pytorch的autoencoder教程为什么会改变嵌入层输出的视图



如PyTorch教程中所示,自动编码器模型的代码如下所示:

class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size)
def forward(self, input, hidden):
embedded = self.embedding(input).view(1, 1, -1)
output = embedded
output, hidden = self.gru(output, hidden)
return output, hidden
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)

我的问题是,在embedding层的输出上使用view函数的原因是什么?

视图函数为给定的输入形状添加了额外的维度,以匹配预期的输入形状。在函数initHidden中,隐藏形状被初始化为(1, 1, 256)

def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)

根据文档,GRU输入形状必须具有3个维度input of shape (seq_len, batch, input_size)

https://pytorch.org/docs/stable/generated/torch.nn.GRU.html

self.embedding(input)的形状为(1, 256),样本输出为

tensor([[ 0.1421,  0.4135, -1.0619,  0.0149,  0.0673, -0.3770,  0.4231,  2.2803,
-1.6939, -0.0071,  1.1131, -1.0019,  0.6593,  0.1366,  1.1033, -0.8804,
1.3676,  0.4115, -0.5671,  0.3314, -0.2599, -0.3082,  1.3644,  0.5788,
-0.1929, -2.0505,  0.4518,  0.8757, -0.2360, -0.4099, -0.5697, -1.5973,
-0.6638, -1.1523,  1.4425,  1.3651,  1.9371,  0.5698, -0.3541, -1.3883,
-0.0195, -1.0757, -1.4324, -1.6226, -2.4267,  0.3874, -0.7529,  1.4938,
-2.5773, -1.1962,  0.3759, -0.6143, -1.0444, -0.6443, -0.8130, -1.7283,
1.4167,  1.3945, -1.2695,  0.7289,  0.7777, -0.0094, -1.8108,  0.2126,
-0.2018, -0.4055, -0.7779, -0.8523,  0.0162,  0.2463,  0.5588, -0.7250,
-0.0128,  0.6272, -0.7729,  0.4259,  0.7596, -1.9500,  0.5853,  0.3764,
-0.1112,  0.7274, -2.8535, -0.0445,  0.4225,  1.2179,  0.2219, -0.7064,
-0.9654,  1.0501,  1.7142,  0.5312, -0.8180, -1.5697,  1.3062, -0.9321,
-0.1652, -1.5298, -0.3575, -1.2046, -0.6571, -0.7689, -0.7032,  1.0727,
-1.3259,  0.1200,  1.9357, -0.2519, -0.3717,  0.8054,  0.1180, -0.6921,
1.0245, -1.5500, -0.5280, -0.7462,  0.7924,  2.2701, -1.5094, -0.1973,
-1.5919,  0.4869,  0.6739, -0.5242,  0.2559, -0.0149, -0.5332, -1.8313,
0.3598,  0.0804, -0.0780, -0.2930, -0.2844, -0.4752, -0.9919,  0.1809,
0.7622, -2.5069, -0.7724, -0.9441,  1.6101,  0.6461, -0.8932,  0.0600,
0.6911,  0.5191, -0.1719, -0.5829, -0.9168,  1.5282,  1.4399,  0.3264,
-0.8894,  0.2880, -0.0697,  0.8977, -0.5004,  0.3844,  0.0925,  0.5592,
-0.1664,  0.8575, -1.0348,  0.7326, -0.2124,  0.7533,  0.6270, -0.9559,
-1.4159,  0.6788,  0.6163, -0.5951, -0.1403, -1.6088, -0.7731,  0.3876,
1.0429, -2.0960,  0.1726,  1.7446, -0.3963,  0.0785, -0.4701,  1.0074,
0.3319, -2.2675, -1.6163, -0.4003, -0.5468,  0.0452, -2.5586,  0.4747,
-0.0271, -1.2161,  1.2121,  1.8738, -1.2207, -0.9218, -0.1430,  0.2512,
-0.5236, -0.2544, -0.5868, -0.7086, -1.3328, -0.0243,  0.4759,  1.4125,
0.4947,  0.5054,  1.6253,  0.4198, -0.9150,  0.6374,  0.4581,  1.1527,
1.4440, -0.0590, -0.4601,  0.2490, -0.5739,  0.6798, -0.2156, -1.1386,
-0.5011, -0.7411,  0.2825, -0.2595,  0.8070,  0.5270,  0.2595, -0.1089,
0.4221, -0.7851,  0.7112, -0.3038,  0.6169, -0.1513, -0.5872,  0.3974,
0.2431,  0.4934, -0.9406, -0.9372,  1.4525,  0.1376,  0.2558,  0.0661,
0.3509,  2.1667,  2.8428,  0.9429, -0.6143, -1.0969,  0.0955,  0.0914]],
device='cuda:0', grad_fn=<EmbeddingBackward>)

self.embedding(input).view(1, 1, -1)的形状为(1, 1, 256),样本输出为

