Pyspark:爆炸阵列缓慢



My dataframe

df_a = spark.createDataFrame( [ 
(0, ["B","C","D","E"] , [1,2,3,4] ),
(1,["E","A","C"] , [1,2,3] ),
(2, ["F","A","E","B"] , [1,2,3,4]),
(3,["E","G","A"] , [1,2,3 ]),
(4,["A","C","E","B","D"] , [1,2,3,4,5])] , ["id","items",'rank'])

,我希望我的输出为:

+---+----+----+
| id|item|rank|
+---+----+----+
|  0|   B|   1|
|  0|   C|   2|
|  0|   D|   3|
|  0|   E|   4|
|  1|   E|   1|
|  1|   A|   2|
|  1|   C|   3|
|  2|   F|   1|
|  2|   A|   2|
|  2|   E|   3|
|  2|   B|   4|
|  3|   E|   1|
|  3|   G|   2|
|  3|   A|   3|
|  4|   A|   1|
|  4|   C|   2|
|  4|   E|   3|
|  4|   B|   4|
|  4|   D|   5|
+---+----+----+

我的数据帧有800万行,当我尝试压缩和爆炸时,它非常慢,作业永远运行,使用15个执行器和25GB内存

zip_udf2 = F.udf(
lambda x, y: list(zip(x, y)),
ArrayType(StructType([
StructField("item", StringType()),
StructField("rank", IntegerType())

]))
)
(df_a
.withColumn('tmp', zip_udf2("items", "rank"))
.withColumn("tmp", F.explode('tmp'))
.select("id", F.col("tmp.item"), F.col("tmp.rank"))
.show())

有其他方法吗?我试过了。flatMap仍然没有对性能产生影响。每行数组中元素的个数变化

UPDATE

因为你正在使用Spark 2.3.2和arrays_zip不可用,我做了一些测试,比较哪个是最好的选择:udfposexplode。快速的答案是:posexplode

(df_a
.select('id', F.posexplode('items'), 'rank')
.select('id', F.col('col').alias('item'), F.expr('rank[pos]').alias('rank'))
.show())
测试

from pyspark.sql.types import *
import pyspark.sql.functions as F
import time

df_a = spark.createDataFrame([ 
(0, ["B","C","D","E"] , [1,2,3,4] ),
(1,["E","A","C"] , [1,2,3] ),
(2, ["F","A","E","B"] , [1,2,3,4]),
(3,["E","G","A"] , [1,2,3 ]),
(4,["A","C","E","B","D"] , [1,2,3,4,5])] , ["id","items",'rank'])

# My solution
def using_posexplode():
(df_a
.select('id', F.posexplode('items'), 'rank')
.select('id', F.col('col').alias('item'), F.expr('rank[pos]').alias('rank'))
.count())

# Your solution
zip_udf2 = F.udf(
lambda x, y: list(zip(x, y)),
ArrayType(StructType([
StructField("item", StringType()),
StructField("rank", IntegerType())
])))
def using_udf():
(df_a
.withColumn('tmp', zip_udf2("items", "rank"))
.withColumn("tmp", F.explode('tmp'))
.select("id", F.col("tmp.item"), F.col("tmp.rank"))
.count())

def time_run_method(iterations, fn):
t0 = time.time()
for i in range(iterations):
fn()
td = time.time() - t0

print(fn.__name__, "Time to count %d iterations: %s [sec]" % (iterations, "{:,}".format(td)))

for function in [using_posexplode, using_udf]:
time_run_method(iterations=100, fn=function)
using_posexplode Time to count 100 iterations: 24.962905168533325 [sec]
using_udf Time to count 100 iterations: 44.120017290115356 [sec]

不能保证这样做就能解决整个问题,但是要考虑的一件事是删除zip_udf2并将其更改为Spark的本机函数arrays_zip。下面是关于为什么我们应该避免(在可能的情况下)使用UDF的解释。

相关内容

  • 没有找到相关文章

最新更新