RuntimeError: Trying to backward through the graph a second time

RuntimeError: Trying to backward through the graph a second time

我正在尝试使用 'pyro' 训练 'trainable Bernoulli distribution'。

我想使用 NLL 损失训练伯努利分布的参数(获胜概率)。

train_data 是单热编码稀疏矩阵 (2034,19475),train_labels 有 4 个值 (4 class, [0,1,2,3])。

import torch
import pyro
pyd = pyro.distributions

print("torch version:", torch.__version__)
print("pyro version:", pyro.__version__)

import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(123)


### 0. define Negative Log Likelihood(NLL) loss function
def nll(x_train, distribution):    
    return -torch.mean(distribution.log_prob(torch.tensor(x_train, dtype=torch.float)))


### 1. initialize bernoulli distribution(trainable distribution)
train_vars = (pyd.Uniform(low=torch.FloatTensor([0.01]),
                          high=torch.FloatTensor([0.1])).rsample([train_data.shape[-1]]).squeeze())
distribution = pyd.Bernoulli(probs=train_vars)

### 2. initialize 'label 0' data
class_mask = (train_labels==0)
class_data = train_data[class_mask, :]

### 3. initialize optimizer
optim = torch.optim.Adam([train_vars])

train_vars.requires_grad=True

### 4. train loop
for i in range(0,100):
    
    loss = nll(class_data, distribution)
    
    loss.backward()

当我 运行 这段代码时,我得到如下所示的运行时错误..

我该如何处理这个错误案例?

非常感谢您的评论。

torch version: 1.9.0+cu102
pyro version: 1.7.0
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-269-0081bb1bb843> in <module>
     25     loss = nll(class_data, distribution)
     26 
---> 27     loss.backward()
     28 

/nf/yes/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    253                 create_graph=create_graph,
    254                 inputs=inputs)
--> 255         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    256 
    257     def register_hook(self, hook):

/nf/yes/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    145         retain_graph = create_graph
    146 
--> 147     Variable._execution_engine.run_backward(
    148         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    149         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.

你需要搬家

distribution = pyd.Bernoulli(probs=train_vars)

在循环内部,因为它使用了train_vars,即requires_grad