pytorch "trying to backward through the graph a second time" 字符级 RNN 错误
pytorch "trying to backward through the graph a second time" error with chracter level RNN
我正在用 pytorch 训练一个字符级别的 GRU,同时将文本分成一定块长度的批次。
这是训练循环:
for e in range(self.epochs):
self.model.train()
h = self.get_init_state(self.batch_size)
for batch_num in range(self.num_batch_runs):
batch = self.generate_batch(batch_num).to(device)
inp_batch = batch[:-1,:]
tar_batch = batch[1:,:]
self.model.zero_grad()
loss = 0
for i in range(inp_batch.shape[0]):
out, h = self.model(inp_batch[i:i+1,:],h)
loss += loss_fn(out[0],tar_batch[i].view(-1))
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
optimizer.step()
if not (batch_num % 5):
print("epoch: {}, loss: {}".format(e,loss.data.item()/inp_batch.shape[0]))
我仍然在第一批后收到此错误:
Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
提前致谢..
我自己找到了答案,GRU 的隐藏状态仍然附加到最后一批 运行,因此必须使用
将其分离
h.detach_()
我正在用 pytorch 训练一个字符级别的 GRU,同时将文本分成一定块长度的批次。 这是训练循环:
for e in range(self.epochs):
self.model.train()
h = self.get_init_state(self.batch_size)
for batch_num in range(self.num_batch_runs):
batch = self.generate_batch(batch_num).to(device)
inp_batch = batch[:-1,:]
tar_batch = batch[1:,:]
self.model.zero_grad()
loss = 0
for i in range(inp_batch.shape[0]):
out, h = self.model(inp_batch[i:i+1,:],h)
loss += loss_fn(out[0],tar_batch[i].view(-1))
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
optimizer.step()
if not (batch_num % 5):
print("epoch: {}, loss: {}".format(e,loss.data.item()/inp_batch.shape[0]))
我仍然在第一批后收到此错误:
Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
提前致谢..
我自己找到了答案,GRU 的隐藏状态仍然附加到最后一批 运行,因此必须使用
将其分离h.detach_()