如何使用按月和 unix 纪元列给出的比率将 spark 数据帧拆分为 2?

How to split the spark dataframe into 2 using ratio given in terms of months and the unix epoch column?

我想使用按月和 unix 纪元列给出的比率将 spark 数据帧拆分为 2-

示例数据框如下-

unixepoch
---------
1539754800
1539754800
1539931200
1539927600
1539927600
1539931200
1539931200
1539931200
1539927600
1540014000
1540014000
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400

拆分策略-

如果给出的数据总月数为 30 个月,splittingRatio 为 0.6 那么预期的数据框 1 应该有:30 * 0.6 = 18 个月的数据 预期的数据框 1 应该有:30 * 0.4 = 12 个月的数据

编辑-1

大多数答案是通过考虑记录数的拆分比率给出的,即如果总记录数 = 100 且拆分比率 = 0.6 然后 split1DF~=60 条记录和 split2DF~=40 条记录。 更清楚地说,这不是我要找的。这里给出了月份的拆分比率,可以通过上述示例数据帧中给定的 epoch unix 时间戳列来计算。 假设上面的纪元列是 30 个月的一些分布,那么我想要数据框 1 中的前 18 个月纪元和第二个数据框中的最后 12 个月纪元行。您可以将其视为拆分 spark 中时间序列数据的数据帧。

编辑-2

如果给出的数据是2018年7月到2019年5月=10个月的数据,那么split1(0.6=first 6 months)= (July, 2018, Jan,2019) and split2(0.4=last 4 months) )=(2019 年 2 月,2019 年 5 月)。随机采摘不应该在那里。

使用row_number & filter将数据拆分为两个DataFrame。

scala> val totalMonths = 10
totalMonths: Int = 10

scala> val splitRatio = 0.6
splitRatio: Double = 0.6

scala> val condition = (totalMonths * splitRatio).floor + 1
condition: Double = 7.0

scala> epochDF.show(false)
+----------+-----+
|dt        |month|
+----------+-----+
|1530383400|7    |
|1533061800|8    |
|1535740200|9    |
|1538332200|10   |
|1541010600|11   |
|1543602600|12   |
|1546281000|1    |
|1548959400|2    |
|1551378600|3    |
|1554057000|4    |
|1556649000|5    |
+----------+-----+

scala> import org.apache.spark.sql.expressions._
import org.apache.spark.sql.expressions._

scala> epochDF.orderBy($"dt".asc).withColumn("id",row_number().over(Window.orderBy($"dt".asc))).filter($"id" <= condition).show(false)

+----------+-----+---+
|dt        |month|id |
+----------+-----+---+
|2018-07-01|7    |1  |
|2018-08-01|8    |2  |
|2018-09-01|9    |3  |
|2018-10-01|10   |4  |
|2018-11-01|11   |5  |
|2018-12-01|12   |6  |
|2019-01-01|1    |7  |
+----------+-----+---+

scala> epochDF.orderBy($"dt".asc).withColumn("id",row_number().over(Window.orderBy($"dt".asc))).filter($"id" > condition).show(false)

+----------+-----+---+
|dt        |month|id |
+----------+-----+---+
|2019-02-01|2    |8  |
|2019-03-01|3    |9  |
|2019-04-01|4    |10 |
|2019-05-01|5    |11 |
+----------+-----+---+


如果给出的数据是 1 个月,我会按月划分数据,然后按天划分数据。

I prefer this method since this answer is not dependent on the windowing function. Other answer given here uses Window without partitionBy which degrades the performance seriously as data shuffles to one executor.

