我过去使用过GroupShuffleSplit
,它工作得很好。但现在我试图根据列进行拆分,它会在测试和训练数据之间产生重叠。这是我正在运行的
val_inds, test_inds = next(GroupShuffleSplit(test_size=0.5,
n_splits=2,).split(df, groups=df['cl_uid'].values))
df_val = df[df.index.isin(val_inds)]
df_test = df[df.index.isin(test_inds)]
# this value is not zero
len(set(df_val.cl_uid).intersection(set(df_test.cl_uid)))
你知道是怎么回事吗?
sklearn
版本0.24.1和Python
版本3.6
GroupShuffleSplit
的返回值为数组索引所以如果你想分割你的DataFrame,你应该使用.iloc
来过滤。
df_val = df.iloc[val_inds]
df_test = df.iloc[test_inds]
如果您错误地尝试使用index
进行过滤,那么您假设您有一个从0开始的非重复RangeIndex
。如果不是这种情况,则此过滤必然会失败。
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
# DataFrame with a non-RangeIndex
df = pd.DataFrame({'clust_id': [1,1,1,2,2,2,2]}, index=[1,2,1,2,1,2,3])
val_inds, test_inds = next(GroupShuffleSplit(test_size=0.5, n_splits=2,).split(df, groups=df['clust_id']))
正确分割
df_val = df.iloc[val_inds]
# clust_id
#2 2
#1 2
#2 2
#3 2
df_test = df.iloc[test_inds]
# clust_id
#1 1
#2 1
#1 1
不正确的分割,混淆了索引标签和数组位置标签
df[df.index.isin(val_inds)]
# clust_id
#3 2
df[df.index.isin(test_inds)]
# clust_id
#1 1
#2 1
#1 1
#2 2
#1 2
#2 2