Sklearn的GroupShuffleSplit正在产生重叠的结果



我过去使用过GroupShuffleSplit,它工作得很好。但现在我试图根据列进行拆分,它会在测试和训练数据之间产生重叠。这是我正在运行的

val_inds, test_inds = next(GroupShuffleSplit(test_size=0.5,
n_splits=2,).split(df, groups=df['cl_uid'].values))

df_val = df[df.index.isin(val_inds)]
df_test = df[df.index.isin(test_inds)]
# this value is not zero
len(set(df_val.cl_uid).intersection(set(df_test.cl_uid)))

你知道是怎么回事吗?

sklearn版本0.24.1和Python版本3.6

GroupShuffleSplit的返回值为数组索引所以如果你想分割你的DataFrame,你应该使用.iloc来过滤。

df_val = df.iloc[val_inds]
df_test = df.iloc[test_inds]

如果您错误地尝试使用index进行过滤,那么您假设您有一个从0开始的非重复RangeIndex。如果不是这种情况,则此过滤必然会失败。


import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
# DataFrame with a non-RangeIndex
df = pd.DataFrame({'clust_id': [1,1,1,2,2,2,2]}, index=[1,2,1,2,1,2,3])
val_inds, test_inds = next(GroupShuffleSplit(test_size=0.5, n_splits=2,).split(df, groups=df['clust_id']))

正确分割

df_val = df.iloc[val_inds]
#   clust_id
#2         2
#1         2
#2         2
#3         2
df_test = df.iloc[test_inds]
#   clust_id
#1         1
#2         1
#1         1

不正确的分割,混淆了索引标签和数组位置标签

df[df.index.isin(val_inds)]
#   clust_id
#3         2
df[df.index.isin(test_inds)]
#   clust_id
#1         1
#2         1
#1         1
#2         2
#1         2
#2         2

最新更新