batch_first 在 PyTorch LSTM 中

batch_first in PyTorch LSTM

我是这个领域的新手,所以我对 PyTorch LSTM 中的 batch_first 还不是很了解。我尝试了某人提到我的代码,当 batch_first = False 时它对我的火车数据起作用,它为官方 LSTM 和手动 LSTM 产生相同的输出。但是,当我更改 batch_first = True 时,它​​不再产生相同的值,而我需要将 batch_first 更改为 True,因为我的数据集形状是张量 (Batch, Sequences, Input size ).当 batch_first = True 时,需要更改手动 LSTM 的哪一部分以产生与官方 LSTM 产生的相同输出?这是代码片段:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

train_x = torch.tensor([[[0.14285755], [0], [0.04761982], [0.04761982], [0.04761982],
  [0.04761982], [0.04761982], [0.09523869], [0.09523869], [0.09523869], 
  [0.09523869], [0.09523869], [0.04761982], [0.04761982], [0.04761982],
  [0.04761982], [0.09523869], [0.        ], [0.        ], [0.        ],
  [0.        ], [0.09523869], [0.09523869], [0.09523869], [0.09523869],
  [0.09523869], [0.09523869], [0.09523869],[0.14285755], [0.14285755]]], 
  requires_grad=True)

seed = 23
torch.manual_seed(seed)
np.random.seed(seed)

pytorch_lstm = torch.nn.LSTM(1, 1, bidirectional=False, num_layers=1, batch_first=True)
weights = torch.randn(pytorch_lstm.weight_ih_l0.shape,dtype = torch.float)
pytorch_lstm.weight_ih_l0 = torch.nn.Parameter(weights)
# Set bias to Zero
pytorch_lstm.bias_ih_l0 = torch.nn.Parameter(torch.zeros(pytorch_lstm.bias_ih_l0.shape))
pytorch_lstm.weight_hh_l0 = torch.nn.Parameter(torch.ones(pytorch_lstm.weight_hh_l0.shape))
# Set bias to Zero
pytorch_lstm.bias_hh_l0 = torch.nn.Parameter(torch.zeros(pytorch_lstm.bias_ih_l0.shape))
pytorch_lstm_out = pytorch_lstm(train_x)

batch_size=1

# Manual Calculation
W_ii, W_if, W_ig, W_io = pytorch_lstm.weight_ih_l0.split(1, dim=0)
b_ii, b_if, b_ig, b_io = pytorch_lstm.bias_ih_l0.split(1, dim=0)

W_hi, W_hf, W_hg, W_ho = pytorch_lstm.weight_hh_l0.split(1, dim=0)
b_hi, b_hf, b_hg, b_ho = pytorch_lstm.bias_hh_l0.split(1, dim=0)

prev_h = torch.zeros((batchsize,1))
prev_c = torch.zeros((batchsize,1))

i_t = torch.sigmoid(F.linear(train_x, W_ii, b_ii) + F.linear(prev_h, W_hi, b_hi))
f_t = torch.sigmoid(F.linear(train_x, W_if, b_if) + F.linear(prev_h, W_hf, b_hf))
g_t = torch.tanh(F.linear(train_x, W_ig, b_ig) + F.linear(prev_h, W_hg, b_hg))
o_t = torch.sigmoid(F.linear(train_x, W_io, b_io) + F.linear(prev_h, W_ho, b_ho))
c_t = f_t * prev_c + i_t * g_t
h_t = o_t * torch.tanh(c_t)

print('nn.LSTM output {}, manual output {}'.format(pytorch_lstm_out[0], h_t))
print('nn.LSTM hidden {}, manual hidden {}'.format(pytorch_lstm_out[1][0], h_t))
print('nn.LSTM state {}, manual state {}'.format(pytorch_lstm_out[1][1], c_t))

您必须一次遍历每个序列元素,并将计算出的隐藏状态和单元状态作为下一个时间步的输入...

h_t = torch.zeros((batch_size,1))
c_t = torch.zeros((batch_size,1))

hidden_seq = []

for t in range(30):
  x_t = train_x[:, t, :]
  i_t = torch.sigmoid(F.linear(x_t, W_ii, b_ii) + F.linear(h_t, W_hi, b_hi))
  f_t = torch.sigmoid(F.linear(x_t, W_if, b_if) + F.linear(h_t, W_hf, b_hf))
  g_t = torch.tanh(F.linear(x_t, W_ig, b_ig) + F.linear(h_t, W_hg, b_hg))
  o_t = torch.sigmoid(F.linear(x_t, W_io, b_io) + F.linear(h_t, W_ho, b_ho))
  c_t = f_t * c_t + i_t * g_t
  h_t = o_t * torch.tanh(c_t)
  hidden_seq.append(h_t.unsqueeze(0))

hidden_seq = torch.cat(hidden_seq, dim=0)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()

print('nn.LSTM output {}, manual output {}'.format(pytorch_lstm_out[0], hidden_seq))