如何在 Apache Spark 上进行非随机数据集拆分?
How to do non-random Dataset splitting on Apache Spark?
我知道我可以使用 randomSplit 方法进行随机拆分:
val splittedData: Array[Dataset[Row]] =
preparedData.randomSplit(Array(0.5, 0.3, 0.2))
我可以用一些 'nonRandomSplit method' 将数据拆分成连续的部分吗?
阿帕奇火花 2.0.1。
提前致谢。
UPD:数据顺序很重要,我将使用 'smaller IDs' 训练我的模型并使用 'larger IDs' 测试它。所以我想把数据分成连续的部分而不打乱。
例如
my dataset = (0,1,2,3,4,5,6,7,8,9)
desired splitting = (0.8, 0.2)
splitting = (0,1,2,3,4,5,6,7), (8,9)
我能想到的唯一解决方案是使用 count 和 limit,但可能还有更好的解决方案。
这是我实现的解决方案:Dataset -> Rdd -> Dataset。
我不确定这是否是最有效的方法,所以我很乐意接受更好的解决方案。
val count = allData.count()
val trainRatio = 0.6
val trainSize = math.round(count * trainRatio).toInt
val dataSchema = allData.schema
// Zipping with indices and skipping rows with indices > trainSize.
// Could have possibly used .limit(n) here
val trainingRdd =
allData
.rdd
.zipWithIndex()
.filter { case (_, index) => index < trainSize }
.map { case (row, _) => row }
// Can't use .limit() :(
val testRdd =
allData
.rdd
.zipWithIndex()
.filter { case (_, index) => index >= trainSize }
.map { case (row, _) => row }
val training = MySession.createDataFrame(trainingRdd, dataSchema)
val test = MySession.createDataFrame(testRdd, dataSchema)
我知道我可以使用 randomSplit 方法进行随机拆分:
val splittedData: Array[Dataset[Row]] =
preparedData.randomSplit(Array(0.5, 0.3, 0.2))
我可以用一些 'nonRandomSplit method' 将数据拆分成连续的部分吗?
阿帕奇火花 2.0.1。 提前致谢。
UPD:数据顺序很重要,我将使用 'smaller IDs' 训练我的模型并使用 'larger IDs' 测试它。所以我想把数据分成连续的部分而不打乱。
例如
my dataset = (0,1,2,3,4,5,6,7,8,9)
desired splitting = (0.8, 0.2)
splitting = (0,1,2,3,4,5,6,7), (8,9)
我能想到的唯一解决方案是使用 count 和 limit,但可能还有更好的解决方案。
这是我实现的解决方案:Dataset -> Rdd -> Dataset。
我不确定这是否是最有效的方法,所以我很乐意接受更好的解决方案。
val count = allData.count()
val trainRatio = 0.6
val trainSize = math.round(count * trainRatio).toInt
val dataSchema = allData.schema
// Zipping with indices and skipping rows with indices > trainSize.
// Could have possibly used .limit(n) here
val trainingRdd =
allData
.rdd
.zipWithIndex()
.filter { case (_, index) => index < trainSize }
.map { case (row, _) => row }
// Can't use .limit() :(
val testRdd =
allData
.rdd
.zipWithIndex()
.filter { case (_, index) => index >= trainSize }
.map { case (row, _) => row }
val training = MySession.createDataFrame(trainingRdd, dataSchema)
val test = MySession.createDataFrame(testRdd, dataSchema)