追加到 PySpark 数组列



我想检查列值是否在某些边界内。如果不是,我会在数组列"F"后附加一些值。这是我到目前为止的代码:

df = spark.createDataFrame(
[
(1, 56), 
(2, 32),
(3, 99)
],
['id', 'some_nr'] 
)
df = df.withColumn( "F", F.lit( None ).cast( types.ArrayType( types.ShortType( ) ) ) )
def boundary_check( val ):
if (val > 60) | (val < 50):
return 1
udf  = F.udf( lambda x: boundary_check( x ) ) 
df =  df.withColumn("F", udf(F.col("some_nr")))
display(df)

但是,我不知道如何附加到数组中。目前,如果我对 df 执行另一个边界检查,它将简单地覆盖"F"中的先前值......

看看pyspark.sql.functions下的array_union函数:https://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=join#pyspark.sql.functions.array_union

这样你就避免了使用udf,这会带走Spark并行化的任何好处。代码如下所示:

from pyspark.context import SparkContext
from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
from pyspark.sql import Row
import pyspark.sql.functions as f

conf = SparkConf()
sc = SparkContext(conf=conf)
spark = SparkSession(sc)
df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="a", c3=10),
Row(c1=["b", "a", "c"], c2="d", c3=20)])
df.show()
+---------+---+---+
|       c1| c2| c3|
+---------+---+---+
|[b, a, c]|  a| 10|
|[b, a, c]|  d| 20|
+---------+---+---+
df.withColumn(
"output_column", 
f.when(f.col("c3") > 10, 
f.array_union(df.c1, f.array(f.lit("1"))))
.otherwise(f.col("c1"))
).show()
+---------+---+---+-------------+
|       c1| c2| c3|output_column|
+---------+---+---+-------------+
|[b, a, c]|  a| 10|    [b, a, c]|
|[b, a, c]|  d| 20| [b, a, c, 1]|
+---------+---+---+-------------+

作为旁注,这用作逻辑联合,因此如果要附加值,则需要确保此值是唯一的,以便始终添加它。否则,请在此处查看其他array functions: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=join#pyspark.sql.functions.array

注意:对于大多数数组函数,您的 Spark 需要是版本>2.4

编辑(根据评论中的要求(:

withColumn方法一次只允许您处理一列,因此您需要使用新的withColumn,理想情况下,预先为两个withColumn查询预定义逻辑语句。

logical_gate = (f.col("c3") > 10)
(
df.withColumn(
"output_column", 
f.when(logical_gate, 
f.array_union(df.c1, f.array(f.lit("1"))))
.otherwise(f.col("c1")))
.withColumn(
"c3",
f.when(logical_gate,
f.lit(None))
.otherwise(f.col("c3")))
.show()
)
+---------+---+----+-------------+
|       c1| c2|  c3|output_column|
+---------+---+----+-------------+
|[b, a, c]|  a|  10|    [b, a, c]|
|[b, a, c]|  d|null| [b, a, c, 1]|
+---------+---+----+-------------+

Spark 3.4+开始,您可以使用array_append:

from pyspark.sql import functions as F
df = spark.createDataFrame([(10, ['a', 'b', 'c']), (20, ['a', 'b', 'c'])], ['c1', 'c2'])
df.show()
# +---+---------+
# | c1|       c2|
# +---+---------+
# | 10|[a, b, c]|
# | 20|[a, b, c]|
# +---+---------+
df = df.withColumn('c2', F.when(F.col('c1') > 15, F.array_append('c2', 'd')).otherwise(F.col('c2')))
df.show()
# +---+------------+
# | c1|          c2|
# +---+------------+
# | 10|   [a, b, c]|
# | 20|[a, b, c, d]|
# +---+------------+

最新更新