来自 Spark 数据帧中列表列的示例值



我有一个 spark-scala 数据帧,如下面的 df1 所示: 我想根据 df1 另一列中的计数从分数列(列表(中替换进行采样。

val df1 = sc.parallelize(Seq(("a1",2,List(20,10)),("a2",1,List(30,10)),
("a3",3,List(10)),("a4",2,List(10,20,40)))).toDF("colA","counts","scores")
df1.show()
+----+------+------------+
|colA|counts|      scores|
+----+------+------------+
|  a1|     2|    [20, 10]|
|  a2|     1|    [30, 10]|
|  a3|     3|        [10]|
|  a4|     2|[10, 20, 40]|
+----+------+------------+
预期输出显示在 df2 中:从第 1 行,从列表 [20,10] 中样本 2 值;从第 2 行样本 1 值从列表

[30,10] 中;从第 3 行样本 3 个来自列表 [10] 的值,重复.. 等。

df2.show() //expected output
+----+------+------------+-------------+
|colA|counts|      scores|sampledScores|
+----+------+------------+-------------+
|  a1|     2|    [20, 10]|     [20, 10]|
|  a2|     1|    [30, 10]|         [30]|
|  a3|     3|        [10]| [10, 10, 10]|
|  a4|     2|[10, 20, 40]|     [10, 40]|
+----+------+------------+-------------+

我写了一个udf 'takeSample'并应用于df1,但没有按预期工作。

val takeSample = udf((a:Array[Int], count1:Int) => {Array.fill(count1)(
a(new Random(System.currentTimeMillis).nextInt(a.size)))}
)
val df2 = df1.withColumn("SampledScores", takeSample(df1("Scores"),df1("counts")))

我收到以下运行时错误;执行时

df2.printSchema()
root
 |-- colA: string (nullable = true)
 |-- counts: integer (nullable = true)
 |-- scores: array (nullable = true)
 |    |-- element: integer (containsNull = false)
 |-- SampledScores: array (nullable = true)
 |    |-- element: integer (containsNull = false)
df2.show()
org.apache.spark.SparkException: Failed to execute user defined   
function($anonfun$1: (array<int>, int) => array<int>)
Caused by: java.lang.ClassCastException:  
scala.collection.mutable.WrappedArray$ofRef cannot be cast to [I
at $anonfun$1.apply(<console>:47)

任何解决方案都非常感谢。

将 UDF 中的数据类型从 Array[Int] 更改为Seq[Int]将解决此问题:

val takeSample = udf((a:Seq[Int], count1:Int) => {Array.fill(count1)(
a(new Random(System.currentTimeMillis).nextInt(a.size)))}
)
val df2 = df1.withColumn("SampledScores", takeSample(df1("Scores"),df1("counts")))

它将为我们提供预期的输出:

df2.printSchema()
root
 |-- colA: string (nullable = true)
 |-- counts: integer (nullable = true)
 |-- scores: array (nullable = true)
 |    |-- element: integer (containsNull = false)
 |-- SampledScores: array (nullable = true)
 |    |-- element: integer (containsNull = false)
df2.show
+----+------+------------+-------------+
|colA|counts|      scores|SampledScores|
+----+------+------------+-------------+
|  a1|     2|    [20, 10]|     [20, 20]|
|  a2|     1|    [30, 10]|         [30]|
|  a3|     3|        [10]| [10, 10, 10]|
|  a4|     2|[10, 20, 40]|     [20, 20]|
+----+------+------------+-------------+

相关内容

  • 没有找到相关文章

最新更新