PySpark DataFrame:找到最接近的值并对数据帧进行切片



所以,我已经做了足够的研究,还没有找到一篇关于我想做的事情的帖子。

我有一个 PySpark 数据帧my_df它由valuesorted

+----+-----+                                                                    
|name|value|
+----+-----+
|   A|   30|
|   B|   25|
|   C|   20|
|   D|   18|
|   E|   18|
|   F|   15|
|   G|   10|
+----+-----+

value列中所有计数的总和等于136。我想获取其combined values >= x% of 136的所有行.在此示例中,假设x=80.然后target sum = 0.8*136 = 108.8.因此,新的数据帧将包含具有combined value >= 108.8的所有行。

在我们的示例中,这将归结为第D行(因为组合值高达 D =30+25+20+18 = 93)。

但是,困难的部分是我还想包含具有重复值的紧随其后的行。在这种情况下,我还想包含行E,因为它与行D具有相同的值,即18.

我想通过给出一个百分比x变量来切片my_df,例如如上所述80。新的数据帧应包含以下行-

+----+-----+                                                                    
|name|value|
+----+-----+
|   A|   30|
|   B|   25|
|   C|   20|
|   D|   18|
|   E|   18|
+----+-----+

我可以在这里做的一件事是遍历数据帧(which is ~360k rows),但我想这违背了 Spark 的目的。

这里有我想要的简洁功能吗?

使用 pyspark SQL 函数简洁地执行此操作。

result = my_df.filter(my_df.value > target).select(my_df.name,my_df.value)
result.show()

编辑:根据OP的问题编辑 - 计算运行总和并获取行,直到达到目标值。请注意,这将导致最多为 D 的行,而不是 E。这似乎是一个奇怪的要求。

from pyspark.sql import Window
from pyspark.sql import functions as f
# Total sum of all `values`
target = (my_df.agg(sum("value")).collect())[0][0]
w = Window.orderBy(my_df.name) #Ideally this should be a column that specifies ordering among rows
running_sum_df = my_df.withColumn('rsum',f.sum(my_df.value).over(w))
running_sum_df.filter(running_sum_df.rsum <= 0.8*target)

您的要求非常严格,因此很难为您的问题制定有效的解决方案。不过,这里有一种方法:

首先计算value列的累积总和和总和,并使用指定的目标条件的百分比筛选数据帧。我们把这个结果称为df_filtered

import pyspark.sql.functions as f
from pyspark.sql import Window
w = Window.orderBy(f.col("value").desc(), "name").rangeBetween(Window.unboundedPreceding, 0)
target = 0.8
df_filtered = df.withColumn("cum_sum", f.sum("value").over(w))
.withColumn("total_sum", f.sum("value").over(Window.partitionBy()))
.where(f.col("cum_sum") <= f.col("total_sum")*target)
df_filtered.show()
#+----+-----+-------+---------+
#|name|value|cum_sum|total_sum|
#+----+-----+-------+---------+
#|   A|   30|     30|      136|
#|   B|   25|     55|      136|
#|   C|   20|     75|      136|
#|   D|   18|     93|      136|
#+----+-----+-------+---------+

然后将此筛选的数据帧重新联接回value列上的原始数据帧。由于您的数据帧已按value排序,因此最终输出将包含所需的行。

df.alias("r")
.join(
df_filtered.alias('l'),
on="value"
).select("r.name", "r.value").sort(f.col("value").desc(), "name").show()
#+----+-----+
#|name|value|
#+----+-----+
#|   A|   30|
#|   B|   25|
#|   C|   20|
#|   D|   18|
#|   E|   18|
#+----+-----+

total_sum列和cum_sum列是使用Window函数计算的。

窗口wvalue列的顺序降序,后跟name列。name列用于断开关系 - 没有它,CD的两行将具有相同的累积111 = 75+18+18总和,您将错误地在过滤器中丢失它们。

w = Window                                     # Define Window
.orderBy(                                   # This will define ordering
f.col("value").desc(),                  # First sort by value descending
"name"                                  # Sort on name second
)
.rangeBetween(Window.unboundedPreceding, 0) # Extend back to beginning of window

rangeBetween(Window.unboundedPreceding, 0)指定窗口应包括当前行(由orderBy定义)之前的所有行。这就是使它成为累积总和的原因。

相关内容

  • 没有找到相关文章

最新更新