batch_first in PyTorch LSTM



我是这个领域的新手,所以我仍然不了解PyTorch LSTM中的batch_first。我尝试了有人引用给我的代码,当batch_first = False时,它对我的火车数据起作用,它为官方LSTM和手动LSTM产生相同的输出。但是,当我更改batch_first = True时,它不再产生相同的值,而我需要将batch_first更改为True,因为我的数据集形状是张量(Batch, Sequences, Input size)。当batch_first = True时,Manual LSTM的哪个部分需要更改以产生与Official LSTM相同的输出?下面是代码片段:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
train_x = torch.tensor([[[0.14285755], [0], [0.04761982], [0.04761982], [0.04761982],
[0.04761982], [0.04761982], [0.09523869], [0.09523869], [0.09523869], 
[0.09523869], [0.09523869], [0.04761982], [0.04761982], [0.04761982],
[0.04761982], [0.09523869], [0.        ], [0.        ], [0.        ],
[0.        ], [0.09523869], [0.09523869], [0.09523869], [0.09523869],
[0.09523869], [0.09523869], [0.09523869],[0.14285755], [0.14285755]]], 
requires_grad=True)
seed = 23
torch.manual_seed(seed)
np.random.seed(seed)
pytorch_lstm = torch.nn.LSTM(1, 1, bidirectional=False, num_layers=1, batch_first=True)
weights = torch.randn(pytorch_lstm.weight_ih_l0.shape,dtype = torch.float)
pytorch_lstm.weight_ih_l0 = torch.nn.Parameter(weights)
# Set bias to Zero
pytorch_lstm.bias_ih_l0 = torch.nn.Parameter(torch.zeros(pytorch_lstm.bias_ih_l0.shape))
pytorch_lstm.weight_hh_l0 = torch.nn.Parameter(torch.ones(pytorch_lstm.weight_hh_l0.shape))
# Set bias to Zero
pytorch_lstm.bias_hh_l0 = torch.nn.Parameter(torch.zeros(pytorch_lstm.bias_ih_l0.shape))
pytorch_lstm_out = pytorch_lstm(train_x)
batch_size=1
# Manual Calculation
W_ii, W_if, W_ig, W_io = pytorch_lstm.weight_ih_l0.split(1, dim=0)
b_ii, b_if, b_ig, b_io = pytorch_lstm.bias_ih_l0.split(1, dim=0)
W_hi, W_hf, W_hg, W_ho = pytorch_lstm.weight_hh_l0.split(1, dim=0)
b_hi, b_hf, b_hg, b_ho = pytorch_lstm.bias_hh_l0.split(1, dim=0)
prev_h = torch.zeros((batchsize,1))
prev_c = torch.zeros((batchsize,1))
i_t = torch.sigmoid(F.linear(train_x, W_ii, b_ii) + F.linear(prev_h, W_hi, b_hi))
f_t = torch.sigmoid(F.linear(train_x, W_if, b_if) + F.linear(prev_h, W_hf, b_hf))
g_t = torch.tanh(F.linear(train_x, W_ig, b_ig) + F.linear(prev_h, W_hg, b_hg))
o_t = torch.sigmoid(F.linear(train_x, W_io, b_io) + F.linear(prev_h, W_ho, b_ho))
c_t = f_t * prev_c + i_t * g_t
h_t = o_t * torch.tanh(c_t)
print('nn.LSTM output {}, manual output {}'.format(pytorch_lstm_out[0], h_t))
print('nn.LSTM hidden {}, manual hidden {}'.format(pytorch_lstm_out[1][0], h_t))
print('nn.LSTM state {}, manual state {}'.format(pytorch_lstm_out[1][1], c_t))

您必须一次遍历每个序列元素,并在下一个时间步骤中将计算的隐藏状态和单元状态作为输入…

h_t = torch.zeros((batch_size,1))
c_t = torch.zeros((batch_size,1))
hidden_seq = []
for t in range(30):
x_t = train_x[:, t, :]
i_t = torch.sigmoid(F.linear(x_t, W_ii, b_ii) + F.linear(h_t, W_hi, b_hi))
f_t = torch.sigmoid(F.linear(x_t, W_if, b_if) + F.linear(h_t, W_hf, b_hf))
g_t = torch.tanh(F.linear(x_t, W_ig, b_ig) + F.linear(h_t, W_hg, b_hg))
o_t = torch.sigmoid(F.linear(x_t, W_io, b_io) + F.linear(h_t, W_ho, b_ho))
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
print('nn.LSTM output {}, manual output {}'.format(pytorch_lstm_out[0], hidden_seq))

最新更新