我已经开始使用PyTorch,无法理解如何找到均值和std作为normalize的输入参数。
我看过这个
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) #https://pytorch.org/vision/stable/transforms.html#
在另一个例子中:
transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])
]) #https://github.com/MicrosoftDocs/ml-basics/blob/master/challenges/05%20%20-%20Safari%20CNN%20Solution%20(PyTorch).ipynb
那么,如果我有一组图像,我应该如何知道或获得这些值呢?这三个参数是否也与R G B有关?
假设您已经有X_train,它是numpy矩阵的列表,例如32x32x3:
X_train = X_train / 255 #normalization of pixels
train_mean = X_train.reshape(3,-1).mean(axis=1)
train_std = X_train.reshape(3,-1).std(axis=1)
然后你可以在你的归一化转换器中传递最后两个变量:
transforms.Normalize(mean = train_mean ,std = train_std)
正如我在评论中提到的,您可以将ImageNet统计信息用于一般域。否则,您可以使用PyTorch表单中的数据来计算需要规范化的数据集的平均值和标准值。
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(100, 3, 24, 24)
def __getitem__(self, index):
x = self.data[index]
return x
def __len__(self):
return len(self.data)
dataset = MyDataset()
loader = DataLoader(
dataset,
batch_size=10,
num_workers=1,
shuffle=False
)
mean = 0.
std = 0.
nb_samples = 0.
for data in loader:
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
nb_samples += batch_samples
mean /= nb_samples
std /= nb_samples