pyspark 中的 DOT 产品?

DOT Product in pyspark?

我有:

df1

+------------------+----------+
|               var|multiplier|
+------------------+----------+
|              var1|         1|
|              var2|         2|
|              var3|         3|
+------------------+----------+

df2

+-------+----------+-----+-----+----------+---------+
|   varA|      varB| varC| var1|      var2|     var3|
+-------+----------+-----+-----+--------------------+
|   abcd|       at1|    5|    1|        45|       12|
|   xyzw|       vt1|    7|    1|        23|       17|
+-------+----------+-----------+----------+---------+

结果: df3

+-------+----------+-----+-----+----------+---------+---------------+
|   varA|      varB| varC| var1|      var2|     var3|     sumproduct|
+-------+----------+-----+-----+--------------------+---------------+
|   abcd|       at1|    5|    1|        90|       36|            127|
|   xyzw|       vt1|    7|    1|        46|       51|             98|
+-------+----------+-----------+----------+---------+---------------+

在 python 中,我可以通过以下方式实现:

df1 = df1.set_index(['var'])
df3 = df2.dot(df1)

对类似的 pyspark 方法有帮助吗?

lst=df1.select("multiplier").rdd.flatMap(lambda x: x).collect()#put multiplier into a list
df3 =(
 df2.withColumn('a1', array('var1',      'var2',     'var3'))#Create an array from df2
 .withColumn('a2', array([F.lit(x) for x in lst]))#Insert array from df1
 .withColumn('a1',expr("transform(a1, (x,i)->a2[i]*x)"))#Compute dot product
 .select('varA','varB','varC','a1', *[F.col('a1')[i].alias(f'var{str(i+1)}') for i in range(3)])#Expand a1 back to original var columns
 .select('*', expr("aggregate(a1,cast(0 as bigint), (x,i) -> x+i)").alias('sumproduct'))#sumproduct
 .drop('a1','a2')
 )

df3.show()

+----+----+----+----+----+----+----------+
|varA|varB|varC|var1|var2|var3|sumproduct|
+----+----+----+----+----+----+----------+
|abcd| at1|   5|   1|  90|  36|       127|
|xyzw| vt1|   7|   1|  46|  51|        98|
+----+----+----+----+----+----+----------+

请记住,如果您只需要点积,udf 是可能的。我们可以使用 numpy,它非常擅长这些东西

import numpy as np
lst=df1.select("multiplier").rdd.flatMap(lambda x: x).collect()
dot_array = udf(lambda x,y: int(np.dot(x,y)), IntegerType())
df2.withColumn("dotproduct",dot_array(array('var1',      'var2',     'var3'),array([F.lit(x) for x in lst]))).show()

+----+----+----+----+----+----+----------+
|varA|varB|varC|var1|var2|var3|dotproduct|
+----+----+----+----+----+----+----------+
|abcd| at1|   5|   1|  45|  12|       127|
|xyzw| vt1|   7|   1|  23|  17|        98|
+----+----+----+----+----+----+----------+