我有一个数据帧dataframe_actions
,其中包含以下字段:user_id
,action
,day
。user_id
对于每个用户都是唯一的,day
采用 1 到 31 之间的值。我想只过滤掉至少连续 2 天看到的用户,例如:
如果在第 1、2、4、8、9 天看到用户,我想保留他们,因为他们至少连续 2 天被看到。
我现在正在做的很笨重而且非常慢(而且似乎不起作用(:
df_final = spark.sql(""" with t1( select user_id, day, row_number()
over(partition by user_id order by day)-day diff from dataframe_actions),
t2( select user_id, day, collect_set(diff) over(partition by user_id) diff2 from t1)
select user_id, day from t2 where size(diff2) > 2""")
类似的东西,但我不知道如何真正解决这个问题。
编辑:
| user_id | action | day |
--------------------------
| asdc24 | conn | 1 |
| asdc24 | conn | 2 |
| asdc24 | conn | 5 |
| adsfa6 | conn | 1 |
| adsfa6 | conn | 3 |
| asdc24 | conn | 9 |
| adsfa6 | conn | 5 |
| asdc24 | conn | 11 |
| adsfa6 | conn | 10 |
| asdc24 | conn | 15 |
应该返回
| user_id | action | day |
--------------------------
| asdc24 | conn | 1 |
| asdc24 | conn | 2 |
| asdc24 | conn | 5 |
| asdc24 | conn | 9 |
| asdc24 | conn | 11 |
| asdc24 | conn | 15 |
因为只有此用户至少连续两天(第 1 天和第 2 天(连接。
使用 lag
获取每个用户的前一天,从当前行的日期中减去它,然后检查其中是否至少有一个是 1。这是通过group by
和之后的filter
来完成的。
from pyspark.sql import functions as f
from pyspark.sql import Window
w = Window.partitionBy(dataframe_actions.user_id).orderBy(dataframe_actions.day)
user_prev = dataframe_actions.withColumn('prev_day_diff',dataframe_actions.day-f.lag(dataframe_actions.day).over(w))
res = user_prev.groupBy(user_prev.user_id).agg(f.sum(f.when(user_prev.prev_day_diff==1,1).otherwise(0)).alias('diff_1'))
res.filter(res.diff_1 >= 1).show()
另一种行号差异方法的方法。这将允许为给定user_id选择所有列。
w = Window.partitionBy(dataframe_actions.user_id).orderBy(dataframe_actions.day)
rownum_diff = dataframe_actions.withColumn('rdiff',day-f.row_number().over(w))
w1 = Window.partitionBy(rownum_diff.user_id)
counts_per_user = rownum_diff.withColumn('cnt',f.sum(f.when(rownum_diff.rdiff == 1,1).otherwise(0)).over(w1))
cols_to_select = ['user_id','action','day']
counts_per_user.filter(counts_per_user.cnt >= 1).select(*cols_to_select).show()
使用给定输入的SQL方法。
皮斯帕克
>>> from pyspark.sql.functions import *
>>> df = sc.parallelize([("asdc24","conn",1),
... ("asdc24","conn",2),
... ("asdc24","conn",5),
... ("adsfa6","conn",1),
... ("adsfa6","conn",3),
... ("asdc24","conn",9),
... ("adsfa6","conn",5),
... ("asdc24","conn",11),
... ("adsfa6","conn",10),
... ("asdc24","conn",15)]).toDF(["user_id","action","day"])
>>> df.createOrReplaceTempView("qubix")
>>> spark.sql(" select * from qubix order by user_id, day").show()
+-------+------+---+
|user_id|action|day|
+-------+------+---+
| adsfa6| conn| 1|
| adsfa6| conn| 3|
| adsfa6| conn| 5|
| adsfa6| conn| 10|
| asdc24| conn| 1|
| asdc24| conn| 2|
| asdc24| conn| 5|
| asdc24| conn| 9|
| asdc24| conn| 11|
| asdc24| conn| 15|
+-------+------+---+
>>> spark.sql(""" with t1 (select user_id,action, day,lead(day) over(partition by user_id order by day) ld from qubix), t2 (select user_id from t1 where ld-t1.day=1 ) select * from qubix where user_id in (select user_id from t2) """).show()
+-------+------+---+
|user_id|action|day|
+-------+------+---+
| asdc24| conn| 1|
| asdc24| conn| 2|
| asdc24| conn| 5|
| asdc24| conn| 9|
| asdc24| conn| 11|
| asdc24| conn| 15|
+-------+------+---+
>>>
斯卡拉版本
scala> val df = Seq(("asdc24","conn",1),
| ("asdc24","conn",2),
| ("asdc24","conn",5),
| ("adsfa6","conn",1),
| ("adsfa6","conn",3),
| ("asdc24","conn",9),
| ("adsfa6","conn",5),
| ("asdc24","conn",11),
| ("adsfa6","conn",10),
| ("asdc24","conn",15)).toDF("user_id","action","day")
df: org.apache.spark.sql.DataFrame = [user_id: string, action: string ... 1 more field]
scala> df.orderBy('user_id,'day).show(false)
+-------+------+---+
|user_id|action|day|
+-------+------+---+
|adsfa6 |conn |1 |
|adsfa6 |conn |3 |
|adsfa6 |conn |5 |
|adsfa6 |conn |10 |
|asdc24 |conn |1 |
|asdc24 |conn |2 |
|asdc24 |conn |5 |
|asdc24 |conn |9 |
|asdc24 |conn |11 |
|asdc24 |conn |15 |
+-------+------+---+
scala> df.createOrReplaceTempView("qubix")
scala> spark.sql(""" with t1 (select user_id,action, day,lead(day) over(partition by user_id order by day) ld from qubix), t2 (select user_id fro
m t1 where ld-t1.day=1 ) select * from qubix where user_id in (select user_id from t2) """).show(false)
+-------+------+---+
|user_id|action|day|
+-------+------+---+
|asdc24 |conn |1 |
|asdc24 |conn |2 |
|asdc24 |conn |5 |
|asdc24 |conn |9 |
|asdc24 |conn |11 |
|asdc24 |conn |15 |
+-------+------+---+
scala>