在 Apache spark 中高效 运行 一个 "for" 循环,以便并行执行

Efficiently running a "for" loop in Apache spark so that execution is parallel

我们如何在 Spark 中并行化循环,使处理不是顺序的而是并行的。举个例子—— 我在包含以下数据的 csv 文件(称为 'bill_item.csv')中包含以下数据:

    |-----------+------------|
    | bill_id   | item_id    |
    |-----------+------------|
    | ABC       | 1          |
    | ABC       | 2          |
    | DEF       | 1          |
    | DEF       | 2          |
    | DEF       | 3          |
    | GHI       | 1          |
    |-----------+------------|

我必须得到如下输出:

    |-----------+-----------+--------------|
    | item_1    | item_2    | Num_of_bills |
    |-----------+-----------+--------------|
    | 1         | 2         | 2            |
    | 2         | 3         | 1            |
    | 1         | 3         | 1            |
    |-----------+-----------+--------------|

我们看到在 2 个账单 'ABC' 和 'DEF' 下找到了项目 1 和 2,因此项目 1 和 2 的 'Num_of_bills' 是 2。类似地,项目 2 和 3仅在 bill 'DEF' 下找到,因此 'Num_of_bills' 列为“1”,依此类推。

我正在使用 spark 处理 CSV 文件 'bill_item.csv' 并且我正在使用以下方法:

方法一:

from pyspark.sql.types import StructType, StructField, IntegerType, StringType

# define the schema for the data 
bi_schema = StructType([
    StructField("bill_id", StringType(), True), 
    StructField("item_id", IntegerType(), True) 
]) 

bi_df = (spark.read.schema(dataSchema).csv('bill_item.csv'))

# find the list of all items in sorted order
item_list = bi_df.select("item_id").distinct().orderBy("item_id").collect()

item_list_len = len(item_list)
i = 0
# for each pair of items for e.g. (1,2), (1,3), (1,4), (1,5), (2,3), (2,4), (2,5), ...... (4,5)
while i < item_list_len - 1:
    # find the list of all bill IDs that contain item '1'
    bill_id_list1 = bi_df.filter(bi_df.item_id == item_list[i].item_id).select("bill_id").collect()
    j = i+1
    while j < item_list_len:
        # find the list of all bill IDs that contain item '2'
        bill_id_list2 = bi_df.filter(bi_df.item_id == item_list[j].item_id).select("bill_id").collect()

        # find the common bill IDs in list bill_id_list1 and bill_id_list2 and then the no. of common items
        common_elements = set(basket_id_list1).intersection(bill_id_list2)
        num_bils = len(common_elements)
        if(num_bils > 0):
            print(item_list[i].item_id, item_list[j].item_id, num_bils)
        j += 1    
    i+=1

但是,考虑到现实生活中我们有数百万条记录并且可能存在以下问题,这种方法并不是一种有效的方法:

  1. 可能没有足够的内存来加载所有项目或账单的列表
  2. 获取结果可能需要很长时间,因为执行是顺序的(感谢 'for' 循环)。 (我 运行 上面的算法有 ~200000 条记录,花了 4 个多小时才得出想要的结果。)

方法二:

我在 "item_id" 的基础上通过拆分数据进一步优化了这一点,我使用以下代码块来拆分数据:

bi_df = (spark.read.schema(dataSchema).csv('bill_item.csv'))
outputPath='/path/to/save'
bi_df.write.partitionBy("item_id").csv(outputPath)

拆分后,我执行了我在 "Approach 1" 中使用的相同算法,我发现在 200000 条记录的情况下,它仍然需要 1.03 小时(比 'Approach 1' 下的 4 小时有显着改进)获得最终输出。

而上面的瓶颈是因为顺序'for'循环(也因为'collect()'方法)。所以我的问题是:

总是按顺序在 spark 中循环,在代码中使用它也不是一个好主意。根据您的代码,您正在使用 while 并一次读取单个记录,这将不允许 spark 并行 运行。

如果数据集很大,Spark 代码应该设计成没有 forwhile 循环。

根据我对你的问题的理解,我已经在 scala 中编写了示例代码,它可以在不使用任何循环的情况下提供你想要的输出。请参考下面的代码,尝试按照同样的方式设计一个代码。

注意:我用 Scala 编写的代码也可以用相同的逻辑在 Python 中实现。

scala> import org.apache.spark.sql.expressions.UserDefinedFunction

scala> def sampleUDF:UserDefinedFunction = udf((flagCol:String) => {var out = ""
     |       val flagColList = flagCol.reverse.split(s""",""").map(x => x.trim).mkString(",").reverse.split(s",").toList
     |       var i = 0
     |     var ss = flagColList.size
     |     flagColList.foreach{ x =>
     |        i =  i + 1
     |      val xs = List(flagColList(i-1))
     |      val ys =  flagColList.slice(i, ss)
     |      for (x <- xs; y <- ys)  
     |           out = out +","+x + "~" + y
     |         }
     |             if(out == "") { out = flagCol}
     |    out.replaceFirst(s""",""","")})

//Input DataSet 
scala> df.show
+-------+-------+
|bill_id|item_id|
+-------+-------+
|    ABC|      1|
|    ABC|      2|
|    DEF|      1|
|    DEF|      2|
|    DEF|      3|
|    GHI|      1|
+-------+-------+

//Collectin all item_id corresponding to bill_id

scala> val df1 = df.groupBy("bill_id")
               .agg(concat_ws(",",collect_list(col("item_id"))).alias("item"))

scala> df1.show
+-------+-----+
|bill_id| item|
+-------+-----+
|    DEF|1,2,3|
|    GHI|    1|
|    ABC|  1,2|
+-------+-----+


//Generating combination of all item_id and filter out for correct data

scala>   val df2 = df1.withColumn("item", sampleUDF(col("item")))
                      .withColumn("item", explode(split(col("item"), ",")))
                      .withColumn("Item_1", split(col("item"), "~")(0))
                      .withColumn("Item_2", split(col("item"), "~")(1))
                      .groupBy(col("Item_1"),col("Item_2"))
                      .agg(count(lit(1)).alias("Num_of_bills"))
                      .filter(col("Item_2").isNotNull)

scala> df2.show
+------+------+------------+
|Item_1|Item_2|Num_of_bills|
+------+------+------------+
|     2|     3|           1|
|     1|     2|           2|
|     1|     3|           1|
+------+------+------------+