PyTorch中K-Fold交叉验证的可重复性



我最近开始了一个使用PyTorch的新项目,我仍然是AI的新手。为了在训练过程中更好地处理我的数据集,我使用了交叉验证技术。每个人似乎都工作得很好,但我正在努力再现。我甚至尝试为每个k-fold迭代设置SEED号,但它似乎根本不起作用。损失和准确性的变化是微不足道的,但它们是。在使用交叉验证之前,一切都很完美。提前谢谢你。

这是k-fold的for循环。我使用的解决方案来自:PyTorch中使用dataloader的k-fold交叉验证

K_FOLD = 5
fraction = 1 / K_FOLD
unit = int(dataset_length * fraction)
for i in range(K_FOLD):

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)            # if you are using multi-GPU.
np.random.seed(SEED)                        # Numpy module.
random.seed(SEED)                           # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

print("-----------K-FOLD {}------------".format(i+1))
tr_ll = 0
print("Train left begin:", tr_ll)
tr_lr = i * unit
print("Train left end:", tr_lr)
val_l = tr_lr
print("Validation begin:", val_l)
val_r = i * unit + unit
print("Validation end:", val_r)
tr_rl = val_r
print("Train right begin:", tr_rl)
tr_rr = dataset_length
print("Train right end:", tr_rr)
# msg
#         print("train indices: [%d,%d),[%d,%d), test indices: [%d,%d)"
#               % (tr_ll,tr_lr,tr_rl,tr_rr,val_l,val_r))
train_left_indices = list(range(tr_ll, tr_lr))
train_right_indices = list(range(tr_rl, tr_rr))
train_indices = train_left_indices + train_right_indices
val_indices = list(range(val_l, val_r))
# print("TRAIN Indices:", train_indices, "VAL Indices:", val_indices)
train_set = torch.utils.data.dataset.Subset(DATASET, train_indices)
val_set = torch.utils.data.dataset.Subset(DATASET, val_indices)
# print("Length of train set:", len(train_set), "Length of val set:", len(val_set))

image_datasets = {"train": train_set, "val": val_set}
loader = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=10, shuffle=True)
for x in sets}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

# training
trained_model = train_model(AlexNet, CRITERION, OPTIMIZER,
dataloader=loader, dataset_sizes=dataset_sizes, num_epochs=EPOCHS, k_fold=i)

根据最新的文档,看起来你还需要:

torch.use_deterministic_algorithms(True)

最新更新