pyspark:将一列数组解压为多列更有效



我有一列数组,可以像这样创建

df = spark.CreateDataFrame(["[{"key":1}, {"key":2}"], "tmp")
from pyspark.sql import functions as F
df = df.
withColumn("tmp", F.from_json(in_col_name, "array<string>")).
cache()
# obtain the maximum number of components in the array
max_arr_len = df.select(F.size(tmp)).rdd.max()[0]
for i in range(max_arr_len )
df = df.withColumn("tmp"+str(i), F.col("tmp").getItem(i))

想象一下,如果我在一亿行上运行这个。我认为循环使用getItem是低效的。有没有一种方法可以同时获得所有max_arr_len列?

在这种情况下,实际上循环并没有那么低效。getItem是一个惰性转换,因此Spark能够优化代码,并将所有循环步骤作为一个步骤执行。使用df.explain():查看计划

== Physical Plan ==
*(1) Project [in_col_name#820, tmp#822, tmp#822[0] AS tmp0#945, tmp#822[1] AS tmp1#949]
+- InMemoryTableScan [in_col_name#820, tmp#822]
+- InMemoryRelation [in_col_name#820, tmp#822], StorageLevel(disk, memory, deserialized, 1 replicas)
+- Project [in_col_name#820, from_json(ArrayType(StructType(StructField(key,IntegerType,true)),true), in_col_name#820, Some(Etc/UTC)) AS tmp#822]
+- *(1) Scan ExistingRDD[in_col_name#820]

您会注意到,所有元素提取都在同一行执行:

tmp#822[0] AS tmp0#945, tmp#822[1] AS tmp1#949

最新更新