我正在尝试使一些与pytorch 1.5.0一起工作的python3代码在新版本上也能正常工作(我目前使用的是pytorch 1.9.0(。更具体地说,我正在尝试更新进行快速傅立叶变换的代码。我正在尝试用pytorch 1.9.0中的torch.fft.fftn((和torch.view_as_real((替换pytorch 1.5.0中的torch.rfft((。我注意到,当我运行以下程序时,我得到的输出略有不同:
使用PyTorch 1.5.0:
import torch
import numpy as np
arr = torch.from_numpy(np.array([[1.,2.,3.,4.,5.],
[6.,7.,8.,9.,10.],
[11.,12.,13.,14.,15.],
[16.,17.,18.,19.,20.]]))
ftt_arr = torch.rfft(arr,2,onesided=False)
print(fft_arr)
使用PyTorch 1.9.0:
import torch
import numpy as np
arr = torch.from_numpy(np.array([[1.,2.,3.,4.,5.],
[6.,7.,8.,9.,10.],
[11.,12.,13.,14.,15.],
[16.,17.,18.,19.,20.]]))
fft_arr = torch.fft.fftn(arr,norm="backward")
fft_arr = torch.view_as_real(fft_arr)
print(fft_arr)
两个快速傅立叶变换的输出如下:
pytorch 1.5.0:
tensor([[[211.0000, 0.0000],
[-10.8090, 13.1760],
[ -9.6910, 4.2003],
[ -9.6910, -4.2003],
[-10.8090, -13.1760]],
[[-50.0000, 51.0000],
[ 0.5878, -0.8090],
[ -0.9511, 0.3090],
[ 0.9511, 0.3090],
[ -0.5878, -0.8090]],
[[-51.0000, 0.0000],
[ 0.8090, 0.5878],
[ -0.3090, -0.9511],
[ -0.3090, 0.9511],
[ 0.8090, -0.5878]],
[[-50.0000, -51.0000],
[ -0.5878, 0.8090],
[ 0.9511, -0.3090],
[ -0.9511, -0.3090],
[ 0.5878, 0.8090]]], dtype=torch.float64)
pytorch 1.9.0:
tensor([[[ 2.1000e+02, 0.0000e+00],
[-1.0000e+01, 1.3764e+01],
[-1.0000e+01, 3.2492e+00],
[-1.0000e+01, -3.2492e+00],
[-1.0000e+01, -1.3764e+01]],
[[-5.0000e+01, 5.0000e+01],
[ 2.2204e-15, 0.0000e+00],
[ 1.7764e-15, -4.4409e-16],
[ 1.7764e-15, -4.4409e-16],
[ 2.2204e-15, 0.0000e+00]],
[[-5.0000e+01, 0.0000e+00],
[-1.7764e-15, 0.0000e+00],
[-8.8818e-16, 0.0000e+00],
[-8.8818e-16, 0.0000e+00],
[-1.7764e-15, 0.0000e+00]],
[[-5.0000e+01, -5.0000e+01],
[ 2.2204e-15, 0.0000e+00],
[ 1.7764e-15, 4.4409e-16],
[ 1.7764e-15, 4.4409e-16],
[ 2.2204e-15, 0.0000e+00]]], dtype=torch.float64)
所有的输出值似乎都相差+/-1左右,我无法解释或调和这一点。
我希望以下内容能帮助
import torch
a = torch.arange(0,4).view(1,1,2,2).float()
print(a)
现在开始PyTorch 1.9 的代码
def dft_amp(img):
fft_im = torch.view_as_real(torch.fft.rfftn(img, dim=(2,3),norm="backward"))
#torch.rfft( img, signal_ndim=2, onesided=False )
print('Pytorch FFT 1.9',fft_im)
fft_amp = fft_im[:,:,:,:,0]**2 + fft_im[:,:,:,:,1]**2
return torch.sqrt(fft_amp + 1e-10)
b = dft_amp(a)
print('Pytorch 1.9 amp', b)
PyTorch 1.9的输出是
tensor([[[[0., 1.],
[2., 3.]]]])
Pytorch FFT 1.9 tensor([[[[[ 6., 0.],
[-2., 0.]],
[[-4., 0.],
[ 0., 0.]]]]])
Pytorch 1.9 amp tensor([[[[6.0000e+00, 2.0000e+00],
[4.0000e+00, 1.0000e-05]]]])
现在开始PyTorch 1.5 的代码
def dft_amp(img):
fft_im = torch.rfft( img, signal_ndim=2, onesided=False )#torch.view_as_real(torch.fft.rfftn(a, dim=(2,3),norm="backward"))
print('Pytorch FFT 1.5',fft_im)
fft_amp = fft_im[:,:,:,:,0]**2 + fft_im[:,:,:,:,1]**2
return torch.sqrt(fft_amp + 1e-10)
b = dft_amp(a)
print('Pytorch 1.5 amp', b)
PyTorch 1.5版本的输出是
tensor([[[[0., 1.],
[2., 3.]]]])
Pytorch FFT 1.5 tensor([[[[[ 6., 0.],
[-2., 0.]],
[[-4., 0.],
[ 0., 0.]]]]])
1.5 amp tensor([[[[6.0000e+00, 2.0000e+00],
[4.0000e+00, 1.0000e-05]]]])
两个版本中的值匹配!