如PyTorch教程中所示,自动编码器模型的代码如下所示:
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size)
def forward(self, input, hidden):
embedded = self.embedding(input).view(1, 1, -1)
output = embedded
output, hidden = self.gru(output, hidden)
return output, hidden
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
我的问题是,在embedding
层的输出上使用view
函数的原因是什么?
视图函数为给定的输入形状添加了额外的维度,以匹配预期的输入形状。在函数initHidden
中,隐藏形状被初始化为(1, 1, 256)
。
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
根据文档,GRU输入形状必须具有3个维度input of shape (seq_len, batch, input_size)
。
https://pytorch.org/docs/stable/generated/torch.nn.GRU.html
self.embedding(input)
的形状为(1, 256)
,样本输出为
tensor([[ 0.1421, 0.4135, -1.0619, 0.0149, 0.0673, -0.3770, 0.4231, 2.2803,
-1.6939, -0.0071, 1.1131, -1.0019, 0.6593, 0.1366, 1.1033, -0.8804,
1.3676, 0.4115, -0.5671, 0.3314, -0.2599, -0.3082, 1.3644, 0.5788,
-0.1929, -2.0505, 0.4518, 0.8757, -0.2360, -0.4099, -0.5697, -1.5973,
-0.6638, -1.1523, 1.4425, 1.3651, 1.9371, 0.5698, -0.3541, -1.3883,
-0.0195, -1.0757, -1.4324, -1.6226, -2.4267, 0.3874, -0.7529, 1.4938,
-2.5773, -1.1962, 0.3759, -0.6143, -1.0444, -0.6443, -0.8130, -1.7283,
1.4167, 1.3945, -1.2695, 0.7289, 0.7777, -0.0094, -1.8108, 0.2126,
-0.2018, -0.4055, -0.7779, -0.8523, 0.0162, 0.2463, 0.5588, -0.7250,
-0.0128, 0.6272, -0.7729, 0.4259, 0.7596, -1.9500, 0.5853, 0.3764,
-0.1112, 0.7274, -2.8535, -0.0445, 0.4225, 1.2179, 0.2219, -0.7064,
-0.9654, 1.0501, 1.7142, 0.5312, -0.8180, -1.5697, 1.3062, -0.9321,
-0.1652, -1.5298, -0.3575, -1.2046, -0.6571, -0.7689, -0.7032, 1.0727,
-1.3259, 0.1200, 1.9357, -0.2519, -0.3717, 0.8054, 0.1180, -0.6921,
1.0245, -1.5500, -0.5280, -0.7462, 0.7924, 2.2701, -1.5094, -0.1973,
-1.5919, 0.4869, 0.6739, -0.5242, 0.2559, -0.0149, -0.5332, -1.8313,
0.3598, 0.0804, -0.0780, -0.2930, -0.2844, -0.4752, -0.9919, 0.1809,
0.7622, -2.5069, -0.7724, -0.9441, 1.6101, 0.6461, -0.8932, 0.0600,
0.6911, 0.5191, -0.1719, -0.5829, -0.9168, 1.5282, 1.4399, 0.3264,
-0.8894, 0.2880, -0.0697, 0.8977, -0.5004, 0.3844, 0.0925, 0.5592,
-0.1664, 0.8575, -1.0348, 0.7326, -0.2124, 0.7533, 0.6270, -0.9559,
-1.4159, 0.6788, 0.6163, -0.5951, -0.1403, -1.6088, -0.7731, 0.3876,
1.0429, -2.0960, 0.1726, 1.7446, -0.3963, 0.0785, -0.4701, 1.0074,
0.3319, -2.2675, -1.6163, -0.4003, -0.5468, 0.0452, -2.5586, 0.4747,
-0.0271, -1.2161, 1.2121, 1.8738, -1.2207, -0.9218, -0.1430, 0.2512,
-0.5236, -0.2544, -0.5868, -0.7086, -1.3328, -0.0243, 0.4759, 1.4125,
0.4947, 0.5054, 1.6253, 0.4198, -0.9150, 0.6374, 0.4581, 1.1527,
1.4440, -0.0590, -0.4601, 0.2490, -0.5739, 0.6798, -0.2156, -1.1386,
-0.5011, -0.7411, 0.2825, -0.2595, 0.8070, 0.5270, 0.2595, -0.1089,
0.4221, -0.7851, 0.7112, -0.3038, 0.6169, -0.1513, -0.5872, 0.3974,
0.2431, 0.4934, -0.9406, -0.9372, 1.4525, 0.1376, 0.2558, 0.0661,
0.3509, 2.1667, 2.8428, 0.9429, -0.6143, -1.0969, 0.0955, 0.0914]],
device='cuda:0', grad_fn=<EmbeddingBackward>)
self.embedding(input).view(1, 1, -1)
的形状为(1, 1, 256)
,样本输出为
tensor([[[ 0.1421, 0.4135, -1.0619, 0.0149, 0.0673, -0.3770, 0.4231,
2.2803, -1.6939, -0.0071, 1.1131, -1.0019, 0.6593, 0.1366,
1.1033, -0.8804, 1.3676, 0.4115, -0.5671, 0.3314, -0.2599,
-0.3082, 1.3644, 0.5788, -0.1929, -2.0505, 0.4518, 0.8757,
-0.2360, -0.4099, -0.5697, -1.5973, -0.6638, -1.1523, 1.4425,
1.3651, 1.9371, 0.5698, -0.3541, -1.3883, -0.0195, -1.0757,
-1.4324, -1.6226, -2.4267, 0.3874, -0.7529, 1.4938, -2.5773,
-1.1962, 0.3759, -0.6143, -1.0444, -0.6443, -0.8130, -1.7283,
1.4167, 1.3945, -1.2695, 0.7289, 0.7777, -0.0094, -1.8108,
0.2126, -0.2018, -0.4055, -0.7779, -0.8523, 0.0162, 0.2463,
0.5588, -0.7250, -0.0128, 0.6272, -0.7729, 0.4259, 0.7596,
-1.9500, 0.5853, 0.3764, -0.1112, 0.7274, -2.8535, -0.0445,
0.4225, 1.2179, 0.2219, -0.7064, -0.9654, 1.0501, 1.7142,
0.5312, -0.8180, -1.5697, 1.3062, -0.9321, -0.1652, -1.5298,
-0.3575, -1.2046, -0.6571, -0.7689, -0.7032, 1.0727, -1.3259,
0.1200, 1.9357, -0.2519, -0.3717, 0.8054, 0.1180, -0.6921,
1.0245, -1.5500, -0.5280, -0.7462, 0.7924, 2.2701, -1.5094,
-0.1973, -1.5919, 0.4869, 0.6739, -0.5242, 0.2559, -0.0149,
-0.5332, -1.8313, 0.3598, 0.0804, -0.0780, -0.2930, -0.2844,
-0.4752, -0.9919, 0.1809, 0.7622, -2.5069, -0.7724, -0.9441,
1.6101, 0.6461, -0.8932, 0.0600, 0.6911, 0.5191, -0.1719,
-0.5829, -0.9168, 1.5282, 1.4399, 0.3264, -0.8894, 0.2880,
-0.0697, 0.8977, -0.5004, 0.3844, 0.0925, 0.5592, -0.1664,
0.8575, -1.0348, 0.7326, -0.2124, 0.7533, 0.6270, -0.9559,
-1.4159, 0.6788, 0.6163, -0.5951, -0.1403, -1.6088, -0.7731,
0.3876, 1.0429, -2.0960, 0.1726, 1.7446, -0.3963, 0.0785,
-0.4701, 1.0074, 0.3319, -2.2675, -1.6163, -0.4003, -0.5468,
0.0452, -2.5586, 0.4747, -0.0271, -1.2161, 1.2121, 1.8738,
-1.2207, -0.9218, -0.1430, 0.2512, -0.5236, -0.2544, -0.5868,
-0.7086, -1.3328, -0.0243, 0.4759, 1.4125, 0.4947, 0.5054,
1.6253, 0.4198, -0.9150, 0.6374, 0.4581, 1.1527, 1.4440,
-0.0590, -0.4601, 0.2490, -0.5739, 0.6798, -0.2156, -1.1386,
-0.5011, -0.7411, 0.2825, -0.2595, 0.8070, 0.5270, 0.2595,
-0.1089, 0.4221, -0.7851, 0.7112, -0.3038, 0.6169, -0.1513,
-0.5872, 0.3974, 0.2431, 0.4934, -0.9406, -0.9372, 1.4525,
0.1376, 0.2558, 0.0661, 0.3509, 2.1667, 2.8428, 0.9429,
-0.6143, -1.0969, 0.0955, 0.0914]]], device='cuda:0',
grad_fn=<ViewBackward>)
代码
这个代码有效,
rnn1 = nn.GRU(256, 128, 1)
input1 = torch.randn(100, 2, 256)
h01 = torch.randn(1, 2, 128)
output1, hn1 = rnn1(input1, h01)
print(input1.shape, h01.shape)
print(output1.shape, hn1.shape)
输出
torch.Size([100, 2, 256]) torch.Size([1, 2, 128])
torch.Size([100, 2, 128]) torch.Size([1, 2, 128])
代码
此代码也适用,
rnn1 = nn.GRU(256, 256)
input1 = torch.randn(1, 1, 256)
h01 = torch.randn(1, 1, 256)
output1, hn1 = rnn1(input1, h01)
print(input1.shape, h01.shape)
print(output1.shape, hn1.shape)
输出
torch.Size([1, 1, 256]) torch.Size([1, 1, 256])
torch.Size([1, 1, 256]) torch.Size([1, 1, 256])
代码
这不起作用,
rnn1 = nn.GRU(256, 256)
input1 = torch.randn(1, 256)
#input1 = input1.view(1, 1, -1)
h01 = torch.randn(1, 1, 256)
output1, hn1 = rnn1(input1, h01)
print(input1.shape, h01.shape)
print(output1.shape, hn1.shape)
输出
RuntimeError: input must have 3 dimensions, got 2