pyspark中的位操作,不使用udf



我有如下所示的spark数据帧:

+---------+---------------------------+
|country  |sports                     |
+---------+---------------------------+
|India    |[Cricket, Hockey, Football]|
|Sri Lanka|[Cricket, Football]        |
+---------+---------------------------+

运动栏中的每项运动都用一个代码表示:

sport_to_code_map = {
'Cricket' : 0x0001,
'Hockey' : 0x0002,
'Football' : 0x0004
}

现在,我想添加一个名为sportsInt的新列,这是上面地图中与运动字符串相关的每个代码的按位的结果,因此产生:

+---------+---------------------------+---------+
|country  |sports                     |sportsInt|
+---------+---------------------------+---------+
|India    |[Cricket, Hockey, Football]|7        |
|Sri Lanka|[Cricket, Football]        |5        |
+---------+---------------------------+---------+

我知道这样做的一种方法是使用UDF,它是这样的:

def get_sport_to_code(sport_name):
sport_to_code_map = {
'Cricket': 0x0001,
'Hockey': 0x0002,
'Football': 0x0004
}
if feature not in sport_to_code_map:
raise Exception(f'Unknown Sport: {sport_name}')
return sport_to_code_map.get(sport_name)
def sport_to_code(sports):
if not sports:
return None
code = 0x0000
for sport in sports:
code = code | get_sport_to_code(sport)
return code
import pyspark.sql.functions as F
sport_to_code_udf = F.udf(sport_to_code, F.StringType())
df.withColumn('sportsInt',sport_to_code_udf('sports'))

但是有没有什么方法可以用火花函数来实现呢?而不是udf?

Spark-2.4+,我们可以在这种情况下使用聚合高阶函数和bitwise or运算符。

Example:

from pyspark.sql.types import *
from pyspark.sql.functions import *
sport_to_code_map = {
'Cricket' : 0x0001,
'Hockey' : 0x0002,
'Football' : 0x0004
}
#creating dataframe from dictionary
lookup=spark.createDataFrame(*[zip(sport_to_code_map.keys(),sport_to_code_map.values())],["key","value"])
#sample dataframe
df.show(10,False)
#+---------+---------------------------+
#|country  |sports                     |
#+---------+---------------------------+
#|India    |[Cricket, Hockey, Football]|
#|Sri Lanka|[Cricket, Football]        |
#+---------+---------------------------+
df1=df.selectExpr("explode(sports) as key","country")
df2=df1.join(lookup,['key'],'left').
groupBy("country").
agg(collect_list(col("key")).alias("sports"),collect_list(col("value")).alias("sportsInt"))
df2.withColumn("sportsInt",expr('aggregate(sportsInt,0,(s,x) -> int(s) | int(x))')).
show(10,False)
#+---------+---------------------------+---------+
#|country  |sports                     |sportsInt|
#+---------+---------------------------+---------+
#|Sri Lanka|[Cricket, Football]        |5        |
#|India    |[Cricket, Hockey, Football]|7        |
#+---------+---------------------------+---------+

如果您想避免在sport_to_code_mapdict中查找联接,请使用.replace:

#converting dict values to string
sport_to_code_map={k:str(v) for k,v in sport_to_code_map.items()}
df1.replace(sport_to_code_map).show()
#+---+---------+
#|key|  country|
#+---+---------+
#|  1|    India|
#|  2|    India|
#|  4|    India|
#|  1|Sri Lanka|
#|  4|Sri Lanka|
#+---+---------+

最新更新