具有 K 折交叉验证的 PyTorch 中的可重复性
Reproducibility in PyTorch with K-Fold Cross Validation
我最近开始了一个使用 PyTorch 的新项目,我在 AI 方面还是个新手。为了在训练过程中在我的数据集上表现更好,我使用了交叉验证技术。每个人似乎都工作正常,但我正在努力解决可重复性问题。我什至尝试为每个 k 次迭代设置 SEED 编号,但它似乎根本不起作用。损失和准确性的变化微不足道,但确实如此。在使用交叉验证之前,一切都很完美。提前谢谢你。
这是我的 k 折的 for 循环。我使用了以下解决方案:
k-fold cross validation using DataLoaders in PyTorch
K_FOLD = 5
fraction = 1 / K_FOLD
unit = int(dataset_length * fraction)
for i in range(K_FOLD):
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED) # if you are using multi-GPU.
np.random.seed(SEED) # Numpy module.
random.seed(SEED) # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
print("-----------K-FOLD {}------------".format(i+1))
tr_ll = 0
print("Train left begin:", tr_ll)
tr_lr = i * unit
print("Train left end:", tr_lr)
val_l = tr_lr
print("Validation begin:", val_l)
val_r = i * unit + unit
print("Validation end:", val_r)
tr_rl = val_r
print("Train right begin:", tr_rl)
tr_rr = dataset_length
print("Train right end:", tr_rr)
# msg
# print("train indices: [%d,%d),[%d,%d), test indices: [%d,%d)"
# % (tr_ll,tr_lr,tr_rl,tr_rr,val_l,val_r))
train_left_indices = list(range(tr_ll, tr_lr))
train_right_indices = list(range(tr_rl, tr_rr))
train_indices = train_left_indices + train_right_indices
val_indices = list(range(val_l, val_r))
# print("TRAIN Indices:", train_indices, "VAL Indices:", val_indices)
train_set = torch.utils.data.dataset.Subset(DATASET, train_indices)
val_set = torch.utils.data.dataset.Subset(DATASET, val_indices)
# print("Length of train set:", len(train_set), "Length of val set:", len(val_set))
image_datasets = {"train": train_set, "val": val_set}
loader = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=10, shuffle=True)
for x in sets}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
# training
trained_model = train_model(AlexNet, CRITERION, OPTIMIZER,
dataloader=loader, dataset_sizes=dataset_sizes, num_epochs=EPOCHS, k_fold=i)
根据最新的 docs,看来您还需要:
torch.use_deterministic_algorithms(True)
我最近开始了一个使用 PyTorch 的新项目,我在 AI 方面还是个新手。为了在训练过程中在我的数据集上表现更好,我使用了交叉验证技术。每个人似乎都工作正常,但我正在努力解决可重复性问题。我什至尝试为每个 k 次迭代设置 SEED 编号,但它似乎根本不起作用。损失和准确性的变化微不足道,但确实如此。在使用交叉验证之前,一切都很完美。提前谢谢你。
这是我的 k 折的 for 循环。我使用了以下解决方案: k-fold cross validation using DataLoaders in PyTorch
K_FOLD = 5
fraction = 1 / K_FOLD
unit = int(dataset_length * fraction)
for i in range(K_FOLD):
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED) # if you are using multi-GPU.
np.random.seed(SEED) # Numpy module.
random.seed(SEED) # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
print("-----------K-FOLD {}------------".format(i+1))
tr_ll = 0
print("Train left begin:", tr_ll)
tr_lr = i * unit
print("Train left end:", tr_lr)
val_l = tr_lr
print("Validation begin:", val_l)
val_r = i * unit + unit
print("Validation end:", val_r)
tr_rl = val_r
print("Train right begin:", tr_rl)
tr_rr = dataset_length
print("Train right end:", tr_rr)
# msg
# print("train indices: [%d,%d),[%d,%d), test indices: [%d,%d)"
# % (tr_ll,tr_lr,tr_rl,tr_rr,val_l,val_r))
train_left_indices = list(range(tr_ll, tr_lr))
train_right_indices = list(range(tr_rl, tr_rr))
train_indices = train_left_indices + train_right_indices
val_indices = list(range(val_l, val_r))
# print("TRAIN Indices:", train_indices, "VAL Indices:", val_indices)
train_set = torch.utils.data.dataset.Subset(DATASET, train_indices)
val_set = torch.utils.data.dataset.Subset(DATASET, val_indices)
# print("Length of train set:", len(train_set), "Length of val set:", len(val_set))
image_datasets = {"train": train_set, "val": val_set}
loader = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=10, shuffle=True)
for x in sets}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
# training
trained_model = train_model(AlexNet, CRITERION, OPTIMIZER,
dataloader=loader, dataset_sizes=dataset_sizes, num_epochs=EPOCHS, k_fold=i)
根据最新的 docs,看来您还需要:
torch.use_deterministic_algorithms(True)