来自整数和字符串列的 Scala 新 Map 列
Scala new Map column from integer and string columns
问题陈述:
我有一个包含四列的数据框:服务(字符串)、显示(字符串)、country_1(整数)和country_2(整数)。我的 objective 是生成一个仅包含两列的数据框:服务 (String) 和信息 (Map[Integer, List[String]])
每个流媒体服务地图可以包含多个键值对记录:
{
"34521": ["The Crown", "Bridgerton", "The Queen's Gambit"],
"49678": ["The Crown", "Bridgerton", "The Queen's Gambit"]
}
需要注意的一件重要事情是,将来可以添加更多国家/地区,例如输入数据框中的另外几列,如“country_3”、“country_4”等。带有解决方案代码的 objective 也有希望解释这些事情,而不仅仅是像我在下面尝试的解决方案中所做的那样对选定的列进行硬编码,如果这有意义的话。
输入数据帧:
架构:
root
|-- service: string (nullable = true)
|-- show: string (nullable = true)
|-- country_1: integer (nullable = true)
|-- country_2: integer (nullable = true)
数据框:
service | show | country_1 | country_2
Netflix The Crown 34521 49678
Netflix Bridgerton 34521 49678
Netflix The Queen's Gambit 34521 49678
Peacock The Office 34521 49678
Disney+ WandaVision 34521 49678
Disney+ Marvel's 616 34521 49678
Disney+ The Mandalorian 34521 49678
Apple TV Ted Lasso 34521 49678
Apple TV The Morning Show 34521 49678
输出数据帧:
架构:
root
|-- service: string (nullable = true)
|-- information: map (nullable = false)
| |-- key: integer
| |-- value: array (valueContainsNull = true)
| | |-- element: string (containsNull = true)
数据框:
service | information
Netflix [34521 -> [The Crown, Bridgerton, The Queen’s Gambit], 49678 -> [The Crown, Bridgerton, The Queen’s Gambit]]
Peacock [34521 -> [The Office], 49678 -> [The Office]]
Disney+ [34521 -> [WandaVision, Marvel’s 616, The Mandalorian], 49678 -> [WandaVision, Marvel’s 616, The Mandalorian]]
Apple TV [34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso, The Morning Show]]
我已经尝试过的
虽然我已经通过粘贴的代码片段成功生成了我想要的输出,但我不想依赖使用非常基本的 SQL 类型的命令,因为我认为它并不总是快速的最佳选择大型数据集的计算,此外,我不想依赖一种方法,在映射时我通过确切的名称手动选择国家列,因为这总是可以改变,因为以后可以添加更多国家列。
是否有更好的方法来执行此操作,利用 udfs、foldLeft 等类型的代码或任何其他有助于优化并帮助代码更简洁而不那么混乱的方法?
val df = spark.read.parquet("filepath/*.parquet")
val temp = df.groupBy("service", "country_1", "country_2").agg(collect_list("show").alias("show"))
val service_information = grouped.withColumn("information", map(lit($"country_1"), $"show", lit($"country_2"), $"show")).drop("country_1", "country_2", "show")
根据评论部分中描述的国家数据“规范”(即国家代码在任何给定 country_X
列的所有行中都是相同且非空的),您的代码可以概括为处理任意多个国家列:
val df = Seq(
("Netflix", "The Crown", 34521, 49678),
("Netflix", "Bridgerton", 34521, 49678),
("Netflix", "The Queen's Gambit", 34521, 49678),
("Peacock", "The Office", 34521, 49678),
("Disney+", "WandaVision", 34521, 49678),
("Disney+", "Marvel's 616", 34521, 49678),
("Disney+", "The Mandalorian", 34521, 49678),
("Apple TV", "Ted Lasso", 34521, 49678),
("Apple TV", "The Morning Show", 34521, 49678)
).toDF("service", "show", "country_1", "country_2")
val countryCols = df.columns.filter(_.startsWith("country_")).toList
val grouped = df.groupBy("service", countryCols: _*).agg(collect_list("show").as("shows"))
val service_information = grouped.withColumn(
"information",
map( countryCols.flatMap{ c => col(c) :: col("shows") :: Nil }: _* )
).drop("shows" :: countryCols: _*)
service_information.show(false)
// +--------+--------------------------------------------------------------------------------------------------------------+
// |service |information |
// +--------+--------------------------------------------------------------------------------------------------------------+
// |Disney+ |[34521 -> [WandaVision, Marvel's 616, The Mandalorian], 49678 -> [WandaVision, Marvel's 616, The Mandalorian]]|
// |Peacock |[34521 -> [The Office], 49678 -> [The Office]] |
// |Netflix |[34521 -> [The Crown, Bridgerton, The Queen's Gambit], 49678 -> [The Crown, Bridgerton, The Queen's Gambit]] |
// |Apple TV|[34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso, The Morning Show]] |
// +--------+--------------------------------------------------------------------------------------------------------------+
请注意,所描述的国家/地区“规范”将强制所有 show
s 与相同的国家/地区列表相关联。例如,如果您有 3 country_X
列,并且给定 country_X
的每一行都相同且没有空值,这意味着每个 show
都与这 3 个国家/地区相关联。如果您的 show
仅适用于 3 个国家/地区中的 2 个,该怎么办?
如果您的数据模式可以修改,维护相关国家/地区信息的更灵活的方法是为每个 show
.
设置一个 ArrayType 列
val df = Seq(
("Netflix", "The Crown", Seq(34521, 49678)),
("Netflix", "Bridgerton", Seq(34521)),
("Netflix", "The Queen's Gambit", Seq(10001, 49678)),
("Peacock", "The Office", Seq(34521, 49678)),
("Disney+", "WandaVision", Seq(10001, 20002, 34521)),
("Disney+", "Marvel's 616", Seq(49678)),
("Disney+", "The Mandalorian", Seq(34521, 49678)),
("Apple TV", "Ted Lasso", Seq(34521, 49678)),
("Apple TV", "The Morning Show", Seq(20002, 34521))
).toDF("service", "show", "countries")
val grouped = df.withColumn("country", explode($"countries")).
groupBy("service", "country").agg(collect_list($"show").as("shows"))
val service_information = grouped.groupBy("service").
agg(collect_list($"country").as("c_list"), collect_list($"shows").as("s_list")).
select($"service", map_from_arrays($"c_list", $"s_list").as("information"))
service_information.show(false)
// +--------+-----------------------------------------------------------------------------------------------------------------------------------+
// |service |information |
// +--------+-----------------------------------------------------------------------------------------------------------------------------------+
// |Peacock |[34521 -> [The Office], 49678 -> [The Office]] |
// |Disney+ |[20002 -> [WandaVision], 49678 -> [Marvel's 616, The Mandalorian], 34521 -> [WandaVision, The Mandalorian], 10001 -> [WandaVision]]|
// |Apple TV|[34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso], 20002 -> [The Morning Show]] |
// |Netflix |[49678 -> [The Crown, The Queen's Gambit], 10001 -> [The Queen's Gambit], 34521 -> [The Crown, Bridgerton]] |
// +--------+-----------------------------------------------------------------------------------------------------------------------------------+
问题陈述:
我有一个包含四列的数据框:服务(字符串)、显示(字符串)、country_1(整数)和country_2(整数)。我的 objective 是生成一个仅包含两列的数据框:服务 (String) 和信息 (Map[Integer, List[String]])
每个流媒体服务地图可以包含多个键值对记录:
{
"34521": ["The Crown", "Bridgerton", "The Queen's Gambit"],
"49678": ["The Crown", "Bridgerton", "The Queen's Gambit"]
}
需要注意的一件重要事情是,将来可以添加更多国家/地区,例如输入数据框中的另外几列,如“country_3”、“country_4”等。带有解决方案代码的 objective 也有希望解释这些事情,而不仅仅是像我在下面尝试的解决方案中所做的那样对选定的列进行硬编码,如果这有意义的话。
输入数据帧:
架构:
root
|-- service: string (nullable = true)
|-- show: string (nullable = true)
|-- country_1: integer (nullable = true)
|-- country_2: integer (nullable = true)
数据框:
service | show | country_1 | country_2
Netflix The Crown 34521 49678
Netflix Bridgerton 34521 49678
Netflix The Queen's Gambit 34521 49678
Peacock The Office 34521 49678
Disney+ WandaVision 34521 49678
Disney+ Marvel's 616 34521 49678
Disney+ The Mandalorian 34521 49678
Apple TV Ted Lasso 34521 49678
Apple TV The Morning Show 34521 49678
输出数据帧:
架构:
root
|-- service: string (nullable = true)
|-- information: map (nullable = false)
| |-- key: integer
| |-- value: array (valueContainsNull = true)
| | |-- element: string (containsNull = true)
数据框:
service | information
Netflix [34521 -> [The Crown, Bridgerton, The Queen’s Gambit], 49678 -> [The Crown, Bridgerton, The Queen’s Gambit]]
Peacock [34521 -> [The Office], 49678 -> [The Office]]
Disney+ [34521 -> [WandaVision, Marvel’s 616, The Mandalorian], 49678 -> [WandaVision, Marvel’s 616, The Mandalorian]]
Apple TV [34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso, The Morning Show]]
我已经尝试过的
虽然我已经通过粘贴的代码片段成功生成了我想要的输出,但我不想依赖使用非常基本的 SQL 类型的命令,因为我认为它并不总是快速的最佳选择大型数据集的计算,此外,我不想依赖一种方法,在映射时我通过确切的名称手动选择国家列,因为这总是可以改变,因为以后可以添加更多国家列。
是否有更好的方法来执行此操作,利用 udfs、foldLeft 等类型的代码或任何其他有助于优化并帮助代码更简洁而不那么混乱的方法?
val df = spark.read.parquet("filepath/*.parquet")
val temp = df.groupBy("service", "country_1", "country_2").agg(collect_list("show").alias("show"))
val service_information = grouped.withColumn("information", map(lit($"country_1"), $"show", lit($"country_2"), $"show")).drop("country_1", "country_2", "show")
根据评论部分中描述的国家数据“规范”(即国家代码在任何给定 country_X
列的所有行中都是相同且非空的),您的代码可以概括为处理任意多个国家列:
val df = Seq(
("Netflix", "The Crown", 34521, 49678),
("Netflix", "Bridgerton", 34521, 49678),
("Netflix", "The Queen's Gambit", 34521, 49678),
("Peacock", "The Office", 34521, 49678),
("Disney+", "WandaVision", 34521, 49678),
("Disney+", "Marvel's 616", 34521, 49678),
("Disney+", "The Mandalorian", 34521, 49678),
("Apple TV", "Ted Lasso", 34521, 49678),
("Apple TV", "The Morning Show", 34521, 49678)
).toDF("service", "show", "country_1", "country_2")
val countryCols = df.columns.filter(_.startsWith("country_")).toList
val grouped = df.groupBy("service", countryCols: _*).agg(collect_list("show").as("shows"))
val service_information = grouped.withColumn(
"information",
map( countryCols.flatMap{ c => col(c) :: col("shows") :: Nil }: _* )
).drop("shows" :: countryCols: _*)
service_information.show(false)
// +--------+--------------------------------------------------------------------------------------------------------------+
// |service |information |
// +--------+--------------------------------------------------------------------------------------------------------------+
// |Disney+ |[34521 -> [WandaVision, Marvel's 616, The Mandalorian], 49678 -> [WandaVision, Marvel's 616, The Mandalorian]]|
// |Peacock |[34521 -> [The Office], 49678 -> [The Office]] |
// |Netflix |[34521 -> [The Crown, Bridgerton, The Queen's Gambit], 49678 -> [The Crown, Bridgerton, The Queen's Gambit]] |
// |Apple TV|[34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso, The Morning Show]] |
// +--------+--------------------------------------------------------------------------------------------------------------+
请注意,所描述的国家/地区“规范”将强制所有 show
s 与相同的国家/地区列表相关联。例如,如果您有 3 country_X
列,并且给定 country_X
的每一行都相同且没有空值,这意味着每个 show
都与这 3 个国家/地区相关联。如果您的 show
仅适用于 3 个国家/地区中的 2 个,该怎么办?
如果您的数据模式可以修改,维护相关国家/地区信息的更灵活的方法是为每个 show
.
val df = Seq(
("Netflix", "The Crown", Seq(34521, 49678)),
("Netflix", "Bridgerton", Seq(34521)),
("Netflix", "The Queen's Gambit", Seq(10001, 49678)),
("Peacock", "The Office", Seq(34521, 49678)),
("Disney+", "WandaVision", Seq(10001, 20002, 34521)),
("Disney+", "Marvel's 616", Seq(49678)),
("Disney+", "The Mandalorian", Seq(34521, 49678)),
("Apple TV", "Ted Lasso", Seq(34521, 49678)),
("Apple TV", "The Morning Show", Seq(20002, 34521))
).toDF("service", "show", "countries")
val grouped = df.withColumn("country", explode($"countries")).
groupBy("service", "country").agg(collect_list($"show").as("shows"))
val service_information = grouped.groupBy("service").
agg(collect_list($"country").as("c_list"), collect_list($"shows").as("s_list")).
select($"service", map_from_arrays($"c_list", $"s_list").as("information"))
service_information.show(false)
// +--------+-----------------------------------------------------------------------------------------------------------------------------------+
// |service |information |
// +--------+-----------------------------------------------------------------------------------------------------------------------------------+
// |Peacock |[34521 -> [The Office], 49678 -> [The Office]] |
// |Disney+ |[20002 -> [WandaVision], 49678 -> [Marvel's 616, The Mandalorian], 34521 -> [WandaVision, The Mandalorian], 10001 -> [WandaVision]]|
// |Apple TV|[34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso], 20002 -> [The Morning Show]] |
// |Netflix |[49678 -> [The Crown, The Queen's Gambit], 10001 -> [The Queen's Gambit], 34521 -> [The Crown, Bridgerton]] |
// +--------+-----------------------------------------------------------------------------------------------------------------------------------+