获取 pyspark 数组类型列的最后 n 个元素
Get last n elements of pyspark array type column
我正在尝试获取名为 Foo 的每个数组列的最后 n 个元素,并从中创建一个单独的列,称为 last_n_items_of_Foo。 Foo 列数组的长度可变
我看过这篇文章
但它有一个方法不能用于访问最后的元素。
import pandas as pd
from pyspark.sql.functions import udf, size
from pyspark.sql.types import StringType
from pyspark.sql.functions import col
df = pd.DataFrame([[[1,1,2,3],1,0],[[1,1,2,7,8,9],0,0],[[1,1,2,3,4,5,8],1,1]],columns = ['Foo','Bar','Baz'])
spark_df = spark.createDataFrame(df)
输出应该是这样的
如果 n=2
Foo Bar Baz last_2_items_of_Foo
0 [1, 1, 2, 3] 1 0 [2, 3]
1 [1, 1, 2, 7, 8, 9] 0 0 [8, 9]
2 [1, 1, 2, 3, 4, 5, 8] 1 1 [5, 8]
您可以编写自己的 UDF 以从数组中获取最后 n 个元素:
import pyspark.sql.functions as f
import pyspark.sql.types as t
def get_last_n_elements_(arr, n):
return arr[-n:]
get_last_n_elements = f.udf(get_last_n_elements_, t.ArrayType(t.IntegerType()))
UDF 将列数据类型作为参数,因此使用 f.lit(n)
spark_df.withColumn('last_2_items_of_Foo', get_last_n_elements('Foo', f.lit(2))).show()
+--------------------+---+---+-------------------+
| Foo|Bar|Baz|last_2_items_of_Foo|
+--------------------+---+---+-------------------+
| [1, 1, 2, 3]| 1| 0| [2, 3]|
| [1, 1, 2, 7, 8, 9]| 0| 0| [8, 9]|
|[1, 1, 2, 3, 4, 5...| 1| 1| [5, 8]|
+--------------------+---+---+-------------------+
显然在 spark 2.4 中,有内置函数 f.slice
可以对数组进行切片。
目前我的系统中没有 2.4+ 版本,但它会像下面这样:
spark_df.withColumn('last_2_items_of_Foo', f.slice('Foo', -2)).show()
我正在尝试获取名为 Foo 的每个数组列的最后 n 个元素,并从中创建一个单独的列,称为 last_n_items_of_Foo。 Foo 列数组的长度可变
我看过这篇文章
import pandas as pd
from pyspark.sql.functions import udf, size
from pyspark.sql.types import StringType
from pyspark.sql.functions import col
df = pd.DataFrame([[[1,1,2,3],1,0],[[1,1,2,7,8,9],0,0],[[1,1,2,3,4,5,8],1,1]],columns = ['Foo','Bar','Baz'])
spark_df = spark.createDataFrame(df)
输出应该是这样的
如果 n=2
Foo Bar Baz last_2_items_of_Foo
0 [1, 1, 2, 3] 1 0 [2, 3]
1 [1, 1, 2, 7, 8, 9] 0 0 [8, 9]
2 [1, 1, 2, 3, 4, 5, 8] 1 1 [5, 8]
您可以编写自己的 UDF 以从数组中获取最后 n 个元素:
import pyspark.sql.functions as f
import pyspark.sql.types as t
def get_last_n_elements_(arr, n):
return arr[-n:]
get_last_n_elements = f.udf(get_last_n_elements_, t.ArrayType(t.IntegerType()))
UDF 将列数据类型作为参数,因此使用 f.lit(n)
spark_df.withColumn('last_2_items_of_Foo', get_last_n_elements('Foo', f.lit(2))).show()
+--------------------+---+---+-------------------+
| Foo|Bar|Baz|last_2_items_of_Foo|
+--------------------+---+---+-------------------+
| [1, 1, 2, 3]| 1| 0| [2, 3]|
| [1, 1, 2, 7, 8, 9]| 0| 0| [8, 9]|
|[1, 1, 2, 3, 4, 5...| 1| 1| [5, 8]|
+--------------------+---+---+-------------------+
显然在 spark 2.4 中,有内置函数 f.slice
可以对数组进行切片。
目前我的系统中没有 2.4+ 版本,但它会像下面这样:
spark_df.withColumn('last_2_items_of_Foo', f.slice('Foo', -2)).show()