Pyspark爆炸多列与滑动窗口

  • 本文关键字:窗口 Pyspark python pyspark
  • 更新时间 :
  • 英文 :


我有一个类似的问题,但添加了额外的列需要应用,我需要知道什么元素是滑动窗口应用的列表的最后一个。

我举个例子:

给定df:

input_df = spark.createDataFrame([
(2,[1,2,3,4,5], ["a","b","c","c","b"], ["a","a","c","c","d"]),
], ("id", "target", "feature1", "feature2"))
input_df.show():
+---+---------------+---------------+---------------+
| id|         target|       feature1|       feature2|
+---+---------------+---------------+---------------+
|  2|[1, 2, 3, 4, 5]|[a, b, c, c, b]|[a, a, c, c, d]|
+---+---------------+---------------+---------------+

我想把每一行分成多行,行上有一个固定大小的滑动窗口。结果df将像这样:

output_df = spark.createDataFrame([
(2, [1,2], 3, ["a","b"], ["a","a"], False), (2, [2,3], 4, ["b","c"], ["a","c"], False), (2, [3,4], 5, ["c","c"], ["c","c"], True),
], ("id", "past-target", "future-target", "past-feature1", "past-feature2", "islast"))
output_df.show():
+---+-----------+-------------+-------------+-------------+------+
| id|past-target|future-target|past-feature1|past-feature2|islast|
+---+-----------+-------------+-------------+-------------+------+
|  2|     [1, 2]|            3|       [a, b]|       [a, a]| false|
|  2|     [2, 3]|            4|       [b, c]|       [a, c]| false|
|  2|     [3, 4]|            5|       [c, c]|       [c, c]|  true|
+---+-----------+-------------+-------------+-------------+------+

逻辑应该是取列["target", "feature1", "feature2"]并应用一个N的滑动窗口(作为参数给出,在本例中为2),其中指针放在N元素上,为列中值的过去索引创建一个列表为[past-target, past-feature1, past-feature2],当前值为future-target。特征列上的当前值可以忽略。

输出df的第一列是通过查看N之后的第一个索引创建的(第3个,因为n=2),将其用作future-target。然后查看["target", "feature1", "feature2"]列表上的第一个和第二个值,以创建[past-target, past-feature1, past-feature2][1,2], [a,b], [a,b]值。islast的值被设置为False,因为指针不是target的最后一个元素。为了创建output_df

,需要反复执行此操作。这是一个很难遵循的逻辑,我真的不知道如何用pyspark做到这一点,如果需要的话,我很乐意解释更多。

我认为您最好的选择是使用UDF,只需将所有逻辑包装在其中,就可以了。唯一棘手的部分是您必须返回一个元组数组,以便稍后可以将它们分解为多行。查看下面的演示代码。

# Pure Python function here
def func(target, feature1, feature2, n=2):
# implement your logic here
# ...
# this is just dummy return values
return [
([1, 2], 3, ['a', 'b'], ['a', 'a'], False),
([2, 3], 4, ['b', 'c'], ['a', 'c'], False),
([3, 4], 5, ['c', 'c'], ['c', 'c'], True),
]
# ... which you can do simple unittest
func([1, 2, 3, 4, 5], ['a', 'b', 'c', 'c', 'b'], ['a', 'a', 'c', 'c', 'd'])
# ... then use it as a UDF like this
from pyspark.sql import functions as F
from pyspark.sql import types as T
schema = T.ArrayType(T.StructType([
T.StructField('past_target', T.ArrayType(T.IntegerType())),
T.StructField('future_target', T.IntegerType()),
T.StructField('past_feature1', T.ArrayType(T.StringType())),
T.StructField('past_feature2', T.ArrayType(T.StringType())),
T.StructField('is_last', T.BooleanType()),
]))
(df
.withColumn('temp', F.explode(F.udf(func, schema)('target', 'feature1', 'feature2', F.lit(2))))
.select('id', 'temp.*')
.show()
)
# Output
# +---+-----------+-------------+-------------+-------------+-------+
# | id|past_target|future_target|past_feature1|past_feature2|is_last|
# +---+-----------+-------------+-------------+-------------+-------+
# |  2|     [1, 2]|            3|       [a, b]|       [a, a]|  false|
# |  2|     [2, 3]|            4|       [b, c]|       [a, c]|  false|
# |  2|     [3, 4]|            5|       [c, c]|       [c, c]|   true|
# +---+-----------+-------------+-------------+-------------+-------+

所以我找到了一个答案,支持也允许填充(添加额外的值到列表,如果我们想要所有的值一个可能的future_target)

代码:

filler_token = "-1"
main_feature = "target"
extra_features = ["feature1", "feature2"]
past_length = 2
expr = f'TRANSFORM({main_feature}, (element, i) -> STRUCT(TRANSFORM(sequence({past_length}, 1), k -> COALESCE({main_feature}[i - k], {filler_token})) AS pasts,' + 
''.join([f'TRANSFORM(sequence({past_length}, 1), k -> COALESCE({feature}[i - k], {filler_token})) AS past{feature}, ' for feature in extra_features]) + 
f'element AS future, 
i=SIZE({main_feature})-1 as final))'
output_df = (input_df
.withColumn(features[0], f.expr(expr))
.selectExpr('id', f'inline({features[0]})'))
padding=False
if(not padding):
output_df = output_df.filter(~array_contains(col("pasts"), filler_token)) #filter out those with -1

output_df.show()
+---+------+-------------+-------------+------+-----+
| id| pasts|pastsfeature1|pastsfeature1|future|final|
+---+------+-------------+-------------+------+-----+
|  2|[1, 2]|       [a, b]|       [a, a]|     3|false|
|  2|[2, 3]|       [b, c]|       [a, c]|     4|false|
|  2|[3, 4]|       [c, c]|       [c, c]|     5| true|
+---+------+-------------+-------------+------+-----+
# If we put padding=True it will be like this
+---+--------+-------------+-------------+------+-----+
| id|   pasts|pastsfeature1|pastsfeature1|future|final|
+---+--------+-------------+-------------+------+-----+
|  2|[-1, -1]|     [-1, -1]|     [-1, -1]|     1|false|
|  2| [-1, 1]|      [-1, a]|      [-1, a]|     2|false|
|  2|  [1, 2]|       [a, b]|       [a, a]|     3|false|
|  2|  [2, 3]|       [b, c]|       [a, c]|     4|false|
|  2|  [3, 4]|       [c, c]|       [c, c]|     5| true|
+---+--------+-------------+-------------+------+-----+

我没有使用udf,因为它们比这个解决方案慢得多。

最新更新