如何检查Pyspark Dataframe中是否存在列表的交集



我有一个pyspark数据帧,如下所示:

import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql.functions import udf
schema = T.StructType([  # schema
T.StructField("id", T.StringType(), True),
T.StructField("code", T.ArrayType(T.StringType()), True)])
df = spark.createDataFrame([{"id": "1", "code": ["a1", "a2","a3","a4"]},
{"id": "2", "code": ["b1","b2"]},
{"id": "3", "code": ["c1","c2","c3"]},
{"id": "4", "code": ["d1", "b3"]}],
schema=schema)

它给出输出

df.show()

| id|            code|
|---|----------------|
|  1|[a1, a2, a3, a4]|
|  2|        [b1, b2]|
|  3|    [c1, c2, c3]|
|  4|        [d1, b3]|

我希望能够通过向函数提供列和列表来过滤行,如果有任何interesection(从这里使用不相交,因为会有很多未命中(,则返回true

def lst_intersect(data_lst,query_lst):
return not set(data_lst).isdisjoint(query_lst) 
lst_intersect_udf = F.udf(lambda x,y: lst_intersect(x,y), T.BooleanType())

当我尝试应用此时

query_lst = ['a1','b3']
df = df.withColumn("code_found", lst_intersect_udf(F.col('code'),F.lit(query_lst)))

获取以下错误

Unsupported literal type class java.util.ArrayList [a1, b3]

我可以通过更改函数等来解决它,但我想知道F.lit(query_lst)是否有根本问题?

lit只接受一个值,不接受Python列表。例如,您需要使用列表理解传递一个数组列,该列包含列表中的文字值。

df2 = df.withColumn(
"code_found", 
lst_intersect_udf(
F.col('code'),
F.array(*[F.lit(i) for i in query_lst])
)
)
df2.show()
+---+----------------+----------+
| id|            code|code_found|
+---+----------------+----------+
|  1|[a1, a2, a3, a4]|      true|
|  2|        [b1, b2]|     false|
|  3|    [c1, c2, c3]|     false|
|  4|        [d1, b3]|      true|
+---+----------------+----------+

也就是说,如果你有Spark>=2.4,您还可以使用Spark SQL函数arrays_overlap来提供更好的性能:

df2 = df.withColumn(
"code_found", 
F.arrays_overlap(
F.col('code'),
F.array(*[F.lit(i) for i in query_lst])
)
)

最新更新