Spark 数据帧中的 argmax:如何检索具有最大值的行



给定一个Spark数据帧df,我想在某个数字列中找到最大值'values',并获取达到该值的行。我当然可以这样做:

# it doesn't matter if I use scala or python, 
# since I hope I get this done with DataFrame API
import pyspark.sql.functions as F
max_value = df.select(F.max('values')).collect()[0][0]
df.filter(df.values == max_value).show()

但这效率低下,因为它需要两次通过df .

pandas.Series/DataFramenumpy.arrayargmax/idxmax 种方法可以有效地做到这一点(一次通过(。标准python也是如此(内置函数max接受键参数,因此可以用来查找最高值的索引(。

Spark 的正确方法是什么?请注意,我不介意是获得达到最大值的所有行,还是只是这些行的任意(非空!(子集。

如果模式Orderable(模式仅包含原子/原子数组/递归排序结构(,则可以使用简单聚合:

蟒蛇

df.select(F.max(
    F.struct("values", *(x for x in df.columns if x != "values"))
)).first()

斯卡拉:

df.select(max(struct(
    $"values" +: df.columns.collect {case x if x!= "values" => col(x)}: _*
))).first

否则,您可以减少 Dataset(仅限 Scala(,但它需要额外的反序列化:

type T = ???
df.reduce((a, b) => if (a.getAs[T]("values") > b.getAs[T]("values")) a else b)

您还可以oredrBylimit(1)/take(1)

斯卡拉:

df.orderBy(desc("values")).limit(1)
// or
df.orderBy(desc("values")).take(1)

蟒蛇

df.orderBy(F.desc('values')).limit(1)
# or
df.orderBy(F.desc("values")).take(1)

也许这是一个不完整的答案,但您可以使用DataFrame的内部RDD,应用 max 方法并使用确定的键获得最大记录。

a = sc.parallelize([
    ("a", 1, 100),
    ("b", 2, 120),
    ("c", 10, 1000),
    ("d", 14, 1000)
  ]).toDF(["name", "id", "salary"])
a.rdd.max(key=lambda x: x["salary"]) # Row(name=u'c', id=10, salary=1000)

最新更新