如何在 Pytorch 中保存和加载随机数生成器状态?

How to save and load random number generator state in Pytorch?

我正在 Pytorch 中训练深度学习模型,并希望以确定性方式训练我的模型。 正如 this 官方指南中所写,我这样设置随机种子:

np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

现在,我的训练很长,我想保存,然后加载所有内容,包括 RNG。我对模型和优化器使用 torch.savetorch.load_state_dict

如何保存和加载随机数生成器?

您可以使用torch.get_rng_state and torch.set_rng_state

调用 torch.get_rng_state 时,您将获得随机数生成器状态 torch.ByteTensor。

然后您可以将此张量保存在文件中的某个位置,稍后您可以加载并使用 torch.set_rng_state 来设置随机数生成器状态。


当使用 numpy 时,您当然可以使用以下方法做同样的事情:
numpy.random.get_state and numpy.random.set_state