在PySpark中获取与窗口上的某个条件匹配的第一行



举个例子,假设我们有一个用户操作流,如下所示:

from pyspark.sql import *
spark = SparkSession.builder.appName('test').master('local[8]').getOrCreate()
df = spark.sparkContext.parallelize([
Row(user=1, action=1, time=1),
Row(user=1, action=1, time=2),
Row(user=2, action=1, time=3),
Row(user=1, action=2, time=4),
Row(user=2, action=2, time=5),
Row(user=2, action=2, time=6),
Row(user=1, action=1, time=7),
Row(user=2, action=1, time=8),
]).toDF()
df.show()

数据帧看起来像:

+----+------+----+
|user|action|time|
+----+------+----+
|   1|     1|   1|
|   1|     1|   2|
|   2|     1|   3|
|   1|     2|   4|
|   2|     2|   5|
|   2|     2|   6|
|   1|     1|   7|
|   2|     1|   8|
+----+------+----+

然后,我想在每行中添加一列next_alt_time,给出用户在以下行中更改操作类型的时间。对于上面的输入,输出应该是:

+----+------+----+-------------+
|user|action|time|next_alt_time|
+----+------+----+-------------+
|   1|     1|   1|            4|
|   1|     1|   2|            4|
|   2|     1|   3|            5|
|   1|     2|   4|            7|
|   2|     2|   5|            8|
|   2|     2|   6|            8|
|   1|     1|   7|         null|
|   2|     1|   8|         null|
+----+------+----+-------------+

我知道我可以创建这样的窗口:

wnd = Window().partitionBy('user').orderBy('time').rowsBetween(1, Window.unboundedFollowing)

但是,我不知道如何在窗口上施加条件,并在上面定义的窗口上选择与当前行具有不同操作的第一行。

下面是如何做到的。Spark无法保持数据帧的顺序,但如果你逐一检查行,你可以确认它给出了你期望的答案:

from pyspark.sql import Row
from pyspark.sql.window import Window
import pyspark.sql.functions as F
df = spark.sparkContext.parallelize([
Row(user=1, action=1, time=1),
Row(user=1, action=1, time=2),
Row(user=2, action=1, time=3),
Row(user=1, action=2, time=4),
Row(user=2, action=2, time=5),
Row(user=2, action=2, time=6),
Row(user=1, action=1, time=7),
Row(user=2, action=1, time=8),
]).toDF()
win = Window().partitionBy('user').orderBy('time')
df = df.withColumn('new_action', F.lag('action').over(win) != F.col('action'))
df = df.withColumn('new_action_time', F.when(F.col('new_action'), F.col('time')))
df = df.withColumn('next_alt_time', F.first('new_action', ignorenulls=True).over(win.rowsBetween(1, Window.unboundedFollowing)))
df.show()
+----+------+----+----------+---------------+-------------+
|user|action|time|new_action|new_action_time|next_alt_time|
+----+------+----+----------+---------------+-------------+
|   1|     1|   1|      null|           null|            4|
|   1|     1|   2|     false|           null|            4|
|   1|     2|   4|      true|              4|            7|
|   1|     1|   7|      true|              7|         null|
|   2|     1|   3|      null|           null|            5|
|   2|     2|   5|      true|              5|            8|
|   2|     2|   6|     false|           null|            8|
|   2|     1|   8|      true|              8|         null|
+----+------+----+----------+---------------+-------------+

最新更新