星火列车测试拆分

Spark train test split

我很好奇在最新的 2.0.1 版本中是否有类似于 sklearn 的 http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html for apache-spark 的东西。

到目前为止我只能找到 https://spark.apache.org/docs/latest/mllib-statistics.html#stratified-sampling 这似乎不太适合将严重不平衡的数据集拆分为训练/测试样本。

虽然这个答案并不特定于 Spark,但在 Apache beam 中,我这样做是为了将训练 66% 和测试 33% 分开(只是一个说明性示例,您可以自定义下面的 partition_fn 以使其更加复杂和接受参数,例如指定桶的数量或偏向某物的选择或确保随机化在各个维度上是公平的,等等):

raw_data = p | 'Read Data' >> Read(...)

clean_data = (raw_data
              | "Clean Data" >> beam.ParDo(CleanFieldsFn())


def partition_fn(element):
    return random.randint(0, 2)

random_buckets = (clean_data | beam.Partition(partition_fn, 3))

clean_train_data = ((random_buckets[0], random_buckets[1])
                    | beam.Flatten())

clean_eval_data = random_buckets[2]

假设我们有这样的数据集:

+---+-----+
| id|label|
+---+-----+
|  0|  0.0|
|  1|  1.0|
|  2|  0.0|
|  3|  1.0|
|  4|  0.0|
|  5|  1.0|
|  6|  0.0|
|  7|  1.0|
|  8|  0.0|
|  9|  1.0|
+---+-----+

这个数据集是完美平衡的,但这种方法也适用于不平衡的数据。

现在,让我们用额外的信息来扩充这个 DataFrame,这些信息将有助于决定哪些行应该进入训练集。步骤如下:

  • 给定一些 ratio
  • 确定每个标签的多少示例应该成为训练集的一部分
  • 打乱 DataFrame 的行。
  • 使用 window 函数按 label 对 DataFrame 进行分区和排序,然后使用 row_number().
  • 对每个标签的观察结果进行排序

我们最终得到以下数据框:

+---+-----+----------+
| id|label|row_number|
+---+-----+----------+
|  6|  0.0|         1|
|  2|  0.0|         2|
|  0|  0.0|         3|
|  4|  0.0|         4|
|  8|  0.0|         5|
|  9|  1.0|         1|
|  5|  1.0|         2|
|  3|  1.0|         3|
|  1|  1.0|         4|
|  7|  1.0|         5|
+---+-----+----------+

注意:行被打乱(参见:id 列中的随机顺序),按标签分区(参见:label 列)并排名。

假设我们想要进行 80% 的拆分。在这种情况下,我们希望将四个 1.0 标签和四个 0.0 标签转到训练数据集,将一个 1.0 标签和一个 0.0 标签转到测试数据集。我们在 row_number 列中有此信息,所以现在我们可以简单地在用户定义的函数中使用它(如果 row_number 小于或等于 4,则示例转到训练集)。

应用UDF后,得到的数据框如下:

+---+-----+----------+----------+
| id|label|row_number|isTrainSet|
+---+-----+----------+----------+
|  6|  0.0|         1|      true|
|  2|  0.0|         2|      true|
|  0|  0.0|         3|      true|
|  4|  0.0|         4|      true|
|  8|  0.0|         5|     false|
|  9|  1.0|         1|      true|
|  5|  1.0|         2|      true|
|  3|  1.0|         3|      true|
|  1|  1.0|         4|      true|
|  7|  1.0|         5|     false|
+---+-----+----------+----------+

现在,要获取 train/test 数据,您必须这样做:

val train = df.where(col("isTrainSet") === true)
val test = df.where(col("isTrainSet") === false)

这些排序和分区步骤对于一些非常大的数据集来说可能会让人望而却步,所以我建议首先尽可能地过滤数据集。实物图如下:

== Physical Plan ==
*(3) Project [id#4, label#5, row_number#11, if (isnull(row_number#11)) null else UDF(label#5, row_number#11) AS isTrainSet#48]
+- Window [row_number() windowspecdefinition(label#5, label#5 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number#11], [label#5], [label#5 ASC NULLS FIRST]
   +- *(2) Sort [label#5 ASC NULLS FIRST, label#5 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(label#5, 200)
         +- *(1) Project [id#4, label#5]
            +- *(1) Sort [_nondeterministic#9 ASC NULLS FIRST], true, 0
               +- Exchange rangepartitioning(_nondeterministic#9 ASC NULLS FIRST, 200)
                  +- LocalTableScan [id#4, label#5, _nondeterministic#9

这是完整的工作示例(使用 Spark 2.3.0 和 Scala 2.11.12 测试):

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions.{col, row_number, udf, rand}

class StratifiedTrainTestSplitter {

  def getNumExamplesPerClass(ss: SparkSession, label: String, trainRatio: Double)(df: DataFrame): Map[Double, Long] = {
    df.groupBy(label).count().createOrReplaceTempView("labelCounts")
    val query = f"SELECT $label AS ratioLabel, count, cast(count * $trainRatio as long) AS trainExamples FROM labelCounts"
    import ss.implicits._
    ss.sql(query)
      .select("ratioLabel", "trainExamples")
      .map((r: Row) => r.getDouble(0) -> r.getLong(1))
      .collect()
      .toMap
  }

  def split(df: DataFrame, label: String, trainRatio: Double): DataFrame = {
    val w = Window.partitionBy(col(label)).orderBy(col(label))

    val rowNumPartitioner = row_number().over(w)

    val dfRowNum = df.sort(rand).select(col("*"), rowNumPartitioner as "row_number")

    dfRowNum.show()

    val observationsPerLabel: Map[Double, Long] = getNumExamplesPerClass(df.sparkSession, label, trainRatio)(df)

    val addIsTrainColumn = udf((label: Double, rowNumber: Int) => rowNumber <= observationsPerLabel(label))

    dfRowNum.withColumn("isTrainSet", addIsTrainColumn(col(label), col("row_number")))
  }


}

object StratifiedTrainTestSplitter {

  def getDf(ss: SparkSession): DataFrame = {
    val data = Seq(
      (0, 0.0), (1, 1.0), (2, 0.0), (3, 1.0), (4, 0.0), (5, 1.0), (6, 0.0), (7, 1.0), (8, 0.0), (9, 1.0)
    )
    ss.createDataFrame(data).toDF("id", "label")
  }

  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .config(new SparkConf().setMaster("local[1]"))
      .getOrCreate()

    val df = new StratifiedTrainTestSplitter().split(getDf(spark), "label", 0.8)

    df.cache()

    df.where(col("isTrainSet") === true).show()
    df.where(col("isTrainSet") === false).show()
  }
}

注意:在这种情况下,标签是 Double。如果您的标签是 String,您将不得不在这里和那里切换类型。

Spark 支持分层样本,如 https://s3.amazonaws.com/sparksummit-share/ml-ams-1.0.1/6-sampling/scala/6-sampling_student.html

中所述
df.stat.sampleBy("label", Map(0 -> .10, 1 -> .20, 2 -> .3), 0)

也许OP发布这个问题时这个方法不可用,但我把它留在这里以供将来参考:

# splitting dataset into train and test set
train, test = df.randomSplit([0.7, 0.3], seed=42)