保存图像时,PyTorch 对象对于数组来说太深



我正在尝试从以下github代表运行代码:

https://github.com/iamkrut/image_inpainting_resnet_unet

我没有更改代码中的任何内容,当代码尝试保存图像时,它会导致 ValueError,即对象太深。错误似乎来自这两行。

images = img_tensor.cpu().detach().permute(0,2,3,1)
plt.imsave(join(data_dir, 'samples', image), images[index,:,:,:3])

这是错误语句

File "train.py", line 205, in <module>
data_dir=args.data_dir)
File "train.py", line 94, in train_net
plt.imsave(join(data_dir, 'samples', image), images[index,:,:,:]);
File "C:ProgramDataAnaconda3envstorch2libsite-packagesmatplotlibpyplot.py", line 2140, in imsave
return matplotlib.image.imsave(fname, arr, **kwargs)
File "C:ProgramDataAnaconda3envstorch2libsite-packagesmatplotlibimage.py", line 1498, in imsave
_png.write_png(rgba, fname, dpi=dpi)
ValueError: object too deep for desired array

有谁知道可能导致这种情况的原因或如何解决它? 谢谢

matplotlib 包不理解 pytorch 数据类型(tensor(。 您应该将张量数组转换为 numpy 数组,然后使用 matplotlib 函数。

a = torch.rand(10, 3, 20, 20)
plt.imsave("test.jpg", a.cpu().detach().permute(0, 2, 3, 1)[0, ...]) # Error
plt.imsave("test.jpg", a.cpu().detach().permute(0, 2, 3, 1).numpy()[0, ...])

我设法通过将行更改为

images=img_tensor.cpu().numpy()[0]
images = np.transpose(images, (1,2,0))
plt.imsave(join(data_dir, 'samples', image), images)

仍然不确定以前的版本出了什么问题。所以如果有人知道,请告诉我。

最新更新