如何在 PySpark 中从 df.collect() 结果中检索特定值?



我在 PySpark 中有以下数据帧df

import pyspark.sql.functions as func
df = spark
.read 
.format("org.elasticsearch.spark.sql") 
.load("my_index/my_mapping") 
.groupBy(["id", "type"]) 
.agg(
func.count(func.lit(1)).alias("number_occurrences"),
func.countDistinct("host_id").alias("number_hosts")
)
ds = df.collect()

我之所以使用collect是因为分组和聚合后的数据量总是很小,适合内存。 另外,我需要使用collect因为我将ds作为函数udf参数传递。 函数collect返回一个数组。如何对此数组进行以下查询:对于给定的idtype,返回number_occurrencesnumber_hosts

例如,假设df包含以下行:

id   type   number_occurrences   number_hosts
1    xxx    11                   3
2    yyy    10                   4 

做完df.collect()后,如何检索number_occurencesnumber_hostsid等于1type等于xxx。 预期结果是:

number_occurrences = 11
number_hosts = 3

更新:

也许有更优雅的解决方案?

id = 1
type = "xxx"
number_occurrences = 0
number_hosts = 0
for row in ds:
if (row["id"] == id) & (row["type"] == type):
number_occurrences = row["number_occurrences"]
number_hosts = row["number_hosts"]

如果你的id是唯一的(id 应该是这种情况(,你可以根据 id 对数组进行排序。这只能确保正确的顺序,如果您的 id 是连续的,您可以直接访问记录并将 id 减去 1

test_df = spark.createDataFrame([
(1,"xxx",11,3),(2,"yyyy",10,4),
], ("id","type","number_occurrences","number_hosts"))
id = 1
type = "xxx"
sorted_list = sorted(test_df.collect(), cmp=lambda x,y: cmp(x["id"],y["id"]))
sorted_list[id-1]["number_occurrences"],sorted_list[id-1]["number_hosts"]

结果:

(11, 3)

最新更新