如何在 PyTorch 中正确使用 Numpy 的 FFT 函数?



我最近被介绍给PyTorch,并开始浏览库的文档和教程。 在"使用 numpy 和 scipy 创建扩展"教程中的"无参数示例"下,使用 numpy 创建了一个名为BadFFTFunction的示例函数。

函数的说明指出:

"这一层并没有特别做任何有用或数学上的事情。 正确。

它被恰当地命名为BadFFTFunction">

该函数及其用法给出如下:

from numpy.fft import rfft2, irfft2
class BadFFTFunction(Function):
def forward(self, input):
numpy_input = input.numpy()
result = abs(rfft2(numpy_input))
return torch.FloatTensor(result)
def backward(self, grad_output):
numpy_go = grad_output.numpy()
result = irfft2(numpy_go)
return torch.FloatTensor(result)
def incorrect_fft(input):
return BadFFTFunction()(input)
input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)

不幸的是,我最近才接触到信号处理,并且不确定此函数中(可能明显的(错误在哪里。

我想知道,如何修复此功能以使其向前和向后输出正确?

如何修复BadFFTFunction,以便在 PyTorch 中使用可微分的 FFT 函数?

我认为错误是:首先,尽管该函数的名称中包含FFT,但仅返回FFT输出的振幅/绝对值,而不是完整的复系数。此外,仅使用逆FFT来计算振幅的梯度在数学上可能没有多大意义(?

有一个名为pytorch-fft的软件包,它试图在pytorch中提供FFT功能。您可以在此处查看一些自动grad功能的实验性代码。另请注意本期的讨论。

从 1.8 版本开始,PyTorch 有torch.fft

torch.fft.fft(input)

最新更新