创建一个函数,在不使用只使用UDF的pyspark代码的情况下执行以下操作:
- 接受pyspark dataFrame和list of values = ['column1', 'column2']
+--+-------+--------+
|id|column1 |column2|
+--+-------+--------+
|1 | 2.0 | 3.0 |
|2 | 1.0 | 3.0 |
|3 | -1.0 | 3.0 |
|4 | -1.0 | -1.0 |
+--+-------+-------+
- 创建一个列,该列的值中不包含-1
+--+-------+--------+-----+
|id|column1 |column2|count|
+--+-------+--------+-----+
|1 | 2.0 | 3.0 | 2 |
|2 | 1.0 | 3.0 | 2 |
|3 | -1.0 | 3.0 | 1 |
|4 | -1.0 | -1.0 | 0 |
+--+-------+-------+-----+
- 创建一个列,该列的值不包含-1
+--+-------+--------+-----+------+
|id|column1 |column2|count|sum |
+--+-------+--------+-----+------+
|1 | 2.0 | 3.0 | 2 | 5.0 |
|2 | 1.0 | 3.0 | 2 | 4.0 |
|3 | -1.0 | 3.0 | 1 | 3.0 |
|4 | -1.0 | -1.0 | 0 | 0.0 |
+--+-------+-------+-----+------+
- 创建一个列,该列的平均值值不包含-1
+--+-------+--------+-----+------+------+
|id|column1 |column2|count|sum |avg |
+--+-------+--------+-----+------+------+
|1 | 2.0 | 3.0 | 2 | 5.0 | 2.5 |
|2 | 1.0 | 3.0 | 2 | 4.0 | 2.0 |
|3 | -1.0 | 3.0 | 1 | 3.0 | 3.0 |
|4 | -1.0 | -1.0 | 0 | 0.0 | 0.0 |
+--+-------+-------+-----+------+------+
我已经使用udf
做到了def average_columns(columns: list) -> float:
"""
This function calculates the average conflict score between the different deduplication methods.
Parameters
----------
values : list
The list of column names to be compared.
Returns
-------
float
The average conflict score.
"""
values = [val for val in columns if val != -1]
return sum(values) / len(values) if values else float(0)
average_columns_udf = F.udf(avg_conflict, T.FloatType())
但我想只使用PySpark函数这个代码是sudo代码,它不能工作
from pyspark.sql import DataFrame
pyspark.sql.functions as F
values = ['column1', 'column2']
def average_columns(df: DataFrame,
values: list): -> DataFrame
return df.withColumn('count', F.sum(F.when(F.col(value) != -1, 1).otherwise(0)) for value in values)
.withColumn('sum', F.sum(F.when(F.col(value) != -1, F.col(value)).otherwise(0)) for value in values)
.withColumn('avg', F.col('sum') / F.col('count'))
高阶函数会有帮助。
下面的代码和逻辑new =(df.withColumn('ABCD', array(*[x for x in df.columns if x!='id']))#Create an array
.withColumn('count', expr("aggregate(transform(ABCD, (c,i)-> cast(c=='-1' as int)),0, (j,k)-> j+k )" ))#Find where there is -1 using transform, cast bool to integer and sum
.withColumn('sum', expr("aggregate(ABCD, cast(0 as bigint), (x,y)->x+y)"))#Sum elements in a list
.withColumn('avg', col('sum')/size(col('ABCD')))#Sum/over count of ABCD
).show()
new.show()
+---+-------+-------+--------+-----+---+----+
| id|column1|column2| ABCD|count|sum| avg|
+---+-------+-------+--------+-----+---+----+
| 1| 2| 3| [2, 3]| 0| 5| 2.5|
| 2| 1| 3| [1, 3]| 0| 4| 2.0|
| 3| -1| 3| [-1, 3]| 1| 2| 1.0|
| 4| -1| -1|[-1, -1]| 2| -2|-1.0|
+---+-------+-------+--------+-----+---+----+