如果self.transforms是什么意思



是什么

if self.transforms:
data = self.transforms(data)

做什么?我不明白这条线背后的逻辑——这条线使用的条件是什么?

我正在阅读一篇关于使用pytorch基于以下实现创建自定义数据集的文章:

#custom dataset
class MNISTDataset(Dataset):
def __init__(self, images, labels=None, transforms=None):
self.X = images
self.y = labels
self.transforms = transforms

def __len__(self):
return (len(self.X))

def __getitem__(self, i):
data = self.X.iloc[i, :]
data = np.asarray(data).astype(np.uint8).reshape(28, 28, 1)

if self.transforms:
data = self.transforms(data)

if self.y is not None:
return (data, self.y[i])
else:
return data
train_data = MNISTDataset(train_images, train_labels, transform)
test_data = MNISTDataset(test_images, test_labels, transform)
# dataloaders
trainloader = DataLoader(train_data, batch_size=128, shuffle=True)
testloader = DataLoader(test_data, batch_size=128, shuffle=True)

谢谢!我基本上是在试图理解它为什么有效&它如何将转换应用于数据。

数据集MNISTDataset可以通过转换函数进行初始化。如果给定这样的变换函数,则将其保存在self.transforms中,否则将保持其默认值None。当用__getitem__调用新项时,它首先检查转换是否为真值,在这种情况下,它检查self.transforms是否可以强制为True,这是可调用对象的情况。否则,这意味着CCD_ 7一开始就没有被提供,并且在CCD_ 8上没有应用变换函数。


这里有一个一般的例子,在torch/torchvision上下文之外:

def do(x, callback=None):
if callback: # will be True if callback is a function/lambda
return callback(x)
return x
do(2) # returns 2
do(2, callback=lambda x: 2*x) # returns 4

最新更新