我的问题分为两部分。 第一个是了解Spark的工作方式,第二个是优化。
我有一个具有多个分类变量的火花数据帧。对于这些分类变量中的每一个,我都添加了一个新列,其中每一行都是相应水平的频率。
例如
Date_Built Square_Footage Num_Beds Num_Baths State Price Freq_State
01/01/1920 1700 3 2 NY 700000 4500
在这里,对于State
(一个分类变量),我正在添加一个新变量Freq_State
。级别NY
在数据集中出现4500
次,因此此行在Freq_State
列中4500
。
我有多个这样的列,我在其中添加相应级别的列轴承频率。
这是我用于实现此目的的代码
def calculate_freq(df, categorical_cols):
for each_cat_col in categorical_cols:
_freq = df.select(each_cat_col).groupBy(each_cat_col).count()
df = df.join(_freq, each_cat_col, "inner")
return df
第 1 部分
在这里,如您所见,我正在更新for
循环中的数据帧。在群集上运行此代码时,是否建议以这种方式更新数据帧?如果它是一个熊猫数据帧,我就不会担心这个。但我不确定上下文何时变为火花。
另外,如果我只是在循环中而不是在函数中运行上述进程,会有什么不同吗?
第 2 部分
有没有更优化的方法可以做到这一点?每次进入循环时,我都会加入这里?这能避免吗
有没有更优化的方法可以做到这一点?
有哪些可能的替代方案?
-
您可以使用窗口函数
def calculate_freq(df, categorical_cols): for cat_col in categorical_cols: w = Window.partitionBy(cat_col) df = df.withColumn("{}_freq".format(each_cat_col), count("*").over(w)) return df
你应该吗?不。与
join
不同,它始终需要对非聚合DataFrame
进行完全洗牌。 -
您可以
melt
并使用单个本地对象(这要求所有分类列的类型相同):from itertools import groupby for c in categorical_cols: df = df.withColumn(c, df[c].cast("string")) rows = (melt(df, id_vars=[], value_vars=categorical_cols) .groupBy("variable", "value").count().collect()) mapping = {k: {x.value: x["count"] for x in v} for k, v in groupby(sorted(rows), lambda x: x.variable)}
并使用
udf
添加值:from pyspark.sql.functions import udf def get_count(mapping_c): @udf("bigint") def _(x): return mapping_c.get(x) return _ for c in categorical_cols: df = df.withColumn("{}_freq".format(c), get_count(mapping[c])(c))
你应该吗?或。与迭代联接不同,它只需要一个操作来计算所有统计信息。如果结果很小(使用分类变量预期),则可以获得适度的性能提升。
-
添加
broadcast
提示。from pyspark.sql.functions import broadcast def calculate_freq(df, categorical_cols): for each_cat_col in categorical_cols: _freq = df.select(each_cat_col).groupBy(each_cat_col).count() df = df.join(broadcast(_freq), each_cat_col, "inner") return df
Spark应该自动广播,所以它不应该改变任何事情,但帮助规划总是更好的。
另外,如果我只是在循环中而不是在函数中运行上述进程,会有什么不同吗?
忽略代码可维护性和可测试性,它不会。