我不太理解sklearn
函数train_test_split
和StratifiedKFold
根据多个";列";并且不仅根据目标分布。我知道前面的句子有点晦涩,所以我希望下面的代码能有所帮助。
import numpy as np
import pandas as pd
import random
n_samples = 100
prob = 0.2
pos = int(n_samples * prob)
neg = n_samples - pos
target = [1] * pos + [0] * neg
cat = ["a"] * 50 + ["b"] * 50
random.shuffle(target)
random.shuffle(cat)
ds = pd.DataFrame()
ds["target"] = target
ds["cat"] = cat
ds["f1"] = np.random.random(size=(n_samples,))
ds["f2"] = np.random.random(size=(n_samples,))
print(ds.head())
这是一个100个示例数据集,目标分布由p
控制,在这种情况下,我们有20%的正示例。有一个完全平衡的二元分类列cat
。上一个代码的输出为:
target cat f1 f2
0 0 a 0.970585 0.134268
1 0 a 0.410689 0.225524
2 0 a 0.638111 0.273830
3 0 b 0.594726 0.579668
4 0 a 0.737440 0.667996
train_test_split()
、stratify
在target
和cat
上,如果我们研究频率,我们得到:
from sklearn.model_selection import train_test_split, StratifiedKFold
# with train_test_split
training, valid = train_test_split(range(n_samples),
test_size=20,
stratify=ds[["target", "cat"]])
print("---")
print("* training")
print(ds.loc[training, ["target", "cat"]].value_counts() / len(training)) # balanced
print("* validation")
print(ds.loc[valid, ["target", "cat"]].value_counts() / len(valid)) # balanced
我们得到这个:
* dataset
0 0.8
1 0.2
Name: target, dtype: float64
target cat
0 a 0.4
b 0.4
1 a 0.1
b 0.1
dtype: float64
---
* training
target cat
0 a 0.4
b 0.4
1 a 0.1
b 0.1
dtype: float64
* validation
target cat
0 a 0.4
b 0.4
1 a 0.1
b 0.1
dtype: float64
它是完全分层的。
现在使用StratifiedKFold
:
# with stratified k-fold
skf = StratifiedKFold(n_splits=5)
try:
for train, valid in skf.split(X=range(len(ds)), y=ds[["target", "cat"]]):
pass
except:
print("! does not work")
for train, valid in skf.split(X=range(len(ds)), y=ds.target):
print("happily iterating")
输出:
! does not work
happily iterating
happily iterating
happily iterating
happily iterating
happily iterating
如何获得train_test_split
和StratifiedKFold
所获得的内容?我知道可能有数据分布不允许在k倍交叉验证中进行这种分层,但我不明白为什么train_test_split
接受两列或更多列,而另一种方法不接受。
目前这似乎不太可能。
Multibel并不完全是你想要的,而是相关的。这是之前在这里被问到的,也是sklearn的github上的一个问题(不确定它为什么关闭(。
作为一个小技巧,你应该能够将你的两列组合成一个新的列,并对其进行分层?