我有一个大约38313
行的数据帧,对于某些 AB 测试用例,我需要将这个数据帧分成两半并单独存储它们。
为此,我正在使用org.apache.spark.sql.randomSplit
,此功能似乎在小型数据集上工作正常,但是当您拥有大型数据帧时,它开始引起一些问题。 我注意到每次将数据帧分成两半时,都会得到重叠的结果
val dedupTarget = target.dropDuplicates("identifier")
val splitDF = dedupTarget.randomSplit(Array(0.5, 0.5), 1000)
// splitDF(0) and splitDF(1) has some overlapping rows and some data we had in dedupTarget doesn't even exist in any of them
基于 randomSplit 实现
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. // MapType cannot be sorted.
所以我试图在拆分之前对我的数据帧进行排序,但它根本没有帮助。
val dedupTarget = target.dropDuplicates("identifier").orderBy(col("identifier").desc)
val splitDF = dedupTarget.randomSplit(Array(0.5, 0.5), 1000)
我建议你采取不同的方法。
获取包含一半数据帧作为第一个数据帧的示例:
val firstDF = dedupTarget.sample(false, 0.5)
然后从最初的DF中减去它作为后半部分:
val secondDF = dedupTarget.except(firstDF)
这样,您可以获得两个不重叠的数据帧。
一种解决方案是创建一个随机列并使用它来将初始数据帧一分为二。如果需要两个相等的部分(一半(,请获取(myrandcol(的中位数,并在过滤器中使用中位数,而不是以下示例中的 0.5。
scala> df.show
+---+----+
| id|data|
+---+----+
| 1| 10|
| 2| 20|
| 3| 30|
| 4| 40|
| 5| 50|
| 6| 5|
| 7| 15|
| 8| 25|
| 9| 35|
| 10| 45|
| 11| 55|
| 12| 65|
+---+----+
scala> val dfrand = df.withColumn("myrandcol", rand())
dfrand: org.apache.spark.sql.DataFrame = [id: int, data: int ... 1 more field]
scala> dfrand.show
+---+----+--------------------+
| id|data| myrandcol|
+---+----+--------------------+
| 1| 10|0.032922537840013755|
| 2| 20| 0.3033357451409988|
| 3| 30| 0.3540969077830527|
| 4| 40| 0.3303614771224386|
| 5| 50| 0.43494868849484125|
| 6| 5| 0.4250309835092507|
| 7| 15| 0.7405114480878822|
| 8| 25| 0.7591141079555013|
| 9| 35| 0.7497022992064433|
| 10| 45| 0.27779407072568674|
| 11| 55| 0.8203602166103228|
| 12| 65| 0.9171256953932918|
+---+----+--------------------+
scala> val dfA = dfrand.where($"myrandcol" <= 0.5)
dfA: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id: int, data: int ... 1 more field]
scala> val dfB = dfrand.where($"myrandcol" > 0.5)
dfB: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id: int, data: int ... 1 more field]
scala> dfA.show
+---+----+--------------------+
| id|data| myrandcol|
+---+----+--------------------+
| 1| 10|0.032922537840013755|
| 2| 20| 0.3033357451409988|
| 3| 30| 0.3540969077830527|
| 4| 40| 0.3303614771224386|
| 5| 50| 0.43494868849484125|
| 6| 5| 0.4250309835092507|
| 10| 45| 0.27779407072568674|
+---+----+--------------------+
scala> dfB.show
+---+----+------------------+
| id|data| myrandcol|
+---+----+------------------+
| 7| 15|0.7405114480878822|
| 8| 25|0.7591141079555013|
| 9| 35|0.7497022992064433|
| 11| 55|0.8203602166103228|
| 12| 65|0.9171256953932918|
+---+----+------------------+