熊猫设计的分层样本



我有一个带有表示地层(strat(的列的df。我想在这些层上循环,并将行拉出到一个新的df,df_sample。如果案例很少,我想把一个层次中的所有行都去掉。

我试过下面的方法,效果很好。但我想知道这个问题是否有更好的解决方案。例如,当我稍后使用实际的更大的数据时,pd.concat可能会很慢。

df=pd.DataFrame({'ID': range(0,120),
'strat': ['A', 'B', 'B', 'A', 'B', 'A', 'D', 'A', 'B', 'C', 
'A', 'D', 'A', 'A', 'A', 'D', 'F', 'D', 'F', 'C', 
'B', 'A', 'A', 'C', 'A', 'A', 'B', 'D', 'B', 'C', 
'C', 'A', 'C', 'A', 'C', 'A', 'D', 'C', 'C', 'A', 
'B', 'F', 'F', 'C', 'B', 'D', 'A', 'A', 'B', 'B', 
'A', 'C', 'A', 'A', 'F', 'A', 'A', 'B', 'A', 'D', 
'C', 'B', 'B', 'A', 'B', 'C', 'B', 'A', 'D', 'B', 
'B', 'A', 'A', 'C', 'D', 'F', 'F', 'A', 'B', 'C',
'F', 'B', 'D', 'A', 'A', 'F', 'B', 'D', 'B', 'A',
'F', 'D', 'A', 'A', 'C', 'B', 'B', 'C', 'C', 'B',
'F', 'A', 'A', 'B', 'B', 'B', 'F', 'A', 'B', 'C',
'A', 'A', 'A', 'B', 'B', 'A', 'A', 'A', 'B', 'B']})
df_sample=pd.DataFrame()
for i in df.strat.unique():
temp=df[df['strat']==i]

if len(temp) < 21:
strat=temp.sample(len(temp))

elif len(temp) > 20:
strat = temp.sample(frac=0.5)

df_sample=pd.concat([df_sample, strat])

其他解决方案可能更快。如果可读性/可维护性更重要,这里还有另一个。

def sample_stratum(stratum):
nrows = stratum.shape[0]
if nrows < 21:
output = stratum.sample(nrows)
else:
output = stratum.sample(frac=0.5)
return output

# Index may be retained if needed
sampled_df = df.groupby(by=['strat']).apply(sample_stratum).reset_index(drop=True)

#    ID strat
# 0   12     A
# 1    7     A
# 2   50     A
# 3   58     A
# 4    0     A
# ..  ..   ...
# 77  41     F
# 78  42     F
# 79  16     F
# 80  76     F
# 81  90     F
# [82 rows x 2 columns]

您可以groupby"strat";并对每个"表"中的条目数进行计数;strat";,然后识别少于21个条目的strat并对其进行洗牌。然后取剩余的地层(那些有20个以上条目的地层(,并对其中的50%进行采样。最后连接两个数据帧:

msk1 = df.groupby('strat')['strat'].count() < 21
less_than_21 = msk1.index[msk1]
msk2 = df['strat'].isin(less_than_21)    
out = pd.concat((df[~msk2].groupby('strat').sample(frac=0.5), df[msk2].sample(msk2.sum())))

输出:

ID strat
110  110     A
72    72     A
46    46     A
31    31     A
92    92     A
..   ...   ...
18    18     F
9      9     C
23    23     C
42    42     F
82    82     D
[82 rows x 2 columns]

为所有有计数的组创建掩码,然后分别处理每组:

m = df.groupby('strat')['strat'].transform('size').lt(21)
df = pd.concat((df[~m].groupby('strat').sample(frac=0.5), 
df[m].sample(frac=1)),
ignore_index=True)
print (df)
ID strat
0   71     A
1   31     A
2   72     A
3   39     A
4   83     A
..  ..   ...
77  37     C
78  85     F
79  19     C
80  34     C
81  73     C
[82 rows x 2 columns]

替代解决方案:

m = df['strat'].map(df['strat'].value_counts()).lt(21)
df = pd.concat((df[~m].groupby('strat').sample(frac=0.5), df[m].sample(frac=1)))

最新更新