Pandas groupby:获取子组中的最大值



我有一个按列、行、年、potveg和总计分组的大型数据集。我正在尝试获取某个组特定年份中"总计"列的最大值。即,对于以下数据集:

col      row    year    potveg  total
-125.0  42.5    2015    9       697.3
2015    13      535.2
2015    15      82.3
2016    9       907.8
2016    13      137.6
2016    15      268.4
2017    9       961.9
2017    13      74.2
2017    15      248.0
2018    9       937.9
2018    13      575.6
2018    15      215.5
-135.0  70.5    2015    8       697.3
2015    10      535.2
2015    19      82.3
2016    8       907.8
2016    10      137.6
2016    19      268.4
2017    8       961.9
2017    10      74.2
2017    19      248.0
2018    8       937.9
2018    10      575.6
2018    19      215.5

我希望输出看起来像这样:

col      row    year    potveg  total
-125.0  42.5    2015    9       697.3
2016    9       907.8
2017    9       961.9
2018    9       937.9
-135.0  70.5    2015    8       697.3
2016    8       907.8
2017    8       961.9
2018    8       937.9

我试过这个:

df.groupby(['col', 'row', 'year', 'potveg']).agg({'total': 'max'})

这个:

df.groupby(['col', 'row', 'year', 'potveg'])['total'].max()

但它们似乎不起作用,因为输出的行太多。我认为问题在于"potveg"列,它是一个子组。我不知道如何选择包含"total"最大值的行。

一种可能的解决方案,在groupby.apply:中使用.idxmax()

print(
df.groupby(["col", "row", "year"], as_index=False, sort=False).apply(
lambda x: x.loc[x["total"].idxmax()]
)
)

打印:

col   row    year  potveg  total
0 -125.0  42.5  2015.0     9.0  697.3
1 -125.0  42.5  2016.0     9.0  907.8
2 -125.0  42.5  2017.0     9.0  961.9
3 -125.0  42.5  2018.0     9.0  937.9
4 -135.0  70.5  2015.0     8.0  697.3
5 -135.0  70.5  2016.0     8.0  907.8
6 -135.0  70.5  2017.0     8.0  961.9
7 -135.0  70.5  2018.0     8.0  937.9

使用的数据帧:

col   row  year potveg  total
0   -125.0  42.5  2015      9  697.3
1   -125.0  42.5  2015     13  535.2
2   -125.0  42.5  2015     15   82.3
3   -125.0  42.5  2016      9  907.8
4   -125.0  42.5  2016     13  137.6
5   -125.0  42.5  2016     15  268.4
6   -125.0  42.5  2017      9  961.9
7   -125.0  42.5  2017     13   74.2
8   -125.0  42.5  2017     15  248.0
9   -125.0  42.5  2018      9  937.9
10  -125.0  42.5  2018     13  575.6
11  -125.0  42.5  2018     15  215.5
12  -135.0  70.5  2015      8  697.3
13  -135.0  70.5  2015     10  535.2
14  -135.0  70.5  2015     19   82.3
15  -135.0  70.5  2016      8  907.8
16  -135.0  70.5  2016     10  137.6
17  -135.0  70.5  2016     19  268.4
18  -135.0  70.5  2017      8  961.9
19  -135.0  70.5  2017     10   74.2
20  -135.0  70.5  2017     19  248.0
21  -135.0  70.5  2018      8  937.9
22  -135.0  70.5  2018     10  575.6
23  -135.0  70.5  2018     19  215.5

选项1:一种方法是执行groupby(),然后与原始df合并

df1 = pd.merge(df.groupby(['col','row','year']).agg({'total':'max'}).reset_index(), 
df, 
on=['col', 'row', 'year', 'total'])
print(df1)

输出:

col   row  year  total potveg
0  -125.0  42.5  2015  697.3      9
1  -125.0  42.5  2016  907.8      9
2  -125.0  42.5  2017  961.9      9
3  -125.0  42.5  2018  937.9      9
4  -135.0  70.5  2015  697.3      8
5  -135.0  70.5  2016  907.8      8
6  -135.0  70.5  2017  961.9      8
7  -135.0  70.5  2018  937.9      8

选项2:或者像这样使用sort_values()drop_duplicates()

df1 = df.sort_values(['col','row','year']).drop_duplicates(['col','row','year'], keep='first')
print(df1)

输出:

col   row  year potveg  total
0   -125.0  42.5  2015      9  697.3
3   -125.0  42.5  2016      9  907.8
6   -125.0  42.5  2017      9  961.9
9   -125.0  42.5  2018      9  937.9
12  -135.0  70.5  2015      8  697.3
15  -135.0  70.5  2016      8  907.8
18  -135.0  70.5  2017      8  961.9
21  -135.0  70.5  2018      8  937.9

最新更新