Spark:scala 中数据集的动态过滤器
Spark: Dynamic filter on a dataset in scala
我有一个数据集 (ds
) 看起来像
scala> ds.show()
+----+---+-----+----+-----+--------------+
|name|age|field|optr|value| rule|
+----+---+-----+----+-----+--------------+
| a| 75| age| <| 18| Minor|
| b| 10| age| <| 18| Minor|
| c| 30| age| <| 18| Minor|
| a| 75| age| >=| 18| Major|
| b| 10| age| >=| 18| Major|
| c| 30| age| >=| 18| Major|
| a| 75| age| >| 60|Senior Citizen|
| b| 10| age| >| 60|Senior Citizen|
| c| 30| age| >| 60|Senior Citizen|
+----+---+-----+----+-----+--------------+
现在我需要对此应用过滤器以获取满足下面指定的过滤条件的那些行。
- 对
field
列中的字段应用过滤器
- 要执行的操作在
optr
列中,并且
- 要比较的值在
value
列中。
示例: 对于第一行 - 在 age
列上应用过滤器(这里所有字段值都是年龄,但可以不同),其中年龄是 小于 (<) 值 18,即 false
作为 age=75.
我不知道如何在 scala 中指定这个过滤条件。结果数据集应该看起来像
+----+---+-----+----+-----+--------------+
|name|age|field|optr|value| rule|
+----+---+-----+----+-----+--------------+
| b| 10| age| <| 18| Minor|
| a| 75| age| >=| 18| Major|
| c| 30| age| >=| 18| Major|
| a| 75| age| >| 60|Senior Citizen|
+----+---+-----+----+-----+--------------+
解决方法如下-
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.Row
import scala.collection.mutable
val encoder = RowEncoder(df.schema);
df.flatMap(row => {
val result = new mutable.MutableList[Row];
val ruleField = row.getAs[String]("field");
val ruleValue = row.getAs[Int]("value");
val ruleOptr = row.getAs[String]("optr");
val rowField = row.getAs[Int](ruleField);
val condition = {ruleOptr match{
case "=" => rowField == ruleValue;
case "<" => rowField < ruleValue;
case "<=" => rowField <= ruleValue;
case ">" => rowField > ruleValue;
case ">=" => rowField >= ruleValue;
case _ => false;
}
};
if (condition){
result+=row;
};
result;
})(encoder).show();
看看这个:
scala> val df = Seq(("a",75,"age","<",18,"Minor"),("b",10,"age","<",18,"Minor"),("c",30,"age","<",18,"Minor"),("a",75,"age",">=",18,"Major"),("b",10,"age",">=",18,"Major"),("c",30,"age",">=",18,"Major"),("a",75,"age",">",60,"Senior Citizen"),("b",10,"age",">",60,"Senior Citizen"),("c",30,"age",">",60,"Senior Citizen")).toDF("name","age","field","optr","value","rule")
df: org.apache.spark.sql.DataFrame = [name: string, age: int ... 4 more fields]
scala> df.show(false)
+----+---+-----+----+-----+--------------+
|name|age|field|optr|value|rule |
+----+---+-----+----+-----+--------------+
|a |75 |age |< |18 |Minor |
|b |10 |age |< |18 |Minor |
|c |30 |age |< |18 |Minor |
|a |75 |age |>= |18 |Major |
|b |10 |age |>= |18 |Major |
|c |30 |age |>= |18 |Major |
|a |75 |age |> |60 |Senior Citizen|
|b |10 |age |> |60 |Senior Citizen|
|c |30 |age |> |60 |Senior Citizen|
+----+---+-----+----+-----+--------------+
scala> val df2 = df.withColumn("condn", concat('field,'optr,'value))
df2: org.apache.spark.sql.DataFrame = [name: string, age: int ... 5 more fields]
scala> val condn_list=df2.groupBy().agg(collect_set('condn).as("condns")).as[(Seq[String])].first
condn_list: Seq[String] = List(age>60, age<18, age>=18)
scala> val df_filters = condn_list.map{ x => df2.filter(s""" condn='${x}' and $x """) }
df_filters: Seq[org.apache.spark.sql.Dataset[org.apache.spark.sql.Row]] = List([name: string, age: int ... 5 more fields], [name: string, age: int ... 5 more fields], [name: string, age: int ... 5 more fields])
scala> df_filters(0).union(df_filters(1)).union(df_filters(2)).show(false)
+----+---+-----+----+-----+--------------+-------+
|name|age|field|optr|value|rule |condn |
+----+---+-----+----+-----+--------------+-------+
|b |10 |age |< |18 |Minor |age<18 |
|a |75 |age |> |60 |Senior Citizen|age>60 |
|a |75 |age |>= |18 |Major |age>=18|
|c |30 |age |>= |18 |Major |age>=18|
+----+---+-----+----+-----+--------------+-------+
scala>
要获得工会,您可以这样做
scala> var res = df_filters(0)
res: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [name: string, age: int ... 5 more fields]
scala> (1 until df_filters.length).map( x => { res = res.union(df_filters(x)) } )
res20: scala.collection.immutable.IndexedSeq[Unit] = Vector((), ())
scala> res.show(false)
+----+---+-----+----+-----+--------------+-------+
|name|age|field|optr|value|rule |condn |
+----+---+-----+----+-----+--------------+-------+
|b |10 |age |< |18 |Minor |age<18 |
|a |75 |age |> |60 |Senior Citizen|age>60 |
|a |75 |age |>= |18 |Major |age>=18|
|c |30 |age |>= |18 |Major |age>=18|
+----+---+-----+----+-----+--------------+-------+
scala>
我有一个数据集 (ds
) 看起来像
scala> ds.show()
+----+---+-----+----+-----+--------------+
|name|age|field|optr|value| rule|
+----+---+-----+----+-----+--------------+
| a| 75| age| <| 18| Minor|
| b| 10| age| <| 18| Minor|
| c| 30| age| <| 18| Minor|
| a| 75| age| >=| 18| Major|
| b| 10| age| >=| 18| Major|
| c| 30| age| >=| 18| Major|
| a| 75| age| >| 60|Senior Citizen|
| b| 10| age| >| 60|Senior Citizen|
| c| 30| age| >| 60|Senior Citizen|
+----+---+-----+----+-----+--------------+
现在我需要对此应用过滤器以获取满足下面指定的过滤条件的那些行。
- 对
field
列中的字段应用过滤器 - 要执行的操作在
optr
列中,并且 - 要比较的值在
value
列中。
示例: 对于第一行 - 在 age
列上应用过滤器(这里所有字段值都是年龄,但可以不同),其中年龄是 小于 (<) 值 18,即 false
作为 age=75.
我不知道如何在 scala 中指定这个过滤条件。结果数据集应该看起来像
+----+---+-----+----+-----+--------------+
|name|age|field|optr|value| rule|
+----+---+-----+----+-----+--------------+
| b| 10| age| <| 18| Minor|
| a| 75| age| >=| 18| Major|
| c| 30| age| >=| 18| Major|
| a| 75| age| >| 60|Senior Citizen|
+----+---+-----+----+-----+--------------+
解决方法如下-
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.Row
import scala.collection.mutable
val encoder = RowEncoder(df.schema);
df.flatMap(row => {
val result = new mutable.MutableList[Row];
val ruleField = row.getAs[String]("field");
val ruleValue = row.getAs[Int]("value");
val ruleOptr = row.getAs[String]("optr");
val rowField = row.getAs[Int](ruleField);
val condition = {ruleOptr match{
case "=" => rowField == ruleValue;
case "<" => rowField < ruleValue;
case "<=" => rowField <= ruleValue;
case ">" => rowField > ruleValue;
case ">=" => rowField >= ruleValue;
case _ => false;
}
};
if (condition){
result+=row;
};
result;
})(encoder).show();
看看这个:
scala> val df = Seq(("a",75,"age","<",18,"Minor"),("b",10,"age","<",18,"Minor"),("c",30,"age","<",18,"Minor"),("a",75,"age",">=",18,"Major"),("b",10,"age",">=",18,"Major"),("c",30,"age",">=",18,"Major"),("a",75,"age",">",60,"Senior Citizen"),("b",10,"age",">",60,"Senior Citizen"),("c",30,"age",">",60,"Senior Citizen")).toDF("name","age","field","optr","value","rule")
df: org.apache.spark.sql.DataFrame = [name: string, age: int ... 4 more fields]
scala> df.show(false)
+----+---+-----+----+-----+--------------+
|name|age|field|optr|value|rule |
+----+---+-----+----+-----+--------------+
|a |75 |age |< |18 |Minor |
|b |10 |age |< |18 |Minor |
|c |30 |age |< |18 |Minor |
|a |75 |age |>= |18 |Major |
|b |10 |age |>= |18 |Major |
|c |30 |age |>= |18 |Major |
|a |75 |age |> |60 |Senior Citizen|
|b |10 |age |> |60 |Senior Citizen|
|c |30 |age |> |60 |Senior Citizen|
+----+---+-----+----+-----+--------------+
scala> val df2 = df.withColumn("condn", concat('field,'optr,'value))
df2: org.apache.spark.sql.DataFrame = [name: string, age: int ... 5 more fields]
scala> val condn_list=df2.groupBy().agg(collect_set('condn).as("condns")).as[(Seq[String])].first
condn_list: Seq[String] = List(age>60, age<18, age>=18)
scala> val df_filters = condn_list.map{ x => df2.filter(s""" condn='${x}' and $x """) }
df_filters: Seq[org.apache.spark.sql.Dataset[org.apache.spark.sql.Row]] = List([name: string, age: int ... 5 more fields], [name: string, age: int ... 5 more fields], [name: string, age: int ... 5 more fields])
scala> df_filters(0).union(df_filters(1)).union(df_filters(2)).show(false)
+----+---+-----+----+-----+--------------+-------+
|name|age|field|optr|value|rule |condn |
+----+---+-----+----+-----+--------------+-------+
|b |10 |age |< |18 |Minor |age<18 |
|a |75 |age |> |60 |Senior Citizen|age>60 |
|a |75 |age |>= |18 |Major |age>=18|
|c |30 |age |>= |18 |Major |age>=18|
+----+---+-----+----+-----+--------------+-------+
scala>
要获得工会,您可以这样做
scala> var res = df_filters(0)
res: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [name: string, age: int ... 5 more fields]
scala> (1 until df_filters.length).map( x => { res = res.union(df_filters(x)) } )
res20: scala.collection.immutable.IndexedSeq[Unit] = Vector((), ())
scala> res.show(false)
+----+---+-----+----+-----+--------------+-------+
|name|age|field|optr|value|rule |condn |
+----+---+-----+----+-----+--------------+-------+
|b |10 |age |< |18 |Minor |age<18 |
|a |75 |age |> |60 |Senior Citizen|age>60 |
|a |75 |age |>= |18 |Major |age>=18|
|c |30 |age |>= |18 |Major |age>=18|
+----+---+-----+----+-----+--------------+-------+
scala>