将 Spark 数据帧分成两半,不重叠数据



我有一个大约38313行的数据帧,对于某些 AB 测试用例,我需要将这个数据帧分成两半并单独存储它们。

为此,我正在使用org.apache.spark.sql.randomSplit,此功能似乎在小型数据集上工作正常,但是当您拥有大型数据帧时,它开始引起一些问题。 我注意到每次将数据帧分成两半时,都会得到重叠的结果

val dedupTarget = target.dropDuplicates("identifier")
val splitDF = dedupTarget.randomSplit(Array(0.5, 0.5), 1000)
// splitDF(0) and splitDF(1) has some overlapping rows and some data we had in dedupTarget doesn't even exist in any of them

基于 randomSplit 实现

// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
// ordering deterministic.
// MapType cannot be sorted.

所以我试图在拆分之前对我的数据帧进行排序,但它根本没有帮助。

val dedupTarget = target.dropDuplicates("identifier").orderBy(col("identifier").desc)
val splitDF = dedupTarget.randomSplit(Array(0.5, 0.5), 1000)

我建议你采取不同的方法。

获取包含一半数据帧作为第一个数据帧的示例:

val firstDF = dedupTarget.sample(false, 0.5)

然后从最初的DF中减去它作为后半部分:

val secondDF = dedupTarget.except(firstDF)

这样,您可以获得两个不重叠的数据帧。

一种解决方案是创建一个随机列并使用它来将初始数据帧一分为二。如果需要两个相等的部分(一半(,请获取(myrandcol(的中位数,并在过滤器中使用中位数,而不是以下示例中的 0.5。

scala> df.show
+---+----+
| id|data|
+---+----+
|  1|  10|
|  2|  20|
|  3|  30|
|  4|  40|
|  5|  50|
|  6|   5|
|  7|  15|
|  8|  25|
|  9|  35|
| 10|  45|
| 11|  55|
| 12|  65|
+---+----+

scala> val dfrand = df.withColumn("myrandcol", rand())
dfrand: org.apache.spark.sql.DataFrame = [id: int, data: int ... 1 more field]
scala> dfrand.show
+---+----+--------------------+
| id|data|           myrandcol|
+---+----+--------------------+
|  1|  10|0.032922537840013755|
|  2|  20|  0.3033357451409988|
|  3|  30|  0.3540969077830527|
|  4|  40|  0.3303614771224386|
|  5|  50| 0.43494868849484125|
|  6|   5|  0.4250309835092507|
|  7|  15|  0.7405114480878822|
|  8|  25|  0.7591141079555013|
|  9|  35|  0.7497022992064433|
| 10|  45| 0.27779407072568674|
| 11|  55|  0.8203602166103228|
| 12|  65|  0.9171256953932918|
+---+----+--------------------+
scala> val dfA = dfrand.where($"myrandcol" <= 0.5)
dfA: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id: int, data: int ... 1 more field]
scala> val dfB = dfrand.where($"myrandcol" > 0.5)
dfB: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id: int, data: int ... 1 more field]
scala> dfA.show
+---+----+--------------------+
| id|data|           myrandcol|
+---+----+--------------------+
|  1|  10|0.032922537840013755|
|  2|  20|  0.3033357451409988|
|  3|  30|  0.3540969077830527|
|  4|  40|  0.3303614771224386|
|  5|  50| 0.43494868849484125|
|  6|   5|  0.4250309835092507|
| 10|  45| 0.27779407072568674|
+---+----+--------------------+

scala> dfB.show
+---+----+------------------+
| id|data|         myrandcol|
+---+----+------------------+
|  7|  15|0.7405114480878822|
|  8|  25|0.7591141079555013|
|  9|  35|0.7497022992064433|
| 11|  55|0.8203602166103228|
| 12|  65|0.9171256953932918|
+---+----+------------------+

最新更新