Pyspark从列中获取最新更新的值



我有一个数据帧如下:

+----+--------+--------+------+
| id | value1 | value2 | flag |
+----+--------+--------+------+
|  1 | 7000   | 30     |   0  |
|  2 | 0      | 9      |   0  |
|  3 | 23627  | 17     |   1  |
|  4 | 8373   | 23     |   0  |
|  5 | -0.5   | 4      |   1  |
+----+--------+--------+------+

我想运行以下条件-
1。如果值大于0,我希望前几行的值2
2。如果值等于0,我想要上一行和下一行的值2
3的平均值。如果值小于0,则为NULL
因此我编写了以下代码-

df = df.withColumn('value2',when(col(value1)>0,lag(col(value2))).when(col(value1)==0,
(lag(col(value2))+lead(col(value2)))/2.0).otherwise(None))

我想要的是,当我获取前一行和下一行的值时,应该有更新的值,如下所示。它应该按照找到它们的顺序进行,首先对id-1进行更新,然后对id-2进行更新值,依此类推

+----+--------+--------+------+
| id | value1 | value2 | flag |
+----+--------+--------+------+
|  1 | 7000   | null   |   0  |
|  2 | 0      | 8.5    |   0  |
|  3 | 23627  | 8.5    |   1  |
|  4 | 8373   | 8.5    |   0  |
|  5 | -0.5   | null   |   1  |
+----+--------+--------+------+

我试着在when,assign数据帧中只给出id==1,然后再次执行withcolumn,when操作。

df = df.withColumn('value2',when((col(id)==1)&(col(value1)>0,lag(col(value2)))
.when((col(id)==1)&col(value1)==0,(lag(col(value2))+lead(col(value2)))/2.0)
.when((col(id)==1)&col(col(value1)<0,None).otherwise(col(value2))

在这之后,我将获得更新的列值,如果我对id==2再次执行相同的操作,我可以获得更新的值。但我当然不能对每个id都这样做。我该如何做到这一点?

我认为完全不循环会很复杂。但是,您可以使用udf将工作划分为不同的执行器和panda中的子集。为了实现这一点,必须有足够的断点(即,值小于0且插入NULL的数据点(。

进口:

from pyspark.sql import Window
from pyspark.sql.functions import last
from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType
import pandas as pd
import numpy as np
from pyspark.sql.functions import col, lit, when

输入数据:

df = spark.createDataFrame([[ 1, 7000.0, 30.0 ], [ 2, 0.0, 9.0], [3, 23628.0, 17.0], [4, 8373.0, 23.0], [5, -0.5, 4.0]], [ 'id', 'value1', 'value2' ]).cache()

添加下一个值2,并在值小于0:时设置断点

dfwithnextvalue = df.alias("a").join(df.alias("b"), col("a.id") == col("b.id") - lit(1), 'left').select("a.*", col("b.value2").alias("nextvalue"))
dfstartnew = dfwithnextvalue.withColumn("startnew", when(col("value1") < lit(0), col("id")).otherwise(lit(None)))
.withColumn("startnew", when(col("id") == lit(1), lit(1)).otherwise(col("startnew")))
window = Window.orderBy('id')
rolled = last(col('startnew'), ignorenulls=True).over(window)
dfstartnewrolled = dfstartnew.withColumn("startnew", rolled)

现在我们可以按startnew列进行分组,并处理熊猫中的每一块。我对熊猫的了解不多,但这似乎奏效了:

@pandas_udf("id long, value1 double, value2 double", PandasUDFType.GROUPED_MAP)
def loopdata(df):
df = df.set_index('id').sort_index()
for i in range(0, len(df.index)):
if i == 0:
df.loc[df.index[0], 'value2'] = np.nan
elif df.loc[df.index[i], 'value1'] < 0:
df.loc[df.index[i], 'value2'] = np.nan
elif df.loc[df.index[i], 'value1'] > 0:
df.loc[df.index[i], 'value2'] = df.loc[df.index[i-1], 'value2']
else:
nextvalue = df.loc[df.index[i], 'nextvalue']
if pd.isna(nextvalue):
nextvalue = 0
prevvalue = df.loc[df.index[i-1], 'value2']
if pd.isna(prevvalue):
prevvalue = 0
df.loc[df.index[i], 'value2'] = (nextvalue + prevvalue)/2.0
df = df.drop(columns=['nextvalue', 'startnew'])
df = df.reset_index()
return df

现在你可以计算结果:

dfstartnewrolled.groupBy("startnew").apply(loopdata)
from pyspark.sql import SparkSession    
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.sql.window import Window

spark = SparkSession 
.builder 
.appName('test') 
.getOrCreate()

tab_data = spark.sparkContext.parallelize(tab_inp)
##
schema = StructType([StructField('id',IntegerType(),True),
StructField('value1',FloatType(),True),
StructField('value2',IntegerType(),True),
StructField('flag',IntegerType(),True)
])
table = spark.createDataFrame(tab_data,schema)
table.createOrReplaceTempView("table")
dummy_df=table.withColumn('dummy',lit('dummy'))
pre_value=dummy_df.withColumn('pre_value',lag(dummy_df['value2']).over(Window.partitionBy('dummy').orderBy('dummy')))
cmb_value=pre_value.withColumn('next_value',lead(dummy_df['value2']).over(Window.partitionBy('dummy').orderBy('dummy')))
new_column=when(col('value1')>0,cmb_value.pre_value) 
.when(col('value1')<0,cmb_value.next_value)
.otherwise((cmb_value.pre_value+cmb_value.next_value)/2)

final_table=cmb_value.withColumn('value',new_column)

上面的"final_table"将有您期望的字段。

最新更新