检查值是否在 pyspark 的间隔内
check if values are within intervals in pyspark
我有一个大的 DataFrame A,间隔如下:
df_a = spark.createDataFrame([
(0, 23), (1, 6), (2, 55), (3, 1), (4, 12), (5, 51),
], ("id", "x"))
# +---+---+
# | id| x|
# +---+---+
# | 0| 23|
# | 1| 6|
# | 2| 55|
# | 3| 1|
# | 4| 12|
# | 5| 51|
# +---+---+
我有一个 Dataframe B,其中包含这样的已排序非重叠闭区间:
df_b = spark.createDataFrame([
(0, 1, 5), (1, 8, 10), (2, 15, 16), (3, 20, 30), (4, 50, 52),
], ("id", "start", "end"))
# +---+-----+---+
# | id|start|end|
# +---+-----+---+
# | 0| 1| 5|
# | 1| 8| 10|
# | 2| 15| 16|
# | 3| 20| 30|
# | 4| 50| 52|
# +---+-----+---+
我想检查 DataFrame A 的值是否包含在 DataFrame B 的间隔之一中,如果是,则将 id 保存在新列中 (interval_id
)。我的 Output-DataFrame 应该是这样的:
id x interval_id
0 23 3
1 6 null
2 55 null
3 1 0
4 12 null
5 51 4
有没有不使用 udfs 的高效方法?
简单 left_join
应该可以完成工作:
from pyspark.sql import functions as F
result = df_a.join(
df_b.withColumnRenamed("id", "interval_id"),
F.col("x").between(F.col("start"), F.col("end")),
"left"
).drop("start", "end")
result.show()
#+---+---+-----------+
#| id| x|interval_id|
#+---+---+-----------+
#| 0| 23| 3|
#| 1| 6| null|
#| 2| 55| null|
#| 3| 1| 0|
#| 4| 12| null|
#| 5| 51| 4|
#+---+---+-----------+
您可以 join
df_a
和 df_b
这样 df_a["x"] between df_b["start"] and df_b["end"]
.
df_a = spark.createDataFrame([
(0, 23), (1, 6), (2, 55), (3, 1), (4, 12), (5, 51),
], ("id", "x"))
df_b = spark.createDataFrame([
(0, 1, 5), (1, 8, 10), (2, 15, 16), (3, 20, 30), (4, 50, 52),
], ("id", "start", "end"))
df_a.join(df_b, df_a["x"].between(df_b["start"], df_b["end"]), how="left")\
.select(df_a["id"], df_a["x"], df_b["id"].alias("interval_id")).show()
输出
+---+---+-----------+
| id| x|interval_id|
+---+---+-----------+
| 0| 23| 3|
| 1| 6| null|
| 2| 55| null|
| 3| 1| 0|
| 4| 12| null|
| 5| 51| 4|
+---+---+-----------+
我有一个大的 DataFrame A,间隔如下:
df_a = spark.createDataFrame([
(0, 23), (1, 6), (2, 55), (3, 1), (4, 12), (5, 51),
], ("id", "x"))
# +---+---+
# | id| x|
# +---+---+
# | 0| 23|
# | 1| 6|
# | 2| 55|
# | 3| 1|
# | 4| 12|
# | 5| 51|
# +---+---+
我有一个 Dataframe B,其中包含这样的已排序非重叠闭区间:
df_b = spark.createDataFrame([
(0, 1, 5), (1, 8, 10), (2, 15, 16), (3, 20, 30), (4, 50, 52),
], ("id", "start", "end"))
# +---+-----+---+
# | id|start|end|
# +---+-----+---+
# | 0| 1| 5|
# | 1| 8| 10|
# | 2| 15| 16|
# | 3| 20| 30|
# | 4| 50| 52|
# +---+-----+---+
我想检查 DataFrame A 的值是否包含在 DataFrame B 的间隔之一中,如果是,则将 id 保存在新列中 (interval_id
)。我的 Output-DataFrame 应该是这样的:
id x interval_id
0 23 3
1 6 null
2 55 null
3 1 0
4 12 null
5 51 4
有没有不使用 udfs 的高效方法?
简单 left_join
应该可以完成工作:
from pyspark.sql import functions as F
result = df_a.join(
df_b.withColumnRenamed("id", "interval_id"),
F.col("x").between(F.col("start"), F.col("end")),
"left"
).drop("start", "end")
result.show()
#+---+---+-----------+
#| id| x|interval_id|
#+---+---+-----------+
#| 0| 23| 3|
#| 1| 6| null|
#| 2| 55| null|
#| 3| 1| 0|
#| 4| 12| null|
#| 5| 51| 4|
#+---+---+-----------+
您可以 join
df_a
和 df_b
这样 df_a["x"] between df_b["start"] and df_b["end"]
.
df_a = spark.createDataFrame([
(0, 23), (1, 6), (2, 55), (3, 1), (4, 12), (5, 51),
], ("id", "x"))
df_b = spark.createDataFrame([
(0, 1, 5), (1, 8, 10), (2, 15, 16), (3, 20, 30), (4, 50, 52),
], ("id", "start", "end"))
df_a.join(df_b, df_a["x"].between(df_b["start"], df_b["end"]), how="left")\
.select(df_a["id"], df_a["x"], df_b["id"].alias("interval_id")).show()
输出
+---+---+-----------+
| id| x|interval_id|
+---+---+-----------+
| 0| 23| 3|
| 1| 6| null|
| 2| 55| null|
| 3| 1| 0|
| 4| 12| null|
| 5| 51| 4|
+---+---+-----------+