pytorch在网络中使用python计算代码会正确执行吗



以下面的伪代码为例:

class():
def forward(input):
x = some_torch_layers(input)
x = some_torch_layers(x)
...
x = sum(x) # or numpy or other operations
x = some_torch_layers(x)
return x

pytorch网运行良好吗?尤其是sum(x)在后向处理中表现良好。

TL;DR
否。

为了让PyTorch"表现良好",它需要通过网络传播梯度。PyTorch不知道(也不知道(如何区分arbitary numpy代码,它只能通过PyTorch张量运算传播梯度
在您的示例中,梯度将停止在numpysum,因此只有最顶部的火炬层将被训练(numpy操作和criterion之间的层(,其他层(输入和numpy操作之间(将具有零梯度,因此它们的参数将在整个训练过程中保持固定。

最新更新