如何有效地使用窗口函数根据 N 个先前值决定下一个 N 行数
How to use windowing functions efficiently to decide next N number of rows based on N number of previous values
嗨,我有以下数据。
+----------+----+-------+-----------------------+
| date|item|avg_val|conditions |
+----------+----+-------+-----------------------+
|01-10-2020| x| 10| 0|
|02-10-2020| x| 10| 0|
|03-10-2020| x| 15| 1|
|04-10-2020| x| 15| 1|
|05-10-2020| x| 5| 0|
|06-10-2020| x| 13| 1|
|07-10-2020| x| 10| 1|
|08-10-2020| x| 10| 0|
|09-10-2020| x| 15| 1|
|01-10-2020| y| 10| 0|
|02-10-2020| y| 18| 0|
|03-10-2020| y| 6| 1|
|04-10-2020| y| 10| 0|
|05-10-2020| y| 20| 0|
+----------+----+-------+-----------------------+
我想根据
创建一个名为 flag level 的新列
- 如果标志值为 0,则新列值将为 0。
- 如果标志为 1,则新列将为 1,接下来的四 N 行将为零,即无需检查下一个 N 值。此过程将应用于每个项目,即按项目划分将起作用。
我这里用过N = 4,
我使用了下面的代码,但是没有有效的窗口功能,有没有优化的方法。
DROP TEMPORARY TABLE t2;
CREATE TEMPORARY TABLE t2
SELECT *,
MAX(conditions) OVER (PARTITION BY item ORDER BY item,`date` ROWS 4 PRECEDING ) AS new_row
FROM record
ORDER BY item,`date`;
DROP TEMPORARY TABLE t3;
CREATE TEMPORARY TABLE t3
SELECT *,ROW_NUMBER() OVER (PARTITION BY item,new_row ORDER BY item,`date`) AS e FROM t2;
SELECT *,CASE WHEN new_row=1 AND e%5>1 THEN 0
WHEN new_row=1 AND e%5=1 THEN 1 ELSE 0 END AS flag FROM t3;
输出像
+----------+----+-------+-----------------------+-----+
| date|item|avg_val|conditions |flag |
+----------+----+-------+-----------------------+-----+
|01-10-2020| x| 10| 0| 0|
|02-10-2020| x| 10| 0| 0|
|03-10-2020| x| 15| 1| 1|
|04-10-2020| x| 15| 1| 0|
|05-10-2020| x| 5| 0| 0|
|06-10-2020| x| 13| 1| 0|
|07-10-2020| x| 10| 1| 0|
|08-10-2020| x| 10| 0| 0|
|09-10-2020| x| 15| 1| 1|
|01-10-2020| y| 10| 0| 0|
|02-10-2020| y| 18| 0| 0|
|03-10-2020| y| 6| 1| 1|
|04-10-2020| y| 10| 0| 0|
|05-10-2020| y| 20| 0| 0|
+----------+----+-------+-----------------------+-----+
但是我无法得到输出,我已经尝试了更多。
正如评论中所建议的(@nbk 和@Akina),您将需要某种迭代器来实现逻辑。使用 SparkSQL 和 Spark 2.4+ 版本,我们可以使用内置函数 aggregate 并设置一个结构数组和一个计数器作为累加器。下面是一个示例数据框,table 命名为 record
(假设 conditions
列中的值是 0
或 1
):
val df = Seq(
("01-10-2020", "x", 10, 0), ("02-10-2020", "x", 10, 0), ("03-10-2020", "x", 15, 1),
("04-10-2020", "x", 15, 1), ("05-10-2020", "x", 5, 0), ("06-10-2020", "x", 13, 1),
("07-10-2020", "x", 10, 1), ("08-10-2020", "x", 10, 0), ("09-10-2020", "x", 15, 1),
("01-10-2020", "y", 10, 0), ("02-10-2020", "y", 18, 0), ("03-10-2020", "y", 6, 1),
("04-10-2020", "y", 10, 0), ("05-10-2020", "y", 20, 0)
).toDF("date", "item", "avg_val", "conditions")
df.createOrReplaceTempView("record")
SQL:
spark.sql("""
SELECT t1.item, m.*
FROM (
SELECT item,
sort_array(collect_list(struct(date,avg_val,int(conditions) as conditions,conditions as flag))) as dta
FROM record
GROUP BY item
) as t1 LATERAL VIEW OUTER inline(
aggregate(
/* expr: set up array `dta` from the 2nd element to the last
* notice that indices for slice function is 1-based, dta[i] is 0-based
*/
slice(dta,2,size(dta)),
/* start: set up and initialize `acc` to a struct containing two fields:
* - dta: an array of structs with a single element dta[0]
* - counter: number of rows after flag=1, can be from `0` to `N+1`
*/
(array(dta[0]) as dta, dta[0].conditions as counter),
/* merge: iterate through the `expr` using x and update two fields of `acc`
* - dta: append values from x to acc.dta array using concat + array functions
* update flag using `IF(acc.counter IN (0,5) and x.conditions = 1, 1, 0)`
* - counter: increment by 1 if acc.counter is between 1 and 4
* , otherwise set value to x.conditions
*/
(acc, x) -> named_struct(
'dta', concat(acc.dta, array(named_struct(
'date', x.date,
'avg_val', x.avg_val,
'conditions', x.conditions,
'flag', IF(acc.counter IN (0,5) and x.conditions = 1, 1, 0)
))),
'counter', IF(acc.counter > 0 and acc.counter < 5, acc.counter+1, x.conditions)
),
/* finish: retrieve acc.dta only and discard acc.counter */
acc -> acc.dta
)
) m
""").show(50)
结果:
+----+----------+-------+----------+----+
|item| date|avg_val|conditions|flag|
+----+----------+-------+----------+----+
| x|01-10-2020| 10| 0| 0|
| x|02-10-2020| 10| 0| 0|
| x|03-10-2020| 15| 1| 1|
| x|04-10-2020| 15| 1| 0|
| x|05-10-2020| 5| 0| 0|
| x|06-10-2020| 13| 1| 0|
| x|07-10-2020| 10| 1| 0|
| x|08-10-2020| 10| 0| 0|
| x|09-10-2020| 15| 1| 1|
| y|01-10-2020| 10| 0| 0|
| y|02-10-2020| 18| 0| 0|
| y|03-10-2020| 6| 1| 1|
| y|04-10-2020| 10| 0| 0|
| y|05-10-2020| 20| 0| 0|
+----+----------+-------+----------+----+
其中:
- 使用
groupby
将同一项目的行收集到名为 dta 列的结构数组中,其中包含 4 个字段:date、avg_val、条件和标志并按日期排序
- 使用
aggregate
函数遍历上述结构数组,根据counter更新flag字段和条件(详见上面SQL代码注释)
- 使用
Lateral VIEW
和inline 函数从聚合函数分解结果数组
备注:
(1) 建议的 SQL 用于 N=4,其中我们在 SQL 中有 acc.counter IN (0,5)
和 acc.counter < 5
。对于任何N,将上面的调整为:acc.counter IN (0,N+1)
和acc.counter < N+1
,下面显示N=2
的结果,具有相同的样本数据:
+----+----------+-------+----------+----+
|item| date|avg_val|conditions|flag|
+----+----------+-------+----------+----+
| x|01-10-2020| 10| 0| 0|
| x|02-10-2020| 10| 0| 0|
| x|03-10-2020| 15| 1| 1|
| x|04-10-2020| 15| 1| 0|
| x|05-10-2020| 5| 0| 0|
| x|06-10-2020| 13| 1| 1|
| x|07-10-2020| 10| 1| 0|
| x|08-10-2020| 10| 0| 0|
| x|09-10-2020| 15| 1| 1|
| y|01-10-2020| 10| 0| 0|
| y|02-10-2020| 18| 0| 0|
| y|03-10-2020| 6| 1| 1|
| y|04-10-2020| 10| 0| 0|
| y|05-10-2020| 20| 0| 0|
+----+----------+-------+----------+----+
(2) 我们使用 dta[0]
来初始化 acc
其中包括其字段的值和数据类型。理想情况下,我们应该确保这些字段的数据类型正确,以便正确进行所有计算。例如,在计算 acc.counter
时,如果 conditions
是 StringType,则 acc.counter+1
将 return 具有 DoubleType 值的 StringType
spark.sql("select '2'+1").show()
+---------------------------------------+
|(CAST(2 AS DOUBLE) + CAST(1 AS DOUBLE))|
+---------------------------------------+
| 3.0|
+---------------------------------------+
使用 acc.counter IN (0,5)
或 acc.counter < 5
将它们的值与整数进行比较时,可能会产生 floating-point 错误。根据 OP 的反馈,这产生了错误的结果,没有任何 WARNING/ERROR 消息。
一种解决方法是在设置聚合函数的第二个参数时使用 CAST 指定确切的字段类型,以便在任何类型不匹配时报告错误,见下文:
CAST((array(dta[0]), dta[0].conditions) as struct<dta:array<struct<date:string,avg_val:string,conditions:int,flag:int>>,counter:int>),
另一个解决方案是在创建dta
列时强制类型,在这个例子中,参见下面代码中的int(conditions) as conditions
:
SELECT item,
sort_array(collect_list(struct(date,avg_val,int(conditions) as conditions,conditions as flag))) as dta
FROM record
GROUP BY item
我们也可以在计算中强制数据类型,例如,参见下面的int(acc.counter+1)
:
IF(acc.counter > 0 and acc.counter < 5, int(acc.counter+1), x.conditions)
嗨,我有以下数据。
+----------+----+-------+-----------------------+
| date|item|avg_val|conditions |
+----------+----+-------+-----------------------+
|01-10-2020| x| 10| 0|
|02-10-2020| x| 10| 0|
|03-10-2020| x| 15| 1|
|04-10-2020| x| 15| 1|
|05-10-2020| x| 5| 0|
|06-10-2020| x| 13| 1|
|07-10-2020| x| 10| 1|
|08-10-2020| x| 10| 0|
|09-10-2020| x| 15| 1|
|01-10-2020| y| 10| 0|
|02-10-2020| y| 18| 0|
|03-10-2020| y| 6| 1|
|04-10-2020| y| 10| 0|
|05-10-2020| y| 20| 0|
+----------+----+-------+-----------------------+
我想根据
创建一个名为 flag level 的新列- 如果标志值为 0,则新列值将为 0。
- 如果标志为 1,则新列将为 1,接下来的四 N 行将为零,即无需检查下一个 N 值。此过程将应用于每个项目,即按项目划分将起作用。
我这里用过N = 4,
我使用了下面的代码,但是没有有效的窗口功能,有没有优化的方法。
DROP TEMPORARY TABLE t2;
CREATE TEMPORARY TABLE t2
SELECT *,
MAX(conditions) OVER (PARTITION BY item ORDER BY item,`date` ROWS 4 PRECEDING ) AS new_row
FROM record
ORDER BY item,`date`;
DROP TEMPORARY TABLE t3;
CREATE TEMPORARY TABLE t3
SELECT *,ROW_NUMBER() OVER (PARTITION BY item,new_row ORDER BY item,`date`) AS e FROM t2;
SELECT *,CASE WHEN new_row=1 AND e%5>1 THEN 0
WHEN new_row=1 AND e%5=1 THEN 1 ELSE 0 END AS flag FROM t3;
输出像
+----------+----+-------+-----------------------+-----+
| date|item|avg_val|conditions |flag |
+----------+----+-------+-----------------------+-----+
|01-10-2020| x| 10| 0| 0|
|02-10-2020| x| 10| 0| 0|
|03-10-2020| x| 15| 1| 1|
|04-10-2020| x| 15| 1| 0|
|05-10-2020| x| 5| 0| 0|
|06-10-2020| x| 13| 1| 0|
|07-10-2020| x| 10| 1| 0|
|08-10-2020| x| 10| 0| 0|
|09-10-2020| x| 15| 1| 1|
|01-10-2020| y| 10| 0| 0|
|02-10-2020| y| 18| 0| 0|
|03-10-2020| y| 6| 1| 1|
|04-10-2020| y| 10| 0| 0|
|05-10-2020| y| 20| 0| 0|
+----------+----+-------+-----------------------+-----+
但是我无法得到输出,我已经尝试了更多。
正如评论中所建议的(@nbk 和@Akina),您将需要某种迭代器来实现逻辑。使用 SparkSQL 和 Spark 2.4+ 版本,我们可以使用内置函数 aggregate 并设置一个结构数组和一个计数器作为累加器。下面是一个示例数据框,table 命名为 record
(假设 conditions
列中的值是 0
或 1
):
val df = Seq(
("01-10-2020", "x", 10, 0), ("02-10-2020", "x", 10, 0), ("03-10-2020", "x", 15, 1),
("04-10-2020", "x", 15, 1), ("05-10-2020", "x", 5, 0), ("06-10-2020", "x", 13, 1),
("07-10-2020", "x", 10, 1), ("08-10-2020", "x", 10, 0), ("09-10-2020", "x", 15, 1),
("01-10-2020", "y", 10, 0), ("02-10-2020", "y", 18, 0), ("03-10-2020", "y", 6, 1),
("04-10-2020", "y", 10, 0), ("05-10-2020", "y", 20, 0)
).toDF("date", "item", "avg_val", "conditions")
df.createOrReplaceTempView("record")
SQL:
spark.sql("""
SELECT t1.item, m.*
FROM (
SELECT item,
sort_array(collect_list(struct(date,avg_val,int(conditions) as conditions,conditions as flag))) as dta
FROM record
GROUP BY item
) as t1 LATERAL VIEW OUTER inline(
aggregate(
/* expr: set up array `dta` from the 2nd element to the last
* notice that indices for slice function is 1-based, dta[i] is 0-based
*/
slice(dta,2,size(dta)),
/* start: set up and initialize `acc` to a struct containing two fields:
* - dta: an array of structs with a single element dta[0]
* - counter: number of rows after flag=1, can be from `0` to `N+1`
*/
(array(dta[0]) as dta, dta[0].conditions as counter),
/* merge: iterate through the `expr` using x and update two fields of `acc`
* - dta: append values from x to acc.dta array using concat + array functions
* update flag using `IF(acc.counter IN (0,5) and x.conditions = 1, 1, 0)`
* - counter: increment by 1 if acc.counter is between 1 and 4
* , otherwise set value to x.conditions
*/
(acc, x) -> named_struct(
'dta', concat(acc.dta, array(named_struct(
'date', x.date,
'avg_val', x.avg_val,
'conditions', x.conditions,
'flag', IF(acc.counter IN (0,5) and x.conditions = 1, 1, 0)
))),
'counter', IF(acc.counter > 0 and acc.counter < 5, acc.counter+1, x.conditions)
),
/* finish: retrieve acc.dta only and discard acc.counter */
acc -> acc.dta
)
) m
""").show(50)
结果:
+----+----------+-------+----------+----+
|item| date|avg_val|conditions|flag|
+----+----------+-------+----------+----+
| x|01-10-2020| 10| 0| 0|
| x|02-10-2020| 10| 0| 0|
| x|03-10-2020| 15| 1| 1|
| x|04-10-2020| 15| 1| 0|
| x|05-10-2020| 5| 0| 0|
| x|06-10-2020| 13| 1| 0|
| x|07-10-2020| 10| 1| 0|
| x|08-10-2020| 10| 0| 0|
| x|09-10-2020| 15| 1| 1|
| y|01-10-2020| 10| 0| 0|
| y|02-10-2020| 18| 0| 0|
| y|03-10-2020| 6| 1| 1|
| y|04-10-2020| 10| 0| 0|
| y|05-10-2020| 20| 0| 0|
+----+----------+-------+----------+----+
其中:
- 使用
groupby
将同一项目的行收集到名为 dta 列的结构数组中,其中包含 4 个字段:date、avg_val、条件和标志并按日期排序 - 使用
aggregate
函数遍历上述结构数组,根据counter更新flag字段和条件(详见上面SQL代码注释) - 使用
Lateral VIEW
和inline 函数从聚合函数分解结果数组
备注:
(1) 建议的 SQL 用于 N=4,其中我们在 SQL 中有 acc.counter IN (0,5)
和 acc.counter < 5
。对于任何N,将上面的调整为:acc.counter IN (0,N+1)
和acc.counter < N+1
,下面显示N=2
的结果,具有相同的样本数据:
+----+----------+-------+----------+----+
|item| date|avg_val|conditions|flag|
+----+----------+-------+----------+----+
| x|01-10-2020| 10| 0| 0|
| x|02-10-2020| 10| 0| 0|
| x|03-10-2020| 15| 1| 1|
| x|04-10-2020| 15| 1| 0|
| x|05-10-2020| 5| 0| 0|
| x|06-10-2020| 13| 1| 1|
| x|07-10-2020| 10| 1| 0|
| x|08-10-2020| 10| 0| 0|
| x|09-10-2020| 15| 1| 1|
| y|01-10-2020| 10| 0| 0|
| y|02-10-2020| 18| 0| 0|
| y|03-10-2020| 6| 1| 1|
| y|04-10-2020| 10| 0| 0|
| y|05-10-2020| 20| 0| 0|
+----+----------+-------+----------+----+
(2) 我们使用 dta[0]
来初始化 acc
其中包括其字段的值和数据类型。理想情况下,我们应该确保这些字段的数据类型正确,以便正确进行所有计算。例如,在计算 acc.counter
时,如果 conditions
是 StringType,则 acc.counter+1
将 return 具有 DoubleType 值的 StringType
spark.sql("select '2'+1").show()
+---------------------------------------+
|(CAST(2 AS DOUBLE) + CAST(1 AS DOUBLE))|
+---------------------------------------+
| 3.0|
+---------------------------------------+
使用 acc.counter IN (0,5)
或 acc.counter < 5
将它们的值与整数进行比较时,可能会产生 floating-point 错误。根据 OP 的反馈,这产生了错误的结果,没有任何 WARNING/ERROR 消息。
一种解决方法是在设置聚合函数的第二个参数时使用 CAST 指定确切的字段类型,以便在任何类型不匹配时报告错误,见下文:
CAST((array(dta[0]), dta[0].conditions) as struct<dta:array<struct<date:string,avg_val:string,conditions:int,flag:int>>,counter:int>),
另一个解决方案是在创建
dta
列时强制类型,在这个例子中,参见下面代码中的int(conditions) as conditions
:SELECT item, sort_array(collect_list(struct(date,avg_val,int(conditions) as conditions,conditions as flag))) as dta FROM record GROUP BY item
我们也可以在计算中强制数据类型,例如,参见下面的
int(acc.counter+1)
:IF(acc.counter > 0 and acc.counter < 5, int(acc.counter+1), x.conditions)