尺寸不匹配,m1:[288 x 9],m2:[2592 x 256],/tmp/pip-req-build-4baxyd



这是我得到的错误

*** RuntimeError: size mismatch, m1: [288 x 9], m2: [2592 x 256] at /tmp/pip-req-build-4baxydiv/aten/src/TH/generic/THTensorMath.cpp:197

这是我的模型:

class DQN(nn.Module):
def __init__(self, num_actions, lr):
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=4, out_channels=16, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(
in_channels=16, out_channels=32, kernel_size=4, stride=2)
# You have to respect the formula ((W-K+2P/S)+1)
self.fc = nn.Linear(in_features=32*9*9, out_features=256)
self.out = nn.Linear(in_features=256, out_features=num_actions)

def forward(self, state):
import ipdb; ipdb.set_trace()
# (1) Hidden Conv. Layer
self.layer1 = F.relu(self.conv1(state))
# (2) Hidden Conv. Layer
self.layer2 = F.relu(self.conv2(self.layer1))
# (3) Hidden Linear Layer
self.layer3 = self.fc(self.layer2)
# (4) Output
actions = self.out(self.layer3)
return actions

错误在第self.layer3 = self.fc(self.layer2)行触发。state是一个形状为(1, 4, 84, 84)的 pytorch 张量。

完整的回溯是

Traceback (most recent call last):
File "/home/infinity/Projects/Exercice_Project/AI_Exercices/Atari_2600_Breakout.py", line 228, in <module>
action = agent.choose_action(state, policy_network)
File "/home/infinity/Projects/Exercice_Project/AI_Exercices/Atari_2600_Breakout.py", line 166, in choose_action
return policy_net(state).argmax(dim=1).to(self.device)
File "/home/infinity/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/infinity/Projects/Exercice_Project/AI_Exercices/Atari_2600_Breakout.py", line 101, in forward
self.layer3 = self.fc(self.layer2)
File "/home/infinity/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/infinity/anaconda3/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "/home/infinity/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 1372, in linear
output = input.matmul(weight.t())
RuntimeError: size mismatch, m1: [288 x 9], m2: [2592 x 256] at /tmp/pip-req-build-4baxydiv/aten/src/TH/generic/THTensorMath.cpp:197

就在self.layer3 = self.fc(self.layer2)之前,我不得不添加行input_layer3 = self.layer2.reshape(-1, 32*9*9)

相关内容

最新更新