在 pyspark 中计算分组中位数



使用 pyspark 时,我希望能够计算分组值与其组中位数之间的差异。 这可能吗? 这是我破解的一些代码,它做了我想要的,除了它计算分组差异与平均值。 另外,如果您觉得有帮助,请随时评论我如何使它变得更好:)

from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    StringType,
    LongType,
    DoubleType,
    StructField,
    StructType
)
from pyspark.sql import functions as F

sc = SparkContext(appName='myapp')
spark = SparkSession(sc)
file_name = 'data.csv'
fields = [
    StructField(
        'group2',
        LongType(),
        True),
    StructField(
        'name',
        StringType(),
        True),
    StructField(
        'value',
        DoubleType(),
        True),
    StructField(
        'group1',
        LongType(),
        True)
]
schema = StructType(fields)
df = spark.read.csv(
    file_name, header=False, mode="DROPMALFORMED", schema=schema
)
df.show()
means = df.select([
    'group1',
    'group2',
    'name',
    'value']).groupBy([
        'group1',
        'group2'
    ]).agg(
        F.mean('value').alias('mean_value')
    ).orderBy('group1', 'group2')
cond = [df.group1 == means.group1, df.group2 == means.group2]
means.show()
df = df.select([
    'group1',
    'group2',
    'name',
    'value']).join(
        means,
        cond
    ).drop(
        df.group1
    ).drop(
        df.group2
    ).select('group1',
             'group2',
             'name',
             'value',
             'mean_value')
final = df.withColumn(
    'diff',
    F.abs(df.value - df.mean_value))
final.show()
sc.stop()

这是我正在使用的示例数据集:

100,name1,0.43,0
100,name2,0.33,0
100,name3,0.73,0
101,name1,0.29,0
101,name2,0.96,0
101,name3,0.42,0
102,name1,0.01,0
102,name2,0.42,0
102,name3,0.51,0
103,name1,0.55,0
103,name2,0.45,0
103,name3,0.02,0
104,name1,0.93,0
104,name2,0.16,0
104,name3,0.74,0
105,name1,0.41,0
105,name2,0.65,0
105,name3,0.29,0
100,name1,0.51,1
100,name2,0.51,1
100,name3,0.43,1
101,name1,0.59,1
101,name2,0.55,1
101,name3,0.84,1
102,name1,0.01,1
102,name2,0.98,1
102,name3,0.44,1
103,name1,0.47,1
103,name2,0.16,1
103,name3,0.02,1
104,name1,0.83,1
104,name2,0.89,1
104,name3,0.31,1
105,name1,0.59,1
105,name2,0.77,1
105,name3,0.45,1

这是我正在尝试制作的:

group1,group2,name,value,median,diff
0,100,name1,0.43,0.43,0.0
0,100,name2,0.33,0.43,0.10
0,100,name3,0.73,0.43,0.30
0,101,name1,0.29,0.42,0.13
0,101,name2,0.96,0.42,0.54
0,101,name3,0.42,0.42,0.0
0,102,name1,0.01,0.42,0.41
0,102,name2,0.42,0.42,0.0
0,102,name3,0.51,0.42,0.09
0,103,name1,0.55,0.45,0.10
0,103,name2,0.45,0.45,0.0
0,103,name3,0.02,0.45,0.43
0,104,name1,0.93,0.74,0.19
0,104,name2,0.16,0.74,0.58
0,104,name3,0.74,0.74,0.0
0,105,name1,0.41,0.41,0.0
0,105,name2,0.65,0.41,0.24
0,105,name3,0.29,0.41,0.24
1,100,name1,0.51,0.51,0.0
1,100,name2,0.51,0.51,0.0
1,100,name3,0.43,0.51,0.08
1,101,name1,0.59,0.59,0.0
1,101,name2,0.55,0.59,0.04
1,101,name3,0.84,0.59,0.25
1,102,name1,0.01,0.44,0.43
1,102,name2,0.98,0.44,0.54
1,102,name3,0.44,0.44,0.0
1,103,name1,0.47,0.16,0.31
1,103,name2,0.16,0.16,0.0
1,103,name3,0.02,0.16,0.14
1,104,name1,0.83,0.83,0.0
1,104,name2,0.89,0.83,0.06
1,104,name3,0.31,0.83,0.52
1,105,name1,0.59,0.59,0.0
1,105,name2,0.77,0.59,0.18
1,105,name3,0.45,0.59,0.14
您可以使用

udf函数median来解决它。首先,让我们创建上面给出的简单示例。

# example data
ls = [[100,'name1',0.43,0],
      [100,'name2',0.33,0],
      [100,'name3',0.73,0],
      [101,'name1',0.29,0],
      [101,'name2',0.96,0],
      [...]]
df = spark.createDataFrame(ls, schema=['a', 'b', 'c', 'd'])

这是用于计算中位数的udf函数

# udf for median
import numpy as np
import pyspark.sql.functions as func
def median(values_list):
    med = np.median(values_list)
    return float(med)
udf_median = func.udf(median, FloatType())
group_df = df.groupby(['a', 'd'])
df_grouped = group_df.agg(udf_median(func.collect_list(col('c'))).alias('median'))
df_grouped.show()

最后,您可以使用原始df将其连接回来,以便恢复中位数列。

df_grouped = df_grouped.withColumnRenamed('a', 'a_').withColumnRenamed('d', 'd_')
df_final = df.join(df_grouped, [df.a == df_grouped.a_, df.d == df_grouped.d_]).select('a', 'b', 'c', 'median')
df_final = df_final.withColumn('diff', func.round(func.col('c') - func.col('median'), scale=2))

请注意,我在末尾使用round来防止在中位数操作后出现额外的数字。

相关内容

  • 没有找到相关文章

最新更新