我正在实现关于数据增强的timm教程,以增加数据集的图像数量。根据他们的教程,我已经实现了相同的代码,但它没有工作。
import numpy as np
import torch
from PIL import Image
from timm.data.transforms_factory import create_transform
a = create_transform(224, is_training=True)
print(a)
pets_image_paths = './download.png'
image = Image.open(pets_image_paths)
# We can convert this into a tensor, and transpose the channels into the format that PyTorch expects:
np_image = np.array(image, dtype=np.float32)
image = torch.as_tensor(np_image).transpose(2, 0)[None]
from timm.data.transforms import RandomResizedCropAndInterpolation
tfm = RandomResizedCropAndInterpolation(size=350, interpolation='random')
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2, 4, figsize=(10, 5))
for idx, im in enumerate([tfm(image) for i in range(4)]):
ax[0, idx].imshow(im)
for idx, im in enumerate([tfm(image) for i in range(4)]):
ax[1, idx].imshow(im)
fig.tight_layout()
plt.show()
回溯
Traceback (most recent call last):
File "/home/cvpr/PycharmProjects/timm_tutorials/9_augmentation.py", line 24, in <module>
for idx, im in enumerate([tfm(image) for i in range(4)]):
File "/home/cvpr/PycharmProjects/timm_tutorials/9_augmentation.py", line 24, in <listcomp>
for idx, im in enumerate([tfm(image) for i in range(4)]):
File "/home/cvpr/anaconda3/envs/timm_tutorials/lib/python3.8/site-packages/timm/data/transforms.py", line 181, in __call__
i, j, h, w = self.get_params(img, self.scale, self.ratio)
File "/home/cvpr/anaconda3/envs/timm_tutorials/lib/python3.8/site-packages/timm/data/transforms.py", line 143, in get_params
area = img.size[0] * img.size[1]
TypeError: 'builtin_function_or_method' object is not subscriptable
前面的答案提到RandomResizedCropAndInterpolation
期望PIL.Image
。
您可以查看Note
:
注意:RandomResizedCropAndInterpolation期望输入是PIL的一个实例。Image而不是torch.tensor.
所以你可以删除张量转换:
import numpy as np
import torch
from PIL import Image
from timm.data.transforms_factory import create_transform
a = create_transform(224, is_training=True)
print(a)
pets_image_paths = './download.png'
image = Image.open(pets_image_paths)
from timm.data.transforms import RandomResizedCropAndInterpolation
tfm = RandomResizedCropAndInterpolation(size=350, interpolation='random')
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2, 4, figsize=(10, 5))
for idx, im in enumerate([tfm(image) for i in range(4)]):
ax[0, idx].imshow(im)
for idx, im in enumerate([tfm(image) for i in range(4)]):
ax[1, idx].imshow(im)
fig.tight_layout()
plt.show()
RandomResizedCropAndInterpolation
期望输入是PIL.Image
而不是torch.tensor
的实例。如果你需要把它转换成张量,那就得以后再做了。