在PyTorch中,grad_fn属性究竟存储了什么以及它是如何使用的?



在PyTorch中,Tensor类具有grad_fn属性。这个引用了用来获取张量的操作:例如,如果a = b + 2,a.grad_fn将是AddBackward0。但是什么是"参考"?究竟意味着什么?

使用inspect.getmro(type(a.grad_fn))检查AddBackward0将声明AddBackward0的唯一基类是object。此外,这个类的源代码(实际上,在grad_fn中可能遇到的任何其他类)在源代码中找不到!

所有这些使我想到以下问题:

  1. grad_fn中究竟存储了什么?在反向传播期间如何调用它?
  2. 为什么在grad_fn中存储的对象没有某种共同的超类,为什么在GitHub上没有源代码?

grad_fn是一个函数"handle",允许访问适用的梯度函数。给定点的梯度是在反向传播过程中调整权重的系数。

"Handle"是对象描述符的通用术语,旨在提供对对象的适当访问。例如,当你打开一个文件时,open返回一个文件句柄。当你实例化一个类时,__init__函数返回一个句柄给创建的实例。句柄包含对所讨论项的数据和函数的引用(通常是内存地址)。

显示为通用的object类,因为它来自另一种语言的底层实现,因此它不会将精确地映射到Python的function类型。PyTorch处理语言间调用和返回。这个移交是预编译(共享对象)运行时系统的一部分。

这足以说明你所看到的吗?

相关内容

  • 没有找到相关文章

最新更新