1。给定以月为单位的火车比率的拆分方法

 val EPOCH = "epoch"
    def splitTrainTest(inputDF: DataFrame,
                       trainRatio: Double): (DataFrame, DataFrame) = {
      require(trainRatio >= 0 && trainRatio <= 0.9, s"trainRatio must between 0 and 0.9, found : $trainRatio")

      def extractDateCols(tuples: (String, Column)*): DataFrame = {
        tuples.foldLeft(inputDF) {
          case (df, (dateColPrefix, dateColumn)) =>
            df
              .withColumn(s"${dateColPrefix}_month", month(from_unixtime(dateColumn))) // month
              .withColumn(s"${dateColPrefix}_dayofmonth", dayofmonth(from_unixtime(dateColumn))) // dayofmonth
              .withColumn(s"${dateColPrefix}_year", year(from_unixtime(dateColumn))) // year
        }
      }

      val extractDF = extractDateCols((EPOCH, inputDF(EPOCH)))

      // derive min/max(yyyy-MM)
      val yearCol = s"${EPOCH}_year"
      val monthCol = s"${EPOCH}_month"
      val dayCol = s"${EPOCH}_dayofmonth"
      val SPLIT = "split"
      require(trainRatio >= 0 && trainRatio <= 0.9, s"trainRatio must between 0 and 0.9, found : $trainRatio")

      // derive min/max(yyyy-MM)
      //    val yearCol = PLANNED_START_YEAR
      //    val monthCol = PLANNED_START_MONTH
      val dateCol = to_date(date_format(
        concat_ws("-", Seq(yearCol, monthCol).map(col): _*), "yyyy-MM-01"))

      val minMaxDF = extractDF.agg(max(dateCol).as("max_date"), min(dateCol).as("min_date"))
      val min_max_date = minMaxDF.head()
      import java.sql.{Date => SqlDate}
      val minDate = min_max_date.getAs[SqlDate]("min_date")
      val maxDate = min_max_date.getAs[SqlDate]("max_date")

      println(s"Min Date Found: $minDate")
      println(s"Max Date Found: $maxDate")

      // Get the total months for which the data exist
      val totalMonths = (maxDate.toLocalDate.getYear - minDate.toLocalDate.getYear) * 12 +
        maxDate.toLocalDate.getMonthValue - minDate.toLocalDate.getMonthValue
      println(s"Total Months of data found for is $totalMonths months")

      // difference starts with 0
      val splitDF = extractDF.withColumn(SPLIT, round(months_between(dateCol, to_date(lit(minDate)))).cast(DataTypes.IntegerType))

      val (trainDF, testDF) = totalMonths match {
        // data is provided for more than a month
        case tm if tm > 0 =>
          val trainMonths = Math.round(totalMonths * trainRatio)
          println(s"Data considered for training is < $trainMonths months")
          println(s"Data considered for testing is >= $trainMonths months")
          (splitDF.filter(col(SPLIT) < trainMonths), splitDF.filter(col(SPLIT) >= trainMonths))

        // data is provided for a month, split based on the total records  in terms of days
        case tm if tm == 0 =>
          //        val dayCol = PLANNED_START_DAYOFMONTH
          val splitDF1 = splitDF.withColumn(SPLIT,
            datediff(date_format(
              concat_ws("-", Seq(yearCol, monthCol, dayCol).map(col): _*), "yyyy-MM-dd"), lit(minDate))
          )
          // Get the total days for which the data exist
          val todalDays = splitDF1.select(max(SPLIT).as("total_days")).head.getAs[Int]("total_days")
          if (todalDays <= 1) {
            throw new RuntimeException(s"Insufficient data provided for training, Data found for $todalDays days but " +
              s"$todalDays > 1 required")
          }
          println(s"Total Days of data found is $todalDays days")

          val trainDays = Math.round(todalDays * trainRatio)
          (splitDF1.filter(col(SPLIT) < trainDays), splitDF1.filter(col(SPLIT) >= trainDays))

        // data should be there
        case default => throw new RuntimeException(s"Insufficient data provided for training, Data found for $totalMonths " +
          s"months but $totalMonths >= 1 required")
      }
      (trainDF.cache(), testDF.cache())
    }

