pyspark 中的 DOT 产品?
DOT Product in pyspark?
我有:
df1
+------------------+----------+
| var|multiplier|
+------------------+----------+
| var1| 1|
| var2| 2|
| var3| 3|
+------------------+----------+
df2
+-------+----------+-----+-----+----------+---------+
| varA| varB| varC| var1| var2| var3|
+-------+----------+-----+-----+--------------------+
| abcd| at1| 5| 1| 45| 12|
| xyzw| vt1| 7| 1| 23| 17|
+-------+----------+-----------+----------+---------+
结果:
df3
+-------+----------+-----+-----+----------+---------+---------------+
| varA| varB| varC| var1| var2| var3| sumproduct|
+-------+----------+-----+-----+--------------------+---------------+
| abcd| at1| 5| 1| 90| 36| 127|
| xyzw| vt1| 7| 1| 46| 51| 98|
+-------+----------+-----------+----------+---------+---------------+
在 python 中,我可以通过以下方式实现:
df1 = df1.set_index(['var'])
df3 = df2.dot(df1)
对类似的 pyspark 方法有帮助吗?
lst=df1.select("multiplier").rdd.flatMap(lambda x: x).collect()#put multiplier into a list
df3 =(
df2.withColumn('a1', array('var1', 'var2', 'var3'))#Create an array from df2
.withColumn('a2', array([F.lit(x) for x in lst]))#Insert array from df1
.withColumn('a1',expr("transform(a1, (x,i)->a2[i]*x)"))#Compute dot product
.select('varA','varB','varC','a1', *[F.col('a1')[i].alias(f'var{str(i+1)}') for i in range(3)])#Expand a1 back to original var columns
.select('*', expr("aggregate(a1,cast(0 as bigint), (x,i) -> x+i)").alias('sumproduct'))#sumproduct
.drop('a1','a2')
)
df3.show()
+----+----+----+----+----+----+----------+
|varA|varB|varC|var1|var2|var3|sumproduct|
+----+----+----+----+----+----+----------+
|abcd| at1| 5| 1| 90| 36| 127|
|xyzw| vt1| 7| 1| 46| 51| 98|
+----+----+----+----+----+----+----------+
请记住,如果您只需要点积,udf 是可能的。我们可以使用 numpy,它非常擅长这些东西
import numpy as np
lst=df1.select("multiplier").rdd.flatMap(lambda x: x).collect()
dot_array = udf(lambda x,y: int(np.dot(x,y)), IntegerType())
df2.withColumn("dotproduct",dot_array(array('var1', 'var2', 'var3'),array([F.lit(x) for x in lst]))).show()
+----+----+----+----+----+----+----------+
|varA|varB|varC|var1|var2|var3|dotproduct|
+----+----+----+----+----+----+----------+
|abcd| at1| 5| 1| 45| 12| 127|
|xyzw| vt1| 7| 1| 23| 17| 98|
+----+----+----+----+----+----+----------+
我有:
df1
+------------------+----------+
| var|multiplier|
+------------------+----------+
| var1| 1|
| var2| 2|
| var3| 3|
+------------------+----------+
df2
+-------+----------+-----+-----+----------+---------+
| varA| varB| varC| var1| var2| var3|
+-------+----------+-----+-----+--------------------+
| abcd| at1| 5| 1| 45| 12|
| xyzw| vt1| 7| 1| 23| 17|
+-------+----------+-----------+----------+---------+
结果: df3
+-------+----------+-----+-----+----------+---------+---------------+
| varA| varB| varC| var1| var2| var3| sumproduct|
+-------+----------+-----+-----+--------------------+---------------+
| abcd| at1| 5| 1| 90| 36| 127|
| xyzw| vt1| 7| 1| 46| 51| 98|
+-------+----------+-----------+----------+---------+---------------+
在 python 中,我可以通过以下方式实现:
df1 = df1.set_index(['var'])
df3 = df2.dot(df1)
对类似的 pyspark 方法有帮助吗?
lst=df1.select("multiplier").rdd.flatMap(lambda x: x).collect()#put multiplier into a list
df3 =(
df2.withColumn('a1', array('var1', 'var2', 'var3'))#Create an array from df2
.withColumn('a2', array([F.lit(x) for x in lst]))#Insert array from df1
.withColumn('a1',expr("transform(a1, (x,i)->a2[i]*x)"))#Compute dot product
.select('varA','varB','varC','a1', *[F.col('a1')[i].alias(f'var{str(i+1)}') for i in range(3)])#Expand a1 back to original var columns
.select('*', expr("aggregate(a1,cast(0 as bigint), (x,i) -> x+i)").alias('sumproduct'))#sumproduct
.drop('a1','a2')
)
df3.show()
+----+----+----+----+----+----+----------+
|varA|varB|varC|var1|var2|var3|sumproduct|
+----+----+----+----+----+----+----------+
|abcd| at1| 5| 1| 90| 36| 127|
|xyzw| vt1| 7| 1| 46| 51| 98|
+----+----+----+----+----+----+----------+
请记住,如果您只需要点积,udf 是可能的。我们可以使用 numpy,它非常擅长这些东西
import numpy as np
lst=df1.select("multiplier").rdd.flatMap(lambda x: x).collect()
dot_array = udf(lambda x,y: int(np.dot(x,y)), IntegerType())
df2.withColumn("dotproduct",dot_array(array('var1', 'var2', 'var3'),array([F.lit(x) for x in lst]))).show()
+----+----+----+----+----+----+----------+
|varA|varB|varC|var1|var2|var3|dotproduct|
+----+----+----+----+----+----+----------+
|abcd| at1| 5| 1| 45| 12| 127|
|xyzw| vt1| 7| 1| 23| 17| 98|
+----+----+----+----+----+----+----------+