我最近被介绍给PyTorch,并开始浏览库的文档和教程。 在"使用 numpy 和 scipy 创建扩展"教程中的"无参数示例"下,使用 numpy 创建了一个名为BadFFTFunction
的示例函数。
函数的说明指出:
"这一层并没有特别做任何有用或数学上的事情。 正确。
它被恰当地命名为BadFFTFunction">
该函数及其用法给出如下:
from numpy.fft import rfft2, irfft2
class BadFFTFunction(Function):
def forward(self, input):
numpy_input = input.numpy()
result = abs(rfft2(numpy_input))
return torch.FloatTensor(result)
def backward(self, grad_output):
numpy_go = grad_output.numpy()
result = irfft2(numpy_go)
return torch.FloatTensor(result)
def incorrect_fft(input):
return BadFFTFunction()(input)
input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)
不幸的是,我最近才接触到信号处理,并且不确定此函数中(可能明显的(错误在哪里。
我想知道,如何修复此功能以使其向前和向后输出正确?
如何修复BadFFTFunction
,以便在 PyTorch 中使用可微分的 FFT 函数?
我认为错误是:首先,尽管该函数的名称中包含FFT,但仅返回FFT输出的振幅/绝对值,而不是完整的复系数。此外,仅使用逆FFT来计算振幅的梯度在数学上可能没有多大意义(?
有一个名为pytorch-fft的软件包,它试图在pytorch中提供FFT功能。您可以在此处查看一些自动grad功能的实验性代码。另请注意本期的讨论。
从 1.8 版本开始,PyTorch 有torch.fft
:
torch.fft.fft(input)