确保 PySpark 数组中相邻元素之间的差异大于给定的最小值
Ensure difference between adjoining elements in PySpark array is more than a given minimum value
我有一个包含三列的 PySpark 数据框 (df
)。
1。
category
: 一些字符串
2。
startTimeArray
: 它是一个数组,包含按升序排列的时间戳。
3。
endTimeArray
: 它是一个数组,包含按升序排列的时间戳。
在每一行中,startTimeArray
中的数组长度与endTimeArray
中的数组长度相同。对于这些数组中的每个索引,startTimeArray
中给出的时间戳比 endTimeArray
中对应的(相同索引)时间戳小(发生在前一个日期)。
在startTimeArray
列(和endTimeArray
列),数组的长度可以不同。
以下是数据帧的示例:
+--------+---------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+
|category|startTimeArray |endTimeArray |
+--------+---------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+
|a |[2019-01-10 00:00:00, 2019-01-12 00:00:00, 2019-01-16 00:00:00, 2019-01-20 00:00:00] |[2019-01-11 00:00:00, 2019-01-15 00:00:00, 2019-01-18 00:00:00, 2019-01-22 00:00:00] |
|a |[2019-03-11 00:00:00, 2019-03-18 00:00:00, 2019-03-20 00:00:00, 2019-03-25 00:00:00, 2019-03-27 00:00:00]|[2019-03-16 00:00:00, 2019-03-19 00:00:00, 2019-03-23 00:00:00, 2019-03-26 00:00:00, 2019-03-30 00:00:00]|
|b |[2019-01-14 00:00:00, 2019-01-16 00:00:00, 2019-02-22 00:00:00] |[2019-01-15 00:00:00, 2019-01-18 00:00:00, 2019-02-25 00:00:00] |
+--------+---------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+
在每一行的 startTimeArray
列中,我想确保数组中连续元素(连续索引处的元素)之间的差异至少为三天。如果 startTimeArray
中的一行有 n
个元素,我愿意删除数组中的条目,第一个条目除外。此外,如果索引 i 处的元素从 startTimeArray
中的一行中删除,我希望索引 i-1 中的元素从 endTimeArray
中的同一行中删除。**
如何使用 PySpark 完成此任务?
有几点需要注意:
如果startTimeArray
中的数组只有一个元素,我们就让它在那里。
我意识到这个任务可以通过删除startTimeArray
中数组中第一个元素之后的所有元素来实现。那将是微不足道的情况。但是我想通过尽可能少的删除来完成任务。
以下是我在上面给出的示例数据帧 df
的情况下想要的输出。
+--------+---------------------------------------------------------------+---------------------------------------------------------------+
|category|startTimeArray |endTimeArray |
+--------+---------------------------------------------------------------+---------------------------------------------------------------+
|a |[2019-01-10 00:00:00, 2019-01-16 00:00:00, 2019-01-20 00:00:00]|[2019-01-15 00:00:00, 2019-01-18 00:00:00, 2019-01-22 00:00:00]|
|a |[2019-03-11 00:00:00, 2019-03-18 00:00:00, 2019-03-25 00:00:00]|[2019-03-16 00:00:00, 2019-03-23 00:00:00, 2019-03-30 00:00:00]|
|b |[2019-01-14 00:00:00, 2019-02-22 00:00:00] |[2019-01-18 00:00:00, 2019-02-25 00:00:00] |
+--------+---------------------------------------------------------------+---------------------------------------------------------------+
用户定义的函数 (UDF) 可以完成这项工作。虽然它比原生 Spark sql 函数有性能损失,但它清楚地表达了所需的操作。
from datetime import date, timedelta
from pyspark.sql.functions import *
from pyspark.sql.types import *
d = [date(2019, 1, d) for d in (10, 12, 16, 20)]
e = [date(2019, 1, d) for d in (11, 15, 18, 22)]
f = [date(2019, 3, d) for d in (11, 18, 20, 25, 27)]
g = [date(2019, 3, d) for d in (16, 19, 23, 26, 30)]
h = [date(2019, 1, 14), date(2019, 1, 16), date(2019, 2, 22)]
i = [date(2019, 1, 15), date(2019, 1, 18), date(2019, 2, 25)]
df = spark.createDataFrame((("a", d, e), ("a", f, g), ("b", h, i)),
schema=("category", "startDates", "endDates"))
@udf(returnType=ArrayType(ArrayType(DateType())))
def retain_dates_n_days_apart(startDates, endDates, min_apart=3):
start_dates = [startDates[0]]
end_dates = []
for start, end in zip(startDates[1:], endDates):
if start >= start_dates[-1] + timedelta(days=min_apart):
start_dates.append(start)
end_dates.append(end)
end_dates.append(endDates[-1])
return start_dates, end_dates
df2 = (df
.withColumn("foo",
retain_dates_n_days_apart(df.startDates,
df.endDates))
.cache())
(df2.withColumn("startDates", df2.foo.getItem(0))
.withColumn("endDates", df2.foo.getItem(1))
.drop("foo")
).show(truncate=False)
# +--------+------------------------------------+------------------------------------+
# |category|startDates |endDates |
# +--------+------------------------------------+------------------------------------+
# |a |[2019-01-10, 2019-01-16, 2019-01-20]|[2019-01-15, 2019-01-18, 2019-01-22]|
# |a |[2019-03-11, 2019-03-18, 2019-03-25]|[2019-03-16, 2019-03-23, 2019-03-30]|
# |b |[2019-01-14, 2019-02-22] |[2019-01-18, 2019-02-25] |
# +--------+------------------------------------+------------------------------------+
我有一个包含三列的 PySpark 数据框 (df
)。
1。
category
: 一些字符串
2。
startTimeArray
: 它是一个数组,包含按升序排列的时间戳。
3。
endTimeArray
: 它是一个数组,包含按升序排列的时间戳。
在每一行中,startTimeArray
中的数组长度与endTimeArray
中的数组长度相同。对于这些数组中的每个索引,startTimeArray
中给出的时间戳比 endTimeArray
中对应的(相同索引)时间戳小(发生在前一个日期)。
在startTimeArray
列(和endTimeArray
列),数组的长度可以不同。
以下是数据帧的示例:
+--------+---------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+
|category|startTimeArray |endTimeArray |
+--------+---------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+
|a |[2019-01-10 00:00:00, 2019-01-12 00:00:00, 2019-01-16 00:00:00, 2019-01-20 00:00:00] |[2019-01-11 00:00:00, 2019-01-15 00:00:00, 2019-01-18 00:00:00, 2019-01-22 00:00:00] |
|a |[2019-03-11 00:00:00, 2019-03-18 00:00:00, 2019-03-20 00:00:00, 2019-03-25 00:00:00, 2019-03-27 00:00:00]|[2019-03-16 00:00:00, 2019-03-19 00:00:00, 2019-03-23 00:00:00, 2019-03-26 00:00:00, 2019-03-30 00:00:00]|
|b |[2019-01-14 00:00:00, 2019-01-16 00:00:00, 2019-02-22 00:00:00] |[2019-01-15 00:00:00, 2019-01-18 00:00:00, 2019-02-25 00:00:00] |
+--------+---------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+
在每一行的 startTimeArray
列中,我想确保数组中连续元素(连续索引处的元素)之间的差异至少为三天。如果 startTimeArray
中的一行有 n
个元素,我愿意删除数组中的条目,第一个条目除外。此外,如果索引 i 处的元素从 startTimeArray
中的一行中删除,我希望索引 i-1 中的元素从 endTimeArray
中的同一行中删除。**
如何使用 PySpark 完成此任务?
有几点需要注意:
如果
startTimeArray
中的数组只有一个元素,我们就让它在那里。我意识到这个任务可以通过删除
startTimeArray
中数组中第一个元素之后的所有元素来实现。那将是微不足道的情况。但是我想通过尽可能少的删除来完成任务。
以下是我在上面给出的示例数据帧 df
的情况下想要的输出。
+--------+---------------------------------------------------------------+---------------------------------------------------------------+
|category|startTimeArray |endTimeArray |
+--------+---------------------------------------------------------------+---------------------------------------------------------------+
|a |[2019-01-10 00:00:00, 2019-01-16 00:00:00, 2019-01-20 00:00:00]|[2019-01-15 00:00:00, 2019-01-18 00:00:00, 2019-01-22 00:00:00]|
|a |[2019-03-11 00:00:00, 2019-03-18 00:00:00, 2019-03-25 00:00:00]|[2019-03-16 00:00:00, 2019-03-23 00:00:00, 2019-03-30 00:00:00]|
|b |[2019-01-14 00:00:00, 2019-02-22 00:00:00] |[2019-01-18 00:00:00, 2019-02-25 00:00:00] |
+--------+---------------------------------------------------------------+---------------------------------------------------------------+
用户定义的函数 (UDF) 可以完成这项工作。虽然它比原生 Spark sql 函数有性能损失,但它清楚地表达了所需的操作。
from datetime import date, timedelta
from pyspark.sql.functions import *
from pyspark.sql.types import *
d = [date(2019, 1, d) for d in (10, 12, 16, 20)]
e = [date(2019, 1, d) for d in (11, 15, 18, 22)]
f = [date(2019, 3, d) for d in (11, 18, 20, 25, 27)]
g = [date(2019, 3, d) for d in (16, 19, 23, 26, 30)]
h = [date(2019, 1, 14), date(2019, 1, 16), date(2019, 2, 22)]
i = [date(2019, 1, 15), date(2019, 1, 18), date(2019, 2, 25)]
df = spark.createDataFrame((("a", d, e), ("a", f, g), ("b", h, i)),
schema=("category", "startDates", "endDates"))
@udf(returnType=ArrayType(ArrayType(DateType())))
def retain_dates_n_days_apart(startDates, endDates, min_apart=3):
start_dates = [startDates[0]]
end_dates = []
for start, end in zip(startDates[1:], endDates):
if start >= start_dates[-1] + timedelta(days=min_apart):
start_dates.append(start)
end_dates.append(end)
end_dates.append(endDates[-1])
return start_dates, end_dates
df2 = (df
.withColumn("foo",
retain_dates_n_days_apart(df.startDates,
df.endDates))
.cache())
(df2.withColumn("startDates", df2.foo.getItem(0))
.withColumn("endDates", df2.foo.getItem(1))
.drop("foo")
).show(truncate=False)
# +--------+------------------------------------+------------------------------------+
# |category|startDates |endDates |
# +--------+------------------------------------+------------------------------------+
# |a |[2019-01-10, 2019-01-16, 2019-01-20]|[2019-01-15, 2019-01-18, 2019-01-22]|
# |a |[2019-03-11, 2019-03-18, 2019-03-25]|[2019-03-16, 2019-03-23, 2019-03-30]|
# |b |[2019-01-14, 2019-02-22] |[2019-01-18, 2019-02-25] |
# +--------+------------------------------------+------------------------------------+