Spark DataFrame过滤器无法随机使用



这是我的dataframe

df.groupBy($"label").count.show
+-----+---------+                                                               
|label|    count|
+-----+---------+
|  0.0|400000000|
|  1.0| 10000000|
+-----+---------+

我试图用标签== 0.0进行记录,但以下内容:

val r = scala.util.Random
val df2 = df.filter($"label" === 1.0 || r.nextDouble > 0.5) // keep 50% of 0.0

我的输出看起来像这样:

df2.groupBy($"label").count.show
+-----+--------+                                                                
|label|   count|
+-----+--------+
|  1.0|10000000|
+-----+--------+

r.nextDouble是表达式中的常数,因此实际评估与您的含义完全不同。根据实际采样值,它是

scala> r.setSeed(0)
scala> $"label" === 1.0 || r.nextDouble > 0.5
res0: org.apache.spark.sql.Column = ((label = 1.0) OR true)

scala> r.setSeed(4096)
scala> $"label" === 1.0 || r.nextDouble > 0.5
res3: org.apache.spark.sql.Column = ((label = 1.0) OR false)

因此,在简化之后,它只是:

true

(保留所有记录)或

label = 1.0 

分别只保留您观察到的情况)。

要生成随机数,您应该使用相应的SQL函数

scala> import org.apache.spark.sql.functions.rand
import org.apache.spark.sql.functions.rand
scala> $"label" === 1.0 || rand > 0.5
res1: org.apache.spark.sql.Column = ((label = 1.0) OR (rand(3801516599083917286) > 0.5))

尽管Spark已经提供了分层的采样工具:

df.stat.sampleBy(
  "label",  // column
  Map(0.0 -> 0.5, 1.0 -> 1.0),  // fractions
  42 // seed 
)

相关内容

  • 没有找到相关文章

最新更新