如何使用spark sql udaf实现window带条件计数?
how to use spark sql udaf to implement window counting with condition?
我有一个 table 列:时间戳和 ID 以及条件,我想计算每个时间间隔(例如 10 秒)中每个 ID 的数量。
如果条件为真,计数++,否则return前一个值。
udaf 代码如下:
public class MyCount extends UserDefinedAggregateFunction {
@Override
public StructType inputSchema() {
return DataTypes.createStructType(
Arrays.asList(
DataTypes.createStructField("condition", DataTypes.BooleanType, true),
DataTypes.createStructField("timestamp", DataTypes.LongType, true),
DataTypes.createStructField("interval", DataTypes.IntegerType, true)
)
);
}
@Override
public StructType bufferSchema() {
return DataTypes.createStructType(
Arrays.asList(
DataTypes.createStructField("timestamp", DataTypes.LongType, true),
DataTypes.createStructField("count", DataTypes.LongType, true)
)
);
}
@Override
public DataType dataType() {
return DataTypes.LongType;
}
@Override
public boolean deterministic() {
return true;
}
@Override
public void initialize(MutableAggregationBuffer mutableAggregationBuffer) {
mutableAggregationBuffer.update(0, 0L);
mutableAggregationBuffer.update(1, 0L);
}
public void update(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
long timestamp = mutableAggregationBuffer.getLong(0);
long count = mutableAggregationBuffer.getLong(1);
long event_time = row.getLong(1);
int interval = row.getInt(2);
if (event_time > timestamp + interval) {
timestamp = event_time - event_time % interval;
count = 0;
}
if (row.getBoolean(0)) {
count++;
}
mutableAggregationBuffer.update(0, timestamp);
mutableAggregationBuffer.update(1, count);
}
@Override
public void merge(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
}
@Override
public Object evaluate(Row row) {
return row.getLong(1);
}
}
然后我总结了一个 sql 比如:
select timestamp, id, MyCount(true, timestamp, 10) over(PARTITION BY id ORDER BY timestamp) as count from xxx.xxx
结果是:
timestamp id count
1642760594 0 1
1642760596 0 2
1642760599 0 3
1642760610 0 2 --duplicate
1642760610 0 2
1642760613 0 3
1642760594 1 1
1642760597 1 2
1642760600 1 1
1642760603 1 2
1642760606 1 4 --duplicate
1642760606 1 4
1642760608 1 5
当时间戳重复时,我得到 1,2,4,4,5 而不是 1,2,3,4,5
怎么解决?
还有一个问题就是udaf的merge方法什么时候执行?我空执行它,但它运行正常。我尝试在方法中添加日志,但我没有看到此日志。真的有必要吗?
还有一个类似的问题:
但是row_number()就没有这样的问题。 row_number() 是一个 hive udaf,然后我尝试创建一个 hive udaf。但是我也有问题...Why hive udaf row_number() terminate() returns 'ArrayList'? I create my udaf row_number2() by copying its code then I got list return?
最后我通过spark aggregateWindowFunction解决了它:
case class Count(condition: Expression) extends AggregateWindowFunction with Logging {
override def prettyName: String = "myCount"
override def dataType: DataType = LongType
override def children: Seq[Expression] = Seq(condition)
private val zero = Literal(0L)
private val one = Literal(1L)
private val count = AttributeReference("count", LongType, nullable = false)()
private val increaseCount = If(condition, Add(count, one), count)
override val initialValues: Seq[Expression] = zero :: Nil
override val updateExpressions: Seq[Expression] = increaseCount :: Nil
override val evaluateExpression: Expression = count
override val aggBufferAttributes: Seq[AttributeReference] = count :: Nil
然后用spark_session.functionRegistry.registerFunction注册。
"select myCount(true) over(partition by window(timestamp, '10 seconds'), id order by timestamp) as count from xxx"
我有一个 table 列:时间戳和 ID 以及条件,我想计算每个时间间隔(例如 10 秒)中每个 ID 的数量。
如果条件为真,计数++,否则return前一个值。
udaf 代码如下:
public class MyCount extends UserDefinedAggregateFunction {
@Override
public StructType inputSchema() {
return DataTypes.createStructType(
Arrays.asList(
DataTypes.createStructField("condition", DataTypes.BooleanType, true),
DataTypes.createStructField("timestamp", DataTypes.LongType, true),
DataTypes.createStructField("interval", DataTypes.IntegerType, true)
)
);
}
@Override
public StructType bufferSchema() {
return DataTypes.createStructType(
Arrays.asList(
DataTypes.createStructField("timestamp", DataTypes.LongType, true),
DataTypes.createStructField("count", DataTypes.LongType, true)
)
);
}
@Override
public DataType dataType() {
return DataTypes.LongType;
}
@Override
public boolean deterministic() {
return true;
}
@Override
public void initialize(MutableAggregationBuffer mutableAggregationBuffer) {
mutableAggregationBuffer.update(0, 0L);
mutableAggregationBuffer.update(1, 0L);
}
public void update(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
long timestamp = mutableAggregationBuffer.getLong(0);
long count = mutableAggregationBuffer.getLong(1);
long event_time = row.getLong(1);
int interval = row.getInt(2);
if (event_time > timestamp + interval) {
timestamp = event_time - event_time % interval;
count = 0;
}
if (row.getBoolean(0)) {
count++;
}
mutableAggregationBuffer.update(0, timestamp);
mutableAggregationBuffer.update(1, count);
}
@Override
public void merge(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
}
@Override
public Object evaluate(Row row) {
return row.getLong(1);
}
}
然后我总结了一个 sql 比如:
select timestamp, id, MyCount(true, timestamp, 10) over(PARTITION BY id ORDER BY timestamp) as count from xxx.xxx
结果是:
timestamp id count
1642760594 0 1
1642760596 0 2
1642760599 0 3
1642760610 0 2 --duplicate
1642760610 0 2
1642760613 0 3
1642760594 1 1
1642760597 1 2
1642760600 1 1
1642760603 1 2
1642760606 1 4 --duplicate
1642760606 1 4
1642760608 1 5
当时间戳重复时,我得到 1,2,4,4,5 而不是 1,2,3,4,5 怎么解决?
还有一个问题就是udaf的merge方法什么时候执行?我空执行它,但它运行正常。我尝试在方法中添加日志,但我没有看到此日志。真的有必要吗?
还有一个类似的问题:
但是row_number()就没有这样的问题。 row_number() 是一个 hive udaf,然后我尝试创建一个 hive udaf。但是我也有问题...Why hive udaf row_number() terminate() returns 'ArrayList'? I create my udaf row_number2() by copying its code then I got list return?
最后我通过spark aggregateWindowFunction解决了它:
case class Count(condition: Expression) extends AggregateWindowFunction with Logging {
override def prettyName: String = "myCount"
override def dataType: DataType = LongType
override def children: Seq[Expression] = Seq(condition)
private val zero = Literal(0L)
private val one = Literal(1L)
private val count = AttributeReference("count", LongType, nullable = false)()
private val increaseCount = If(condition, Add(count, one), count)
override val initialValues: Seq[Expression] = zero :: Nil
override val updateExpressions: Seq[Expression] = increaseCount :: Nil
override val evaluateExpression: Expression = count
override val aggBufferAttributes: Seq[AttributeReference] = count :: Nil
然后用spark_session.functionRegistry.registerFunction注册。
"select myCount(true) over(partition by window(timestamp, '10 seconds'), id order by timestamp) as count from xxx"