所以,我已经做了足够的研究,还没有找到一篇关于我想做的事情的帖子。
我有一个 PySpark 数据帧my_df
它由value
列sorted
+----+-----+
|name|value|
+----+-----+
| A| 30|
| B| 25|
| C| 20|
| D| 18|
| E| 18|
| F| 15|
| G| 10|
+----+-----+
value
列中所有计数的总和等于136
。我想获取其combined values >= x% of 136
的所有行.在此示例中,假设x=80
.然后target sum = 0.8*136 = 108.8
.因此,新的数据帧将包含具有combined value >= 108.8
的所有行。
在我们的示例中,这将归结为第D
行(因为组合值高达 D =30+25+20+18 = 93
)。
但是,困难的部分是我还想包含具有重复值的紧随其后的行。在这种情况下,我还想包含行E
,因为它与行D
具有相同的值,即18
.
我想通过给出一个百分比x
变量来切片my_df
,例如如上所述80
。新的数据帧应包含以下行-
+----+-----+
|name|value|
+----+-----+
| A| 30|
| B| 25|
| C| 20|
| D| 18|
| E| 18|
+----+-----+
我可以在这里做的一件事是遍历数据帧(which is ~360k rows)
,但我想这违背了 Spark 的目的。
这里有我想要的简洁功能吗?
使用 pyspark SQL 函数简洁地执行此操作。
result = my_df.filter(my_df.value > target).select(my_df.name,my_df.value)
result.show()
编辑:根据OP的问题编辑 - 计算运行总和并获取行,直到达到目标值。请注意,这将导致最多为 D 的行,而不是 E。这似乎是一个奇怪的要求。
from pyspark.sql import Window
from pyspark.sql import functions as f
# Total sum of all `values`
target = (my_df.agg(sum("value")).collect())[0][0]
w = Window.orderBy(my_df.name) #Ideally this should be a column that specifies ordering among rows
running_sum_df = my_df.withColumn('rsum',f.sum(my_df.value).over(w))
running_sum_df.filter(running_sum_df.rsum <= 0.8*target)
您的要求非常严格,因此很难为您的问题制定有效的解决方案。不过,这里有一种方法:
首先计算value
列的累积总和和总和,并使用指定的目标条件的百分比筛选数据帧。我们把这个结果称为df_filtered
:
import pyspark.sql.functions as f
from pyspark.sql import Window
w = Window.orderBy(f.col("value").desc(), "name").rangeBetween(Window.unboundedPreceding, 0)
target = 0.8
df_filtered = df.withColumn("cum_sum", f.sum("value").over(w))
.withColumn("total_sum", f.sum("value").over(Window.partitionBy()))
.where(f.col("cum_sum") <= f.col("total_sum")*target)
df_filtered.show()
#+----+-----+-------+---------+
#|name|value|cum_sum|total_sum|
#+----+-----+-------+---------+
#| A| 30| 30| 136|
#| B| 25| 55| 136|
#| C| 20| 75| 136|
#| D| 18| 93| 136|
#+----+-----+-------+---------+
然后将此筛选的数据帧重新联接回value
列上的原始数据帧。由于您的数据帧已按value
排序,因此最终输出将包含所需的行。
df.alias("r")
.join(
df_filtered.alias('l'),
on="value"
).select("r.name", "r.value").sort(f.col("value").desc(), "name").show()
#+----+-----+
#|name|value|
#+----+-----+
#| A| 30|
#| B| 25|
#| C| 20|
#| D| 18|
#| E| 18|
#+----+-----+
total_sum
列和cum_sum
列是使用Window
函数计算的。
窗口w
value
列的顺序降序,后跟name
列。name
列用于断开关系 - 没有它,C
和D
的两行将具有相同的累积111 = 75+18+18
总和,您将错误地在过滤器中丢失它们。
w = Window # Define Window
.orderBy( # This will define ordering
f.col("value").desc(), # First sort by value descending
"name" # Sort on name second
)
.rangeBetween(Window.unboundedPreceding, 0) # Extend back to beginning of window
rangeBetween(Window.unboundedPreceding, 0)
指定窗口应包括当前行(由orderBy
定义)之前的所有行。这就是使它成为累积总和的原因。