tensor([[[ 0.1421,  0.4135, -1.0619,  0.0149,  0.0673, -0.3770,  0.4231,
2.2803, -1.6939, -0.0071,  1.1131, -1.0019,  0.6593,  0.1366,
1.1033, -0.8804,  1.3676,  0.4115, -0.5671,  0.3314, -0.2599,
-0.3082,  1.3644,  0.5788, -0.1929, -2.0505,  0.4518,  0.8757,
-0.2360, -0.4099, -0.5697, -1.5973, -0.6638, -1.1523,  1.4425,
1.3651,  1.9371,  0.5698, -0.3541, -1.3883, -0.0195, -1.0757,
-1.4324, -1.6226, -2.4267,  0.3874, -0.7529,  1.4938, -2.5773,
-1.1962,  0.3759, -0.6143, -1.0444, -0.6443, -0.8130, -1.7283,
1.4167,  1.3945, -1.2695,  0.7289,  0.7777, -0.0094, -1.8108,
0.2126, -0.2018, -0.4055, -0.7779, -0.8523,  0.0162,  0.2463,
0.5588, -0.7250, -0.0128,  0.6272, -0.7729,  0.4259,  0.7596,
-1.9500,  0.5853,  0.3764, -0.1112,  0.7274, -2.8535, -0.0445,
0.4225,  1.2179,  0.2219, -0.7064, -0.9654,  1.0501,  1.7142,
0.5312, -0.8180, -1.5697,  1.3062, -0.9321, -0.1652, -1.5298,
-0.3575, -1.2046, -0.6571, -0.7689, -0.7032,  1.0727, -1.3259,
0.1200,  1.9357, -0.2519, -0.3717,  0.8054,  0.1180, -0.6921,
1.0245, -1.5500, -0.5280, -0.7462,  0.7924,  2.2701, -1.5094,
-0.1973, -1.5919,  0.4869,  0.6739, -0.5242,  0.2559, -0.0149,
-0.5332, -1.8313,  0.3598,  0.0804, -0.0780, -0.2930, -0.2844,
-0.4752, -0.9919,  0.1809,  0.7622, -2.5069, -0.7724, -0.9441,
1.6101,  0.6461, -0.8932,  0.0600,  0.6911,  0.5191, -0.1719,
-0.5829, -0.9168,  1.5282,  1.4399,  0.3264, -0.8894,  0.2880,
-0.0697,  0.8977, -0.5004,  0.3844,  0.0925,  0.5592, -0.1664,
0.8575, -1.0348,  0.7326, -0.2124,  0.7533,  0.6270, -0.9559,
-1.4159,  0.6788,  0.6163, -0.5951, -0.1403, -1.6088, -0.7731,
0.3876,  1.0429, -2.0960,  0.1726,  1.7446, -0.3963,  0.0785,
-0.4701,  1.0074,  0.3319, -2.2675, -1.6163, -0.4003, -0.5468,
0.0452, -2.5586,  0.4747, -0.0271, -1.2161,  1.2121,  1.8738,
-1.2207, -0.9218, -0.1430,  0.2512, -0.5236, -0.2544, -0.5868,
-0.7086, -1.3328, -0.0243,  0.4759,  1.4125,  0.4947,  0.5054,
1.6253,  0.4198, -0.9150,  0.6374,  0.4581,  1.1527,  1.4440,
-0.0590, -0.4601,  0.2490, -0.5739,  0.6798, -0.2156, -1.1386,
-0.5011, -0.7411,  0.2825, -0.2595,  0.8070,  0.5270,  0.2595,
-0.1089,  0.4221, -0.7851,  0.7112, -0.3038,  0.6169, -0.1513,
-0.5872,  0.3974,  0.2431,  0.4934, -0.9406, -0.9372,  1.4525,
0.1376,  0.2558,  0.0661,  0.3509,  2.1667,  2.8428,  0.9429,
-0.6143, -1.0969,  0.0955,  0.0914]]], device='cuda:0',
grad_fn=<ViewBackward>)

代码

这个代码有效,

rnn1 = nn.GRU(256, 128, 1)
input1 = torch.randn(100, 2, 256)
h01 = torch.randn(1, 2, 128)
output1, hn1 = rnn1(input1, h01)
print(input1.shape, h01.shape)
print(output1.shape, hn1.shape)

输出

torch.Size([100, 2, 256]) torch.Size([1, 2, 128])
torch.Size([100, 2, 128]) torch.Size([1, 2, 128])

代码

此代码也适用,

rnn1 = nn.GRU(256, 256)
input1 = torch.randn(1, 1, 256)
h01 = torch.randn(1, 1, 256)
output1, hn1 = rnn1(input1, h01)
print(input1.shape, h01.shape)
print(output1.shape, hn1.shape)

输出

torch.Size([1, 1, 256]) torch.Size([1, 1, 256])
torch.Size([1, 1, 256]) torch.Size([1, 1, 256])

代码

这不起作用,

rnn1 = nn.GRU(256, 256)
input1 = torch.randn(1, 256)
#input1 = input1.view(1, 1, -1)
h01 = torch.randn(1, 1, 256)
output1, hn1 = rnn1(input1, h01)
print(input1.shape, h01.shape)
print(output1.shape, hn1.shape)

输出

RuntimeError: input must have 3 dimensions, got 2

相关内容

  • 没有找到相关文章

最新更新