在PyTorch中,Tensor
类具有grad_fn
属性。这个引用了用来获取张量的操作:例如,如果a = b + 2
,a.grad_fn
将是AddBackward0
。但是什么是"参考"?究竟意味着什么?
使用inspect.getmro(type(a.grad_fn))
检查AddBackward0
将声明AddBackward0
的唯一基类是object
。此外,这个类的源代码(实际上,在grad_fn
中可能遇到的任何其他类)在源代码中找不到!
所有这些使我想到以下问题:
grad_fn
中究竟存储了什么?在反向传播期间如何调用它?- 为什么在
grad_fn
中存储的对象没有某种共同的超类,为什么在GitHub上没有源代码?
grad_fn
是一个函数"handle",允许访问适用的梯度函数。给定点的梯度是在反向传播过程中调整权重的系数。
"Handle"是对象描述符的通用术语,旨在提供对对象的适当访问。例如,当你打开一个文件时,open
返回一个文件句柄。当你实例化一个类时,__init__
函数返回一个句柄给创建的实例。句柄包含对所讨论项的数据和函数的引用(通常是内存地址)。
显示为通用的object
类,因为它来自另一种语言的底层实现,因此它不会将精确地映射到Python的function
类型。PyTorch处理语言间调用和返回。这个移交是预编译(共享对象)运行时系统的一部分。
这足以说明你所看到的吗?