专辑错误:__call__() 得到一个意外的关键字参数"force_apply"



我尝试做一个简单的数据加载器:

train_transforms = A.Compose(
[
A.GaussNoise(always_apply=False, p=0.4, var_limit=(0, 70.0)),
A.Blur(always_apply=False, p=0.3, blur_limit=(3, 7)),
A.RandomResizedCrop(always_apply=False, p=0.2, height=128, width=128, scale=(0.7, 1.0),
ratio=(0.75, 1.3333333333333333), interpolation=0),
A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
A.RandomBrightnessContrast(p=0.5),
transforms.ToTensor()
]
)
class SixKDataset(Dataset):
def __init__(self, data_path, transform=None):
self.data = np.load(data_path, allow_pickle=True)
self.transform = transform
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
ret = np.squeeze(self.data[idx, 0, :, :, :]), np.squeeze(self.data[idx, 1, :, :, :])
if self.transform:
image1 = self.transform(image = ret[0])['image']
image2 = self.transform(image = ret[1])['image']
return image1, image2
return np.squeeze(ret[0]), np.squeeze(ret[1])

但是我得到这个错误:TypeError:调用()得到一个意外的关键字参数'force_apply'

我不知道如何解决这个问题

我认为您需要像这样定义额外的目标:

train_transforms = A.Compose(
[
A.GaussNoise(always_apply=False, p=0.4, var_limit=(0, 70.0)),
A.Blur(always_apply=False, p=0.3, blur_limit=(3, 7)),
A.RandomResizedCrop(always_apply=False, p=0.2, height=128, width=128, scale=(0.7, 1.0),
ratio=(0.75, 1.3333333333333333), interpolation=0),
A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
A.RandomBrightnessContrast(p=0.5),
transforms.ToTensor()
], additional_targets={'image0': 'image'}
)
然后在getitem中使用它们像这样:
def __getitem__(self, idx):
ret = np.squeeze(self.data[idx, 0, :, :, :]), np.squeeze(self.data[idx, 1, :, :, :])
if self.transform:
transformed = transform(image=ret[0], image0=ret[1])
return np.squeeze(transformed[image]), np.squeeze(transformed[image0])

https://albumentations。Ai/docs/examples/example_multi_target/look here

最新更新