属性错误:'tuple'对象在将输入馈送到 Pytorch LSTM 网络时没有属性"dim"



我正在尝试运行以下代码:

import matplotlib.pylab as plt
import numpy as np
import torch
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self, input_shape, n_actions):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(input_shape, 12)
self.hidden2tag = nn.Linear(12, n_actions)
def forward(self, x):
out = self.lstm(x)
out = self.hidden2tag(out)
return out

state = [(1,2,3,4,5),(2,3,4,5,6),(3,4,5,6,7),(4,5,6,7,8),(5,6,7,8,9),(6,7,8,9,0)]
device = torch.device("cuda")
net = LSTM(5, 3).to(device)
state_v = torch.FloatTensor(state).to(device)
q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
_, action = int(torch.max(q_vals_v, dim=1).item())

然后返回这个错误:

Traceback (most recent call last):
File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/QuestionStackoverflow.py", line 26, in <module>
q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/QuestionStackoverflow.py", line 15, in forward
out = self.hidden2tag(out)
File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/linear.py", line 55, in forward
return F.linear(input, self.weight, self.bias)
File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/functional.py", line 1022, in linear
if input.dim() == 2 and bias is not None:
AttributeError: 'tuple' object has no attribute 'dim'

有人知道怎么解决这个问题吗?(去掉张量是元组,这样它就可以被馈送到LSTM网络中(

pytorch LSTM返回一个元组
所以您会得到这个错误,因为您的线性层self.hidden2tag无法处理这个元组。

所以改变:

out = self.lstm(x)

out, states = self.lstm(x)

这将通过拆分元组来修复您的错误,使out只是您的输出张量。

然后,out存储隐藏状态,而states是另一个包含最后一个隐藏状态和单元格状态的元组。

您也可以在此处查看:
https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM

由于max()也返回一个元组,最后一行将出现另一个错误。但这应该很容易修复,而且是不同的错误:(

首先在numpy数组中转换状态:

state = np.array(state)

PyTorch可能在其API中缺少一个np.asarray

最新更新