使用 Pytorch 进行深度学习:理解神经网络示例



我正在阅读 Pytorch 文档,我有几个关于引入的神经网络的问题。该文档定义了以下网络:

import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 1 input image channel, 6 output channels, 3x3 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:]  # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features

稍后,发表以下声明:

让我们尝试一个随机的 32x32 输入。注意:此网络(LeNet(的预期输入大小为32x32。要在MNIST数据集上使用此网络,请将数据集中的图像大小调整为32x32。

问题 1:为什么图像需要为 32x32(我假设这意味着 32 x 32 像素(?

第一个卷积将六个内核应用于一个图像,每个内核都是 3x3。这意味着,如果输入通道为 32x32,则六个输出通道的尺寸均为 30x30(3x3 内核网格会使宽度和高度损失 2 个像素(。第二个卷积应用了更多的内核,因此现在有 16 个尺寸为 28x28 的输出通道(同样,3x3 内核网格会让你在宽度和高度上损失 2 个像素(。现在我预计下一层有 16x28x28 个节点,因为 16 个输出通道中的每一个都有 28x28 像素。不知何故,这是不正确的,下一层包含 16x6x6 节点。为什么会这样呢?

问题2:第二个卷积层从六个输入通道变为十六个输出通道。这是怎么做到的?

在第一个卷积层中,我们从一个输入通道转到六个输入通道,这对我来说很有意义。您只需将六个内核应用于单个输入通道即可到达六个输出通道。从六个输入通道到十六个输出通道对我来说没有多大意义。如何应用不同的内核?您是否将两个内核应用于前五个输入通道以到达十个输出通道,并将六个内核应用于最后一个输入通道,以便总共达到十六个输出通道?还是神经网络学会了使用 x 内核并将它们应用于它认为最合适的输入通道?

我现在可以自己回答这些问题了。

问题 1:要了解为什么需要 32x32 图像才能使此神经网络工作,请考虑以下事项:

第 1 层:首先,使用 3x3 内核应用卷积。由于图像的尺寸为 32x32,这将导致网格为 30x30。接下来,将最大池化应用于网格,使用 2x2 内核和 2 步幅生成维度为 15x15 的网格。

第 2 层:首先,使用 3x3 内核将卷积应用于 15x15 网格,从而产生 13x13 网格。接下来,使用 2x2 内核和 2 步幅应用最大池化,从而生成维度为 6x6 的网格。我们得到一个 6x6 网格而不是一个 7x7 网格,因为默认情况下使用楼层函数而不是 ceil 函数。

由于第 2 层的卷积有 16 个输出通道,因此第一层线性层需要 16x6x6 个节点!我们看到所需的输入确实是 32x32 图像。

问题 2:每个输出通道都是通过对每个输入通道应用六个不同的内核并对结果求和来创建的。文档中对此进行了说明。

最新更新