PySpark:我们应该迭代更新数据帧吗?



我的问题分为两部分。 第一个是了解Spark的工作方式,第二个是优化。

我有一个具有多个分类变量的火花数据帧。对于这些分类变量中的每一个,我都添加了一个新列,其中每一行都是相应水平的频率。

例如

Date_Built  Square_Footage  Num_Beds    Num_Baths   State   Price     Freq_State
01/01/1920  1700            3           2           NY      700000    4500

在这里,对于State(一个分类变量),我正在添加一个新变量Freq_State。级别NY在数据集中出现4500次,因此此行在Freq_State列中4500

我有多个这样的列,我在其中添加相应级别的列轴承频率。

这是我用于实现此目的的代码

def calculate_freq(df, categorical_cols):
for each_cat_col in categorical_cols:
_freq = df.select(each_cat_col).groupBy(each_cat_col).count()
df = df.join(_freq, each_cat_col, "inner")
return df

第 1 部分

在这里,如您所见,我正在更新for循环中的数据帧。在群集上运行此代码时,是否建议以这种方式更新数据帧?如果它是一个熊猫数据帧,我就不会担心这个。但我不确定上下文何时变为火花。

另外,如果我只是在循环中而不是在函数中运行上述进程,会有什么不同吗?

第 2 部分

有没有更优化的方法可以做到这一点?每次进入循环时,我都会加入这里?这能避免吗

有没有更优化的方法可以做到这一点?

有哪些可能的替代方案?

  1. 您可以使用窗口函数

    def calculate_freq(df, categorical_cols):
    for cat_col in categorical_cols:
    w = Window.partitionBy(cat_col)
    df = df.withColumn("{}_freq".format(each_cat_col), count("*").over(w))
    return df
    

    你应该吗?不。与join不同,它始终需要对非聚合DataFrame进行完全洗牌。

  2. 您可以melt并使用单个本地对象(这要求所有分类列的类型相同):

    from itertools import groupby
    for c in categorical_cols:
    df = df.withColumn(c, df[c].cast("string"))
    
    rows = (melt(df, id_vars=[], value_vars=categorical_cols)
    .groupBy("variable", "value").count().collect())
    mapping = {k: {x.value: x["count"] for x in v} 
    for k, v in groupby(sorted(rows), lambda x: x.variable)}
    

    并使用udf添加值:

    from pyspark.sql.functions import udf
    def get_count(mapping_c):
    @udf("bigint")
    def _(x):
    return mapping_c.get(x)
    return _
    
    for c in categorical_cols:
    df = df.withColumn("{}_freq".format(c), get_count(mapping[c])(c))
    

    你应该吗?或。与迭代联接不同,它只需要一个操作来计算所有统计信息。如果结果很小(使用分类变量预期),则可以获得适度的性能提升。

  3. 添加broadcast提示。

    from pyspark.sql.functions import broadcast
    def calculate_freq(df, categorical_cols):
    for each_cat_col in categorical_cols:
    _freq = df.select(each_cat_col).groupBy(each_cat_col).count()
    df = df.join(broadcast(_freq), each_cat_col, "inner")
    return df
    

    Spark应该自动广播,所以它不应该改变任何事情,但帮助规划总是更好的。

另外,如果我只是在循环中而不是在函数中运行上述进程,会有什么不同吗?

忽略代码可维护性和可测试性,它不会。

最新更新