Pytorch模块如何进行后退道具



在按照扩展pytorch的说明 - 添加模块时,我注意到扩展模块时注意到,我们实际上不必实现向后的功能。我们唯一需要的是在向前函数中应用功能实例,而pytorch可以在执行后退道具时在功能实例中自动调用一个函数。这对我来说似乎是神奇的,因为我们甚至没有注册使用的功能实例。我研究了源代码,但没有找到任何相关的内容。谁能向我指出一个实际发生的地方吗?

不必实现backward()是Pytorch或任何其他DL框架如此有价值的原因。实际上,仅在您需要弄乱网络梯度的非常具体的情况下(或创建无法使用Pytorch的内置功能表达的自定义函数时(才能实现backward()(。

pytorch使用计算图计算向后梯度,该计算图可以跟踪您的正向通行过程中所做的操作。在Variable上进行的任何操作都隐含地在此处注册。然后,这是从称为变量的向后向后的问题,并应用导数链规则来计算梯度。

pytorch的关于图表的可视化及其通常如何工作。如果您想要更多详细信息,我还建议您在Google上查找计算图和自动射击机制。

编辑:所有这些发生的源代码将在Pytorch代码库的C部分中,在该代码库中实现了实际图。经过一些挖掘后,我发现了这一点:

/// Evaluates the function on the given inputs and returns the result of the
/// function call.
variable_list operator()(const variable_list& inputs) {
    profiler::RecordFunction rec(this);
    if (jit::tracer::isTracingVar(inputs)) {
        return traced_apply(inputs);
    }
    return apply(inputs);
}

因此,在每个功能中,Pytorch首先检查其输入是否需要跟踪,并在此处执行trace_apply((。您可以看到正在创建的节点并将其附加到图表:

// Insert a CppOp in the trace.
auto& graph = state->graph;
std::vector<VariableFlags> var_flags;
for(auto & input: inputs) {
    var_flags.push_back(VariableFlags::of(input));
}
auto* this_node = graph->createCppOp(get_shared_ptr(), std::move(var_flags));
// ...
for (auto& input: inputs) {
    this_node->addInput(tracer::getValueTrace(state, input));
}
graph->appendNode(this_node);

我最好的猜测是,每个函数对象在执行时都会注册自身及其输入(如果需要(。每个非功能调用(例如变量dot(((简单地辩护到相应的函数,因此仍然适用。

注意:我不参加Pytorch的开发,绝不是其架构的专家。任何更正或添加都将受到欢迎。

也许我不正确,但是我有不同的视图。

向后函数定义并通过正向函数调用。

例如:

#!/usr/bin/env python
# encoding: utf-8
###############################################################
# Parametrized example
# --------------------
#
# This implements a layer with learnable weights.
#
# It implements the Cross-correlation with a learnable kernel.
#
# In deep learning literature, it’s confusingly referred to as
# Convolution.
#
# The backward computes the gradients wrt the input and gradients wrt the
# filter.
#
# **Implementation:**
#
# *Please Note that the implementation serves as an illustration, and we
# did not verify it’s correctness*
import torch
from torch.autograd import Function
from torch.autograd import Variable
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter

class ScipyConv2dFunction(Function):
    @staticmethod
    def forward(ctx, input, filter):
        result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
        ctx.save_for_backward(input, filter)
        return input.new(result)
    @staticmethod
    def backward(ctx, grad_output):
        input, filter = ctx.saved_tensors
        grad_output = grad_output.data
        grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
        grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')
        return Variable(grad_output.new(grad_input)), 
            Variable(grad_output.new(grad_filter))

class ScipyConv2d(Module):
    def __init__(self, kh, kw):
        super(ScipyConv2d, self).__init__()
        self.filter = Parameter(torch.randn(kh, kw))
    def forward(self, input):
        return ScipyConv2dFunction.apply(input, self.filter)
###############################################################
# **Example usage:**
module = ScipyConv2d(3, 3)
print(list(module.parameters()))
input = Variable(torch.randn(10, 10), requires_grad=True)
output = module(input)
print(output)
output.backward(torch.randn(8, 8))
print(input.grad)

在此示例中,向后函数由ScipyConv2DFunction函数定义。

和scipyConv2DFunction通过正向函数调用。

我正确吗?

最新更新