我想添加一个新的列score
,这是一个数组,其长度等于另一个列values
的大小,并包含所有值2
。
使用列的size
时出现错误,但如果我用硬编码的数字替换它,则工作正常。
columns = ["id","values"]
data = [("sample1", [12.0,10.0]), ("sample2", [1.0,2.0,3.0,4.0])]
rdd = spark.sparkContext.parallelize(data)
源DataFrame
+-------+--------------------+
| id| values|
+-------+--------------------+
|sample1| [12.0, 10.0]|
|sample2|[1.0, 2.0, 3.0, 4.0]|
+-------+--------------------+
预期输出
+-------+--------------------+--------------------+
| id| values| score|
+-------+--------------------+--------------------+
|sample1| [12.0, 10.0]| [2, 2] |
|sample2|[1.0, 2.0, 3.0, 4.0]| [2, 2, 2, 2]|
+-------+--------------------+--------------------+
from pyspark.sql.functions import *
df.withColumn("score",array([lit(x) for x in [2]*(size(col("values")))])).show()
低于错误
: java.lang.RuntimeException:不支持的文字类型类java.util.ArrayList [2]
不能将Python列表与Spark列相乘。您可以使用array_repeat
函数
import pyspark.sql.functions as F
df2 = df.withColumn('score', F.expr('array_repeat(2, size(values))'))
df2.show()
+-------+--------------------+------------+
| id| values| score|
+-------+--------------------+------------+
|sample1| [12.0, 10.0]| [2, 2]|
|sample2|[1.0, 2.0, 3.0, 4.0]|[2, 2, 2, 2]|
+-------+--------------------+------------+
函数array_repeat
仅适用于Spark 2.4+。对于旧版本,可以使用UDF:
from pyspark.sql.functions import udf, size, lit
from pyspark.sql.types import ArrayType, IntegerType
array_repeat_udf = udf(lambda v, n: [v for _ in range(n)], ArrayType(IntegerType()))
df1 = df.withColumn('score', array_repeat_udf(lit(2), size("values")))
df1.show()
#+-------+--------------------+------------+
#| id| values| score|
#+-------+--------------------+------------+
#|sample1| [12.0, 10.0]| [2, 2]|
#|sample2|[1.0, 2.0, 3.0, 4.0]|[2, 2, 2, 2]|
#+-------+--------------------+------------+