如何获取列中至少有两个不同值的行?

How to get rows where at least two distinct values are in a column?

我有一个日志文件,我想报告发起了不止一种(至少两种)类型的 protocol 连接的 IP 地址,同时显示这些协议。我试图通过使用 both DataFrames API 和 SparkSQL 来获得这些结果。

这是我的数据示例:

+----------------+--------+--------+---------------+--------------+---------+-------------+------+-----+
|       Timestamp|Duration|Protocol|BytesOriginator|ResponderBytes|LocalHost|   RemoteHost| State|Flags|
+----------------+--------+--------+---------------+--------------+---------+-------------+------+-----+
|748162802.427995| 1.24383|    smtp|              ?|             ?|        1| 128.97.154.3|   REJ|    L|
|748162802.803033| 3.96513|    smtp|           1173|           328|        3|  128.8.142.5|    SF| null|
|748162804.817224| 1.02839|    nntp|             58|           129|        2|   140.98.2.1|    SF|    L|
|748162812.254572| 138.168|    nntp|         363238|          1200|        4| 128.49.4.103|    SF|    L|
|748162817.478016| 10.0858|    nntp|            230|           100|        4| 128.32.133.1|    SF|    N|
|748162833.453963| 2.16477|    smtp|           2524|           306|        5|192.48.232.17|    SF| null|
|748162836.735788| 13.1779|    smtp|          16479|           174|       16| 128.233.1.12|RSTRS3|    L|
|748162839.930331| 6.69767|    smtp|           3104|           371|        8|   139.91.1.1|    SF|    L|
|748162841.854151| 2.07407|    smtp|           1172|           380|        6|  128.8.142.5|    SF| null|
|748162854.814153| 131.659|    nntp|         319292|          1220|        4| 128.110.4.25|    SF|    L|
|748162866.207165| 51.8406|    nntp|         135714|           280|        4| 128.110.4.25|    SF| null|
|748162866.600750|0.402045|    smtp|              ?|             ?|        1| 128.97.154.3|   REJ|    L|
|748162869.790751| 172.363|    smtp|              0|             0|       16|132.230.6.100|    SF|    L|
|748162873.491682|  102.88|    nntp|            346|           180|        4| 128.32.136.1|    SF|   LN|
|748162875.237378| 5.32943|    nntp|             90|            85|        4| 128.32.133.1|    SF|    N|
+----------------+--------+--------+---------------+--------------+---------+-------------+------+-----+

我试图过滤我的数据框,但我一直收到错误,我不知道我是否应该使用 Window 函数。通过使用 SparkSQL,到目前为止我得到了 IPs 但没有 protocols.

这是我所做的:

custom_schema = StructType([
    StructField('Timestamp', StringType(), True),
    StructField('Duration', FloatType(), True),
    StructField('Protocol', StringType(), True),
    StructField('BytesOriginator', StringType(), True),
    StructField('ResponderBytes', StringType(), True),
    StructField('LocalHost', StringType(), True),
    StructField('RemoteHost', StringType(), True),
    StructField('State', StringType(), True),
    StructField('Flags', StringType(), True) 
])

logs = spark.read.csv('lbl-conn-7.csv', header=False, sep=' ', schema=custom_schema)

# I get an error
logs.select('RemoteHost', 'Protocol').distinct().filter(F.countDistinct('Protocol') > 1).show()

logs.createOrReplaceTempView("mytable")
sqlContext = SQLContext(sc)
df = sqlContext.sql("select remotehost, protocol FROM mytable GROUP BY  HAVING COUNT(distinct protocol) > 1")
# It doesn't show the protocols
df.show()

您可以按 RemoteHost 分组并收集使用的不同 Protocol 的列表。然后,使用协议数组的大小过滤生成的数据帧:

import pyspark.sql.functions as F

logs.groupBy("RemoteHost").agg(
    F.collect_set("Protocol").alias("Protocols")
).filter(
    F.size("Protocols") >= 2
).show()

Spark SQL 等效查询:

SELECT  RemoteHost, 
        collect_set(Protocol) AS Protocols
FROM    mytable 
GROUP BY  RemoteHost
HAVING  size(Protocols) >= 2 -- or count(distinct Protocol)  >= 2

如果要保留所有列,请使用 Window 和 collect_set 函数:

logs.withColumn(
    "Protocols",
    F.collect_set("Protocol").over((Window.partitionBy("RemoteHost")))
).filter(
    F.size("Protocols") >= 2
).drop("Protocols").show()