我有以下DF:
+--------------+---+----+
|Date |Id |Cond|
+--------------+---+----+
| 2022-01-08| 1| 0|
| 2022-01-10| 1| 0|
| 2022-01-11| 1| 0|
| 2022-01-12| 1| 0|
| 2022-01-13| 1| 0|
| 2022-01-15| 1| 0|
| 2022-01-18| 1| 0|
| 2022-01-19| 1| 0|
| 2022-01-08| 2| 0|
| 2022-01-11| 2| 0|
| 2022-01-12| 2| 0|
| 2022-01-15| 2| 0|
| 2022-01-16| 2| 0|
| 2022-01-17| 2| 0|
| 2022-01-19| 2| 0|
| 2022-01-20| 2| 0|
+--------------+---+----+
+--------------+---+----+
|Date |Id |Cond|
+--------------+---+----+
| 2022-01-09| 1| 1|
| 2022-01-14| 1| 1|
| 2022-01-16| 1| 1|
| 2022-01-17| 1| 1|
| 2022-01-20| 1| 1|
| 2022-01-09| 2| 1|
| 2022-01-10| 2| 1|
| 2022-01-13| 2| 1|
| 2022-01-14| 2| 1|
| 2022-01-18| 2| 1|
+--------------+---+----+
我想获取 DF1 的前 2 个日期,该日期在 DF2 中具有顺序。
例:
对于DF1中的日期"2022-01-15"
和Id = 1
,我需要从DF2收集"2022-01-14"
和"2022-01-09"
的日期。
我的预期输出:
+--------------+---+------------------------------+
|Date |Id |List |
+--------------+---+------------------------------+
| 2022-01-08| 1| [] |
| 2022-01-10| 1| ['2022-01-09'] |
| 2022-01-11| 1| ['2022-01-09'] |
| 2022-01-12| 1| ['2022-01-09'] |
| 2022-01-13| 1| ['2022-01-09'] |
| 2022-01-15| 1| ['2022-01-14', '2022-01-09']|
| 2022-01-18| 1| ['2022-01-17', '2022-01-16']|
| 2022-01-19| 1| ['2022-01-17', '2022-01-16']|
| 2022-01-08| 2| [] |
| 2022-01-11| 2| ['2022-01-10', '2022-01-09']|
| 2022-01-12| 2| ['2022-01-10', '2022-01-09']|
| 2022-01-15| 2| ['2022-01-14', '2022-01-13']|
| 2022-01-16| 2| ['2022-01-14', '2022-01-13']|
| 2022-01-17| 2| ['2022-01-14', '2022-01-13']|
| 2022-01-19| 2| ['2022-01-18', '2022-01-14']|
| 2022-01-20| 2| ['2022-01-18', '2022-01-14']|
+--------------+---+------------------------------+
我知道我可以使用collect_list
将日期作为列表获取,但是如何按范围收集?
MVCE:
data_1 = [
("2022-01-08", 1, 0),
("2022-01-10", 1, 0),
("2022-01-11", 1, 0),
("2022-01-12", 1, 0),
("2022-01-13", 1, 0),
("2022-01-15", 1, 0),
("2022-01-18", 1, 0),
("2022-01-19", 1, 0),
("2022-01-08", 2, 0),
("2022-01-11", 2, 0),
("2022-01-12", 2, 0),
("2022-01-15", 2, 0),
("2022-01-16", 2, 0),
("2022-01-17", 2, 0),
("2022-01-19", 2, 0),
("2022-01-20", 2, 0)
]
schema_1 = StructType([
StructField("Date", StringType(), True),
StructField("Id", IntegerType(), True),
StructField("Cond", IntegerType(), True)
])
df_1 = spark.createDataFrame(data=data_1, schema=schema_1)
data_2 = [
("2022-01-09", 1, 1),
("2022-01-14", 1, 1),
("2022-01-16", 1, 1),
("2022-01-17", 1, 1),
("2022-01-20", 1, 1),
("2022-01-09", 2, 1),
("2022-01-10", 2, 1),
("2022-01-13", 2, 1),
("2022-01-14", 2, 1),
("2022-01-18", 2, 1)
]
schema_2 = StructType([
StructField("Date", StringType(), True),
StructField("Id", IntegerType(), True),
StructField("Cond", IntegerType(), True)
])
df_2 = spark.createDataFrame(data=data_2, schema=schema_2)
您可以通过以下方式完成此操作:
- 在
Id
上连接两个表; - 有条件地从
df_2
收集早于目标日期的日期df_1
(默认情况下collect_list
忽略空值);和 - 结合使用
slice
和sort_array
仅保留最近的两个日期。
import pyspark.sql.functions as F
df_out = df_1
.join(df_2.select(F.col("Date").alias("Date_RHS"), "Id"), on="Id", how="inner")
.groupBy("Date", "Id")
.agg(F.collect_list(F.when(F.col("Date_RHS") < F.col("Date"), F.col("Date_RHS")).otherwise(F.lit(None))).alias("List"))
.select("Date", "Id", F.slice(F.sort_array(F.col("List"), asc=False), start=1, length=2).alias("List"))
# +----------+---+------------------------+
# |Date |Id |List |
# +----------+---+------------------------+
# |2022-01-08|1 |[] |
# |2022-01-10|1 |[2022-01-09] |
# |2022-01-11|1 |[2022-01-09] |
# |2022-01-12|1 |[2022-01-09] |
# |2022-01-13|1 |[2022-01-09] |
# |2022-01-15|1 |[2022-01-14, 2022-01-09]|
# |2022-01-18|1 |[2022-01-17, 2022-01-16]|
# |2022-01-19|1 |[2022-01-17, 2022-01-16]|
# |2022-01-08|2 |[] |
# |2022-01-11|2 |[2022-01-10, 2022-01-09]|
# |2022-01-12|2 |[2022-01-10, 2022-01-09]|
# |2022-01-15|2 |[2022-01-14, 2022-01-13]|
# |2022-01-16|2 |[2022-01-14, 2022-01-13]|
# |2022-01-17|2 |[2022-01-14, 2022-01-13]|
# |2022-01-19|2 |[2022-01-18, 2022-01-14]|
# |2022-01-20|2 |[2022-01-18, 2022-01-14]|
# +----------+---+------------------------+
以下方法将首先聚合df_2
,然后执行左连接。然后,使用高阶函数filter
过滤掉大于"日期"列的日期,slice
从数组中选择 2 个最大值。
from pyspark.sql import functions as F
df = df_1.join(df_2.groupBy('Id').agg(F.collect_set('Date').alias('d2')), 'Id', 'left')
df = df.select(
'Date', 'Id',
F.slice(F.sort_array(F.filter('d2', lambda x: x < F.col('Date')), False), 1, 2).alias('List')
)
df.show(truncate=0)
# +----------+---+------------------------+
# |Date |Id |List |
# +----------+---+------------------------+
# |2022-01-08|1 |[] |
# |2022-01-10|1 |[2022-01-09] |
# |2022-01-11|1 |[2022-01-09] |
# |2022-01-12|1 |[2022-01-09] |
# |2022-01-13|1 |[2022-01-09] |
# |2022-01-15|1 |[2022-01-14, 2022-01-09]|
# |2022-01-18|1 |[2022-01-17, 2022-01-16]|
# |2022-01-19|1 |[2022-01-17, 2022-01-16]|
# |2022-01-08|2 |[] |
# |2022-01-11|2 |[2022-01-10, 2022-01-09]|
# |2022-01-12|2 |[2022-01-10, 2022-01-09]|
# |2022-01-15|2 |[2022-01-14, 2022-01-13]|
# |2022-01-16|2 |[2022-01-14, 2022-01-13]|
# |2022-01-17|2 |[2022-01-14, 2022-01-13]|
# |2022-01-19|2 |[2022-01-18, 2022-01-14]|
# |2022-01-20|2 |[2022-01-18, 2022-01-14]|
# +----------+---+------------------------+
对于较低的 Spark 版本,请使用以下命令:
from pyspark.sql import functions as F
df = df_1.join(df_2.groupBy('Id').agg(F.collect_set('Date').alias('d2')), 'Id', 'left')
df = df.select(
'Date', 'Id',
F.slice(F.sort_array(F.expr("filter(d2, x -> x < Date)"), False), 1, 2).alias('List')
)