使用 case class spark scala 加入数据集

Join dataset with case class spark scala

我正在使用 case class 将数据帧转换为数据集,其中包含另一个 case class

的序列
case class IdMonitor(id: String, ipLocation: Seq[IpLocation])
case class IpLocation(
    ip: String,
    ipVersion: Byte,
    ipType: String,
    city: String,
    state: String,
    country: String)

现在我有了另一个只有 IP 的字符串数据集。我的要求是如果 ipType == "home" 或 IP 数据集具有来自 ipLocation 的给定 IP,则从 IpLocation 获取所有记录。我正在尝试在 IP 数据集上使用布隆过滤器来搜索该数据集,但它效率低下,而且通常效果不佳。我想用 IpLocation 加入 IP 数据集,但我遇到了麻烦,因为这是在 Seq 中。我对 spark 和 scala 很陌生,所以我可能遗漏了一些东西。现在我的代码看起来像这样


def buildBloomFilter(Ips: Dataset[String]): BloomFilter[String] = {
    val count = Ips.count
    val bloomFilter = Ips.rdd
      .mapPartitions { iter =>
        val b = BloomFilter.optimallySized[String](count, FP_PROBABILITY)
        iter.foreach(i => b += i)
        Iterator(b)
      }
      .treeReduce(_|_)
    bloomFilter
  }

val ipBf = buildBloomFilter(Ips)
val ipBfBroadcast = spark.sparkContext.broadcast(ipBf)

idMonitor.map { x => 
    x.ipLocation.filter(
       x => x.ipType == "home" && ipBfBroadcast.value.contains(x.ip)
    )
}

我只是想弄清楚如何加入 IpLocationIps

您可以使用 explode 函数分解 IpMonitor 对象中的数组序列,然后使用左外连接匹配 Ips 数据集中存在的 ip,然后过滤掉ipType == "home" 或 ip 存在于 Ips 数据集中,最后通过 idcollect_list.

分组重建您的 IpLocation 序列

完整代码如下:

import org.apache.spark.sql.functions.{col, collect_list, explode}

val result = idMonitor.select(col("id"), explode(col("ipLocation")))
  .join(Ips, col("col.ip") === col("value"), "left_outer")
  .filter(col("col.ipType") === "home" || col("value").isNotNull())
  .groupBy("id")
  .agg(collect_list("col").as("value"))
  .drop("id")
  .as[Seq[IpLocation]]

样本:

从您的案例开始 class,

case class IpLocation(
    ip: String,
    ipVersion: Byte,
    ipType: String,
    city: String,
    state: String,
    country: String
)
case class IdMonitor(id: String, ipLocation: Seq[IpLocation])

我定义了示例数据如下:

val ip_locations1 = Seq(IpLocation("123.123.123.123", 12.toByte, "home", "test", "test", "test"), IpLocation("123.123.123.124", 12.toByte, "otherwise", "test", "test", "test"))
val ip_locations2 = Seq(IpLocation("123.123.123.125", 13.toByte, "company", "test", "test", "test"), IpLocation("123.123.123.124", 13.toByte, "otherwise", "test", "test", "test"))

val id_monitor = Seq(IdMonitor("1", ip_locations1), IdMonitor("2", ip_locations2))
val df = id_monitor.toDF()
df.show(false)

+---+------------------------------------------------------------------------------------------------------+
|id |ipLocation                                                                                            |
+---+------------------------------------------------------------------------------------------------------+
|1  |[{123.123.123.123, 12, home, test, test, test}, {123.123.123.124, 12, otherwise, test, test, test}]   |
|2  |[{123.123.123.125, 13, company, test, test, test}, {123.123.123.124, 13, otherwise, test, test, test}]|
+---+------------------------------------------------------------------------------------------------------+

和 IP:

val ips = Seq("123.123.123.125")
val df_ips = ips.toDF("ips")
df_ips.show()

+---------------+
|            ips|
+---------------+
|123.123.123.125|
+---------------+

加入:

从上面的示例数据中,展开 IdMonitor 的数组并与 IPs 连接。

df.withColumn("ipLocation", explode('ipLocation)).alias("a")
  .join(df_ips.alias("b"), col("a.ipLocation.ipType") === lit("home") || col("a.ipLocation.ip") === col("b.ips"), "inner")
  .select("ipLocation.*")
  .as[IpLocation].collect()

最后给出采集结果如下:

res32: Array[IpLocation] = Array(IpLocation(123.123.123.123,12,home,test,test,test), IpLocation(123.123.123.125,13,company,test,test,test))