在 Apache spark 中高效 运行 一个 "for" 循环,以便并行执行
Efficiently running a "for" loop in Apache spark so that execution is parallel
我们如何在 Spark 中并行化循环,使处理不是顺序的而是并行的。举个例子——
我在包含以下数据的 csv 文件(称为 'bill_item.csv')中包含以下数据:
|-----------+------------|
| bill_id | item_id |
|-----------+------------|
| ABC | 1 |
| ABC | 2 |
| DEF | 1 |
| DEF | 2 |
| DEF | 3 |
| GHI | 1 |
|-----------+------------|
我必须得到如下输出:
|-----------+-----------+--------------|
| item_1 | item_2 | Num_of_bills |
|-----------+-----------+--------------|
| 1 | 2 | 2 |
| 2 | 3 | 1 |
| 1 | 3 | 1 |
|-----------+-----------+--------------|
我们看到在 2 个账单 'ABC' 和 'DEF' 下找到了项目 1 和 2,因此项目 1 和 2 的 'Num_of_bills' 是 2。类似地,项目 2 和 3仅在 bill 'DEF' 下找到,因此 'Num_of_bills' 列为“1”,依此类推。
我正在使用 spark 处理 CSV 文件 'bill_item.csv' 并且我正在使用以下方法:
方法一:
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
# define the schema for the data
bi_schema = StructType([
StructField("bill_id", StringType(), True),
StructField("item_id", IntegerType(), True)
])
bi_df = (spark.read.schema(dataSchema).csv('bill_item.csv'))
# find the list of all items in sorted order
item_list = bi_df.select("item_id").distinct().orderBy("item_id").collect()
item_list_len = len(item_list)
i = 0
# for each pair of items for e.g. (1,2), (1,3), (1,4), (1,5), (2,3), (2,4), (2,5), ...... (4,5)
while i < item_list_len - 1:
# find the list of all bill IDs that contain item '1'
bill_id_list1 = bi_df.filter(bi_df.item_id == item_list[i].item_id).select("bill_id").collect()
j = i+1
while j < item_list_len:
# find the list of all bill IDs that contain item '2'
bill_id_list2 = bi_df.filter(bi_df.item_id == item_list[j].item_id).select("bill_id").collect()
# find the common bill IDs in list bill_id_list1 and bill_id_list2 and then the no. of common items
common_elements = set(basket_id_list1).intersection(bill_id_list2)
num_bils = len(common_elements)
if(num_bils > 0):
print(item_list[i].item_id, item_list[j].item_id, num_bils)
j += 1
i+=1
但是,考虑到现实生活中我们有数百万条记录并且可能存在以下问题,这种方法并不是一种有效的方法:
- 可能没有足够的内存来加载所有项目或账单的列表
- 获取结果可能需要很长时间,因为执行是顺序的(感谢 'for' 循环)。 (我 运行 上面的算法有 ~200000 条记录,花了 4 个多小时才得出想要的结果。)
方法二:
我在 "item_id" 的基础上通过拆分数据进一步优化了这一点,我使用以下代码块来拆分数据:
bi_df = (spark.read.schema(dataSchema).csv('bill_item.csv'))
outputPath='/path/to/save'
bi_df.write.partitionBy("item_id").csv(outputPath)
拆分后,我执行了我在 "Approach 1" 中使用的相同算法,我发现在 200000 条记录的情况下,它仍然需要 1.03 小时(比 'Approach 1' 下的 4 小时有显着改进)获得最终输出。
而上面的瓶颈是因为顺序'for'循环(也因为'collect()'方法)。所以我的问题是:
- 有没有办法并行化 for 循环?
- 或者有其他有效的方法吗?
总是按顺序在 spark 中循环,在代码中使用它也不是一个好主意。根据您的代码,您正在使用 while
并一次读取单个记录,这将不允许 spark 并行 运行。
如果数据集很大,Spark 代码应该设计成没有 for
和 while
循环。
根据我对你的问题的理解,我已经在 scala 中编写了示例代码,它可以在不使用任何循环的情况下提供你想要的输出。请参考下面的代码,尝试按照同样的方式设计一个代码。
注意:我用 Scala 编写的代码也可以用相同的逻辑在 Python 中实现。
scala> import org.apache.spark.sql.expressions.UserDefinedFunction
scala> def sampleUDF:UserDefinedFunction = udf((flagCol:String) => {var out = ""
| val flagColList = flagCol.reverse.split(s""",""").map(x => x.trim).mkString(",").reverse.split(s",").toList
| var i = 0
| var ss = flagColList.size
| flagColList.foreach{ x =>
| i = i + 1
| val xs = List(flagColList(i-1))
| val ys = flagColList.slice(i, ss)
| for (x <- xs; y <- ys)
| out = out +","+x + "~" + y
| }
| if(out == "") { out = flagCol}
| out.replaceFirst(s""",""","")})
//Input DataSet
scala> df.show
+-------+-------+
|bill_id|item_id|
+-------+-------+
| ABC| 1|
| ABC| 2|
| DEF| 1|
| DEF| 2|
| DEF| 3|
| GHI| 1|
+-------+-------+
//Collectin all item_id corresponding to bill_id
scala> val df1 = df.groupBy("bill_id")
.agg(concat_ws(",",collect_list(col("item_id"))).alias("item"))
scala> df1.show
+-------+-----+
|bill_id| item|
+-------+-----+
| DEF|1,2,3|
| GHI| 1|
| ABC| 1,2|
+-------+-----+
//Generating combination of all item_id and filter out for correct data
scala> val df2 = df1.withColumn("item", sampleUDF(col("item")))
.withColumn("item", explode(split(col("item"), ",")))
.withColumn("Item_1", split(col("item"), "~")(0))
.withColumn("Item_2", split(col("item"), "~")(1))
.groupBy(col("Item_1"),col("Item_2"))
.agg(count(lit(1)).alias("Num_of_bills"))
.filter(col("Item_2").isNotNull)
scala> df2.show
+------+------+------------+
|Item_1|Item_2|Num_of_bills|
+------+------+------------+
| 2| 3| 1|
| 1| 2| 2|
| 1| 3| 1|
+------+------+------------+
我们如何在 Spark 中并行化循环,使处理不是顺序的而是并行的。举个例子—— 我在包含以下数据的 csv 文件(称为 'bill_item.csv')中包含以下数据:
|-----------+------------|
| bill_id | item_id |
|-----------+------------|
| ABC | 1 |
| ABC | 2 |
| DEF | 1 |
| DEF | 2 |
| DEF | 3 |
| GHI | 1 |
|-----------+------------|
我必须得到如下输出:
|-----------+-----------+--------------|
| item_1 | item_2 | Num_of_bills |
|-----------+-----------+--------------|
| 1 | 2 | 2 |
| 2 | 3 | 1 |
| 1 | 3 | 1 |
|-----------+-----------+--------------|
我们看到在 2 个账单 'ABC' 和 'DEF' 下找到了项目 1 和 2,因此项目 1 和 2 的 'Num_of_bills' 是 2。类似地,项目 2 和 3仅在 bill 'DEF' 下找到,因此 'Num_of_bills' 列为“1”,依此类推。
我正在使用 spark 处理 CSV 文件 'bill_item.csv' 并且我正在使用以下方法:
方法一:
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
# define the schema for the data
bi_schema = StructType([
StructField("bill_id", StringType(), True),
StructField("item_id", IntegerType(), True)
])
bi_df = (spark.read.schema(dataSchema).csv('bill_item.csv'))
# find the list of all items in sorted order
item_list = bi_df.select("item_id").distinct().orderBy("item_id").collect()
item_list_len = len(item_list)
i = 0
# for each pair of items for e.g. (1,2), (1,3), (1,4), (1,5), (2,3), (2,4), (2,5), ...... (4,5)
while i < item_list_len - 1:
# find the list of all bill IDs that contain item '1'
bill_id_list1 = bi_df.filter(bi_df.item_id == item_list[i].item_id).select("bill_id").collect()
j = i+1
while j < item_list_len:
# find the list of all bill IDs that contain item '2'
bill_id_list2 = bi_df.filter(bi_df.item_id == item_list[j].item_id).select("bill_id").collect()
# find the common bill IDs in list bill_id_list1 and bill_id_list2 and then the no. of common items
common_elements = set(basket_id_list1).intersection(bill_id_list2)
num_bils = len(common_elements)
if(num_bils > 0):
print(item_list[i].item_id, item_list[j].item_id, num_bils)
j += 1
i+=1
但是,考虑到现实生活中我们有数百万条记录并且可能存在以下问题,这种方法并不是一种有效的方法:
- 可能没有足够的内存来加载所有项目或账单的列表
- 获取结果可能需要很长时间,因为执行是顺序的(感谢 'for' 循环)。 (我 运行 上面的算法有 ~200000 条记录,花了 4 个多小时才得出想要的结果。)
方法二:
我在 "item_id" 的基础上通过拆分数据进一步优化了这一点,我使用以下代码块来拆分数据:
bi_df = (spark.read.schema(dataSchema).csv('bill_item.csv'))
outputPath='/path/to/save'
bi_df.write.partitionBy("item_id").csv(outputPath)
拆分后,我执行了我在 "Approach 1" 中使用的相同算法,我发现在 200000 条记录的情况下,它仍然需要 1.03 小时(比 'Approach 1' 下的 4 小时有显着改进)获得最终输出。
而上面的瓶颈是因为顺序'for'循环(也因为'collect()'方法)。所以我的问题是:
- 有没有办法并行化 for 循环?
- 或者有其他有效的方法吗?
总是按顺序在 spark 中循环,在代码中使用它也不是一个好主意。根据您的代码,您正在使用 while
并一次读取单个记录,这将不允许 spark 并行 运行。
如果数据集很大,Spark 代码应该设计成没有 for
和 while
循环。
根据我对你的问题的理解,我已经在 scala 中编写了示例代码,它可以在不使用任何循环的情况下提供你想要的输出。请参考下面的代码,尝试按照同样的方式设计一个代码。
注意:我用 Scala 编写的代码也可以用相同的逻辑在 Python 中实现。
scala> import org.apache.spark.sql.expressions.UserDefinedFunction
scala> def sampleUDF:UserDefinedFunction = udf((flagCol:String) => {var out = ""
| val flagColList = flagCol.reverse.split(s""",""").map(x => x.trim).mkString(",").reverse.split(s",").toList
| var i = 0
| var ss = flagColList.size
| flagColList.foreach{ x =>
| i = i + 1
| val xs = List(flagColList(i-1))
| val ys = flagColList.slice(i, ss)
| for (x <- xs; y <- ys)
| out = out +","+x + "~" + y
| }
| if(out == "") { out = flagCol}
| out.replaceFirst(s""",""","")})
//Input DataSet
scala> df.show
+-------+-------+
|bill_id|item_id|
+-------+-------+
| ABC| 1|
| ABC| 2|
| DEF| 1|
| DEF| 2|
| DEF| 3|
| GHI| 1|
+-------+-------+
//Collectin all item_id corresponding to bill_id
scala> val df1 = df.groupBy("bill_id")
.agg(concat_ws(",",collect_list(col("item_id"))).alias("item"))
scala> df1.show
+-------+-----+
|bill_id| item|
+-------+-----+
| DEF|1,2,3|
| GHI| 1|
| ABC| 1,2|
+-------+-----+
//Generating combination of all item_id and filter out for correct data
scala> val df2 = df1.withColumn("item", sampleUDF(col("item")))
.withColumn("item", explode(split(col("item"), ",")))
.withColumn("Item_1", split(col("item"), "~")(0))
.withColumn("Item_2", split(col("item"), "~")(1))
.groupBy(col("Item_1"),col("Item_2"))
.agg(count(lit(1)).alias("Num_of_bills"))
.filter(col("Item_2").isNotNull)
scala> df2.show
+------+------+------------+
|Item_1|Item_2|Num_of_bills|
+------+------+------------+
| 2| 3| 1|
| 1| 2| 2|
| 1| 3| 1|
+------+------+------------+