PySpark中的GroupBy和ApplyInPadas-如何正确实现UDF



我正试图在PySpark中使用PandasUDF来查找层次结构中的"最长的唯一尾部"。

例如,如果我的输入是:

1.2
1.2.3

那么最长的尾部是"1.2.3">

我也可能有多个唯一的集合,例如:

1.2
1.2.3
5.6.7
5.6

在这种情况下,输出应该是:

1.2.3
5.6.7

我使用的方法是:

  • 对输入进行排序,以便列出类似的行,这样,如果前一行"包含"在后一行中,我可以过滤掉它,只返回最长的唯一行

示例输入:

1.2.3
5.6.7
5.6
1.2

排序为:

1.2
1.2.3
5.6
5.6.7

当我在线上过滤行时,我的输出应该是

1.2.3
5.6.7

我尝试了两种方法。首先是编写一个函数,该函数通过发送到其中的DF循环,如下所示:

def getLongestTail(key, pdf) -> pd.DataFrame:
sortedData = pdf.sort_values(by='value')
for i in range(len(sortedData)-1):
if sortedData.index(i+1).loc['value'].startswith(sortedData.loc['value']):
sortedData.index(i+1) = False
return pd.DataFrame(sortedData)

其次是在中使用lambda函数

def getLongestTail(pdf) -> pd.DataFrame:
pdf = pdf.sort
return (lambda x: pdf.shift(1).loc['value'].startswith(pdf.loc['value']))

我也试着装饰如下:

@pandas_udf(架构,PandasUDFType.GROUPED_MAP(

这是我的总体代码:

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.functions import pandas_udf, PandasUDFType
import pandas as pd
from pyspark.sql.types import *
simpleData = [
('A', '1.2.3'),
('A', '1.2'),
('B', '9.8'),
('A', '5.6.7.8'),
('B', '9'),
('B', '9.8.7'),
('A', '5')]
schema = StructType([
StructField("letter", StringType()),
StructField("value", StringType())
])
def getLongestTail(pdf) -> pd.DataFrame:
pdf = pdf.sort
return pd.DataFrame((lambda x: pdf.loc['value'].startswith(pdf.shift(1).loc['value'])))

spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(data=simpleData, schema = schema)
df_result = df.groupby('letter').applyInPandas(getLongestTail, schema=schema).show()

我的Jupyter笔记本中显示的错误显示worker崩溃以及与Py4JJavaError有关的错误。一我确信我缺少一些基本的东西——任何评论都很感激。

谢谢。

===

错误:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
/tmp/ipykernel_34305/1009949605.py in <module>
3 # df_grouped.show()
4 
----> 5 df_result = df.groupby('letter').applyInPandas(getLongestTailL, schema=schema).show()
6 
324             value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
325             if answer[1] == REFERENCE_TYPE:
--> 326                 raise Py4JJavaError(
327                     "An error occurred while calling {0}{1}{2}.n".
328                     format(target_id, ".", name), value)

好问题。在将数据传递给applyInPandas之后,我们希望在output_schema中添加一些新变量:只需在input_schema中添加一个结果变量,并将扩展的输出模式传递给applyInPandas

最新更新