如何为pyspark.sql.dataframe.DataFrame编写这个pandas逻辑而不使用spark API上



我对Pyspark完全陌生,因为Pyspark没有loc功能,我们如何编写这个逻辑。我尝试通过指定条件,但不能得到理想的结果,任何帮助将非常感激!

df['Total'] = (df['level1']+df['level2']+df['level3']+df['level4'])/df['Number']
df.loc[df['level4'] > 0, 'Total'] += 4
df.loc[((df['level3'] > 0) & (df['Total'] < 1)), 'Total'] += 3
df.loc[((df['level2'] > 0) & (df['Total'] < 1)), 'Total'] += 2
df.loc[((df['level1'] > 0) & (df['Total'] < 1)), 'Total'] += 1

对于如下数据

data_ls = [
(1, 1, 1, 1, 10),
(5, 5, 5, 5, 10)
]
data_sdf = spark.sparkContext.parallelize(data_ls). 
toDF(['level1', 'level2', 'level3', 'level4', 'number'])
# +------+------+------+------+------+
# |level1|level2|level3|level4|number|
# +------+------+------+------+------+
# |     1|     1|     1|     1|    10|
# |     5|     5|     5|     5|    10|
# +------+------+------+------+------+

你实际上是在每个语句中更新total列,而不是以if-then-else的方式。您的代码可以在pyspark中使用多个withColumn()when()进行复制,如下所示。

data_sdf. 
withColumn('total', (func.col('level1') + func.col('level2') + func.col('level3') + func.col('level4')) / func.col('number')). 
withColumn('total', func.when(func.col('level4') > 0, func.col('total') + 4).otherwise(func.col('total'))). 
withColumn('total', func.when((func.col('level3') > 0) & (func.col('total') < 1), func.col('total') + 3).otherwise(func.col('total'))). 
withColumn('total', func.when((func.col('level2') > 0) & (func.col('total') < 1), func.col('total') + 2).otherwise(func.col('total'))). 
withColumn('total', func.when((func.col('level1') > 0) & (func.col('total') < 1), func.col('total') + 1).otherwise(func.col('total'))). 
show()
# +------+------+------+------+------+-----+
# |level1|level2|level3|level4|number|total|
# +------+------+------+------+------+-----+
# |     1|     1|     1|     1|    10|  4.4|
# |     5|     5|     5|     5|    10|  6.0|
# +------+------+------+------+------+-----+

我们可以将所有的withColumn()when()合并成一个withColumn()和多个when()语句。

data_sdf. 
withColumn('total', (func.col('level1') + func.col('level2') + func.col('level3') + func.col('level4')) / func.col('number')). 
withColumn('total', 
func.when(func.col('level4') > 0, func.col('total') + 4).
when((func.col('level3') > 0) & (func.col('total') < 1), func.col('total') + 3).
when((func.col('level2') > 0) & (func.col('total') < 1), func.col('total') + 2).
when((func.col('level1') > 0) & (func.col('total') < 1), func.col('total') + 1).
otherwise(func.col('total'))
). 
show()
# +------+------+------+------+------+-----+
# |level1|level2|level3|level4|number|total|
# +------+------+------+------+------+-----+
# |     1|     1|     1|     1|    10|  4.4|
# |     5|     5|     5|     5|    10|  6.0|
# +------+------+------+------+------+-----+

它就像numpy.where和SQL的case语句。

最新更新