调整pytorch 1.5.0和pytorch 1.9.0之间fft输出的差异



我正在尝试使一些与pytorch 1.5.0一起工作的python3代码在新版本上也能正常工作(我目前使用的是pytorch 1.9.0(。更具体地说,我正在尝试更新进行快速傅立叶变换的代码。我正在尝试用pytorch 1.9.0中的torch.fft.fftn((和torch.view_as_real((替换pytorch 1.5.0中的torch.rfft((。我注意到,当我运行以下程序时,我得到的输出略有不同:

使用PyTorch 1.5.0:

import torch
import numpy as np
arr = torch.from_numpy(np.array([[1.,2.,3.,4.,5.],
[6.,7.,8.,9.,10.],
[11.,12.,13.,14.,15.],
[16.,17.,18.,19.,20.]]))
ftt_arr = torch.rfft(arr,2,onesided=False)
print(fft_arr)

使用PyTorch 1.9.0:

import torch
import numpy as np
arr = torch.from_numpy(np.array([[1.,2.,3.,4.,5.],
[6.,7.,8.,9.,10.],
[11.,12.,13.,14.,15.],
[16.,17.,18.,19.,20.]]))
fft_arr = torch.fft.fftn(arr,norm="backward")
fft_arr = torch.view_as_real(fft_arr)
print(fft_arr)

两个快速傅立叶变换的输出如下:

pytorch 1.5.0:

tensor([[[211.0000,   0.0000],
[-10.8090,  13.1760],
[ -9.6910,   4.2003],
[ -9.6910,  -4.2003],
[-10.8090, -13.1760]],
[[-50.0000,  51.0000],
[  0.5878,  -0.8090],
[ -0.9511,   0.3090],
[  0.9511,   0.3090],
[ -0.5878,  -0.8090]],
[[-51.0000,   0.0000],
[  0.8090,   0.5878],
[ -0.3090,  -0.9511],
[ -0.3090,   0.9511],
[  0.8090,  -0.5878]],
[[-50.0000, -51.0000],
[ -0.5878,   0.8090],
[  0.9511,  -0.3090],
[ -0.9511,  -0.3090],
[  0.5878,   0.8090]]], dtype=torch.float64)

pytorch 1.9.0:

tensor([[[ 2.1000e+02,  0.0000e+00],
[-1.0000e+01,  1.3764e+01],
[-1.0000e+01,  3.2492e+00],
[-1.0000e+01, -3.2492e+00],
[-1.0000e+01, -1.3764e+01]],
[[-5.0000e+01,  5.0000e+01],
[ 2.2204e-15,  0.0000e+00],
[ 1.7764e-15, -4.4409e-16],
[ 1.7764e-15, -4.4409e-16],
[ 2.2204e-15,  0.0000e+00]],
[[-5.0000e+01,  0.0000e+00],
[-1.7764e-15,  0.0000e+00],
[-8.8818e-16,  0.0000e+00],
[-8.8818e-16,  0.0000e+00],
[-1.7764e-15,  0.0000e+00]],
[[-5.0000e+01, -5.0000e+01],
[ 2.2204e-15,  0.0000e+00],
[ 1.7764e-15,  4.4409e-16],
[ 1.7764e-15,  4.4409e-16],
[ 2.2204e-15,  0.0000e+00]]], dtype=torch.float64)

所有的输出值似乎都相差+/-1左右,我无法解释或调和这一点。

我希望以下内容能帮助

import torch
a = torch.arange(0,4).view(1,1,2,2).float()
print(a)

现在开始PyTorch 1.9 的代码

def dft_amp(img):
fft_im = torch.view_as_real(torch.fft.rfftn(img, dim=(2,3),norm="backward"))
#torch.rfft( img, signal_ndim=2, onesided=False )
print('Pytorch FFT 1.9',fft_im)
fft_amp = fft_im[:,:,:,:,0]**2 + fft_im[:,:,:,:,1]**2
return torch.sqrt(fft_amp + 1e-10)
b = dft_amp(a)
print('Pytorch 1.9 amp', b)

PyTorch 1.9的输出是

tensor([[[[0., 1.],
[2., 3.]]]])
Pytorch FFT 1.9 tensor([[[[[ 6.,  0.],
[-2.,  0.]],
[[-4.,  0.],
[ 0.,  0.]]]]])
Pytorch 1.9 amp tensor([[[[6.0000e+00, 2.0000e+00],
[4.0000e+00, 1.0000e-05]]]])

现在开始PyTorch 1.5 的代码

def dft_amp(img):
fft_im = torch.rfft( img, signal_ndim=2, onesided=False )#torch.view_as_real(torch.fft.rfftn(a, dim=(2,3),norm="backward"))

print('Pytorch FFT 1.5',fft_im)
fft_amp = fft_im[:,:,:,:,0]**2 + fft_im[:,:,:,:,1]**2
return torch.sqrt(fft_amp + 1e-10)
b = dft_amp(a)
print('Pytorch 1.5 amp', b)

PyTorch 1.5版本的输出是

tensor([[[[0., 1.],
[2., 3.]]]])
Pytorch FFT 1.5 tensor([[[[[ 6.,  0.],
[-2.,  0.]],
[[-4.,  0.],
[ 0.,  0.]]]]])
1.5 amp tensor([[[[6.0000e+00, 2.0000e+00],
[4.0000e+00, 1.0000e-05]]]])

两个版本中的值匹配!

最新更新