如何在 Pytorch(闪电)中使用 T运行cated 反向传播在非常长的序列上 运行 LSTM?

How to run LSTM on very long sequence using Truncated Backpropagation in Pytorch (lightning)?

我有一个很长的时间序列,我想输入 LSTM 进行每帧分类。

我的数据是按帧标记的,我知道发生了一些罕见的事件,这些事件自发生以来就严重影响了分类。

因此,我必须输入整个序列才能获得有意义的预测。

众所周知,仅将非常长的序列输入 LSTM 是次优的,因为梯度会像普通 RNN 一样消失或爆炸。


我想使用一种简单的技术将序列切割成较短的(比如 100 长)序列,运行 每个序列上的 LSTM,然后将最终的 LSTM 隐藏和单元状态作为开始下一次前向传播的隐藏和细胞状态。

是我发现的一个这样做的例子。在那里它被称为“T运行cated Back propagation through time”。我无法为我做同样的工作。


我在 Pytorch 闪电中的尝试(去掉无关部分):

def __init__(self, config, n_classes, datamodule):
    ...
    self._criterion = nn.CrossEntropyLoss(
        reduction='mean',
    )

    num_layers = 1
    hidden_size = 50
    batch_size=1

    self._lstm1 = nn.LSTM(input_size=len(self._in_features), hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
    self._log_probs = nn.Linear(hidden_size, self._n_predicted_classes)
    self._last_h_n = torch.zeros((num_layers, batch_size, hidden_size), device='cuda', dtype=torch.double, requires_grad=False)
    self._last_c_n = torch.zeros((num_layers, batch_size, hidden_size), device='cuda', dtype=torch.double, requires_grad=False)

def training_step(self, batch, batch_index):
    orig_batch, label_batch = batch
    n_labels_in_batch = np.prod(label_batch.shape)
    lstm_out, (self._last_h_n, self._last_c_n) = self._lstm1(orig_batch, (self._last_h_n, self._last_c_n))
    log_probs = self._log_probs(lstm_out)
    loss = self._criterion(log_probs.view(n_labels_in_batch, -1), label_batch.view(n_labels_in_batch))

    return loss

运行 此代码给出以下错误:

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

如果我添加

也会发生同样的情况
def on_after_backward(self) -> None:
    self._last_h_n.detach()
    self._last_c_n.detach()

如果我使用

则不会发生错误
lstm_out, (self._last_h_n, self._last_c_n) = self._lstm1(orig_batch,)

但显然这是无用的,因为当前帧批次的输出不会转发到下一个批次。


导致此错误的原因是什么?我认为分离输出 h_nc_n 应该足够了。

如何将前一个帧批的输出传递到下一个帧批并让 torch 分别向后传播每个帧批?

显然,我错过了 detach() 的尾随 _:

正在使用

def on_after_backward(self) -> None:
    self._last_h_n.detach_()
    self._last_c_n.detach_()

有效。


问题是 self._last_h_n.detach() 没有更新对由 detach() 分配的 新内存 的引用,因此该图仍在取消引用旧变量反向传播通过了。 通过 H = H.detach().

解决了这个问题

更干净(并且可能更快)的是 self._last_h_n.detach_(),它就地执行操作。