2。使用跨年多月的数据进行测试

 //  call methods
    val implicits = sqlContext.sparkSession.implicits
    import implicits._
    val monthData = sc.parallelize(Seq(
      1539754800,
      1539754800,
      1539931200,
      1539927600,
      1539927600,
      1539931200,
      1539931200,
      1539931200,
      1539927600,
      1540449600,
      1540449600,
      1540536000,
      1540536000,
      1540536000,
      1540424400,
      1540424400,
      1540618800,
      1540618800,
      1545979320,
      1546062120,
      1545892920,
      1545892920,
      1545892920,
      1545201720,
      1545892920,
      1545892920
    )).toDF(EPOCH)

    val (split1, split2) = splitTrainTest(monthData, 0.6)
    split1.show(false)
    split2.show(false)

    /**
      * Min Date Found: 2018-10-01
      * Max Date Found: 2018-12-01
      * Total Months of data found for is 2 months
      * Data considered for training is < 1 months
      * Data considered for testing is >= 1 months
      * +----------+-----------+----------------+----------+-----+
      * |epoch     |epoch_month|epoch_dayofmonth|epoch_year|split|
      * +----------+-----------+----------------+----------+-----+
      * |1539754800|10         |17              |2018      |0    |
      * |1539754800|10         |17              |2018      |0    |
      * |1539931200|10         |19              |2018      |0    |
      * |1539927600|10         |19              |2018      |0    |
      * |1539927600|10         |19              |2018      |0    |
      * |1539931200|10         |19              |2018      |0    |
      * |1539931200|10         |19              |2018      |0    |
      * |1539931200|10         |19              |2018      |0    |
      * |1539927600|10         |19              |2018      |0    |
      * |1540449600|10         |25              |2018      |0    |
      * |1540449600|10         |25              |2018      |0    |
      * |1540536000|10         |26              |2018      |0    |
      * |1540536000|10         |26              |2018      |0    |
      * |1540536000|10         |26              |2018      |0    |
      * |1540424400|10         |25              |2018      |0    |
      * |1540424400|10         |25              |2018      |0    |
      * |1540618800|10         |27              |2018      |0    |
      * |1540618800|10         |27              |2018      |0    |
      * +----------+-----------+----------------+----------+-----+
      *
      * +----------+-----------+----------------+----------+-----+
      * |epoch     |epoch_month|epoch_dayofmonth|epoch_year|split|
      * +----------+-----------+----------------+----------+-----+
      * |1545979320|12         |28              |2018      |2    |
      * |1546062120|12         |29              |2018      |2    |
      * |1545892920|12         |27              |2018      |2    |
      * |1545892920|12         |27              |2018      |2    |
      * |1545892920|12         |27              |2018      |2    |
      * |1545201720|12         |19              |2018      |2    |
      * |1545892920|12         |27              |2018      |2    |
      * |1545892920|12         |27              |2018      |2    |
      * +----------+-----------+----------------+----------+-----+
      */

3。使用一年中一个月的数据进行测试

 val oneMonthData = sc.parallelize(Seq(
      1589514575, //  Friday, May 15, 2020 3:49:35 AM
      1589600975, // Saturday, May 16, 2020 3:49:35 AM
      1589946575, // Wednesday, May 20, 2020 3:49:35 AM
      1590378575, // Monday, May 25, 2020 3:49:35 AM
      1590464975, // Tuesday, May 26, 2020 3:49:35 AM
      1590470135 // Tuesday, May 26, 2020 5:15:35 AM
    )).toDF(EPOCH)

    val (split3, split4) = splitTrainTest(oneMonthData, 0.6)
    split3.show(false)
    split4.show(false)

    /**
      * Min Date Found: 2020-05-01
      * Max Date Found: 2020-05-01
      * Total Months of data found for is 0 months
      * Total Days of data found is 25 days
      * +----------+-----------+----------------+----------+-----+
      * |epoch     |epoch_month|epoch_dayofmonth|epoch_year|split|
      * +----------+-----------+----------------+----------+-----+
      * |1589514575|5          |15              |2020      |14   |
      * +----------+-----------+----------------+----------+-----+
      *
      * +----------+-----------+----------------+----------+-----+
      * |epoch     |epoch_month|epoch_dayofmonth|epoch_year|split|
      * +----------+-----------+----------------+----------+-----+
      * |1589600975|5          |16              |2020      |15   |
      * |1589946575|5          |20              |2020      |19   |
      * |1590378575|5          |25              |2020      |24   |
      * |1590464975|5          |26              |2020      |25   |
      * |1590470135|5          |26              |2020      |25   |
      * +----------+-----------+----------------+----------+-----+
      */