如何在torchvision.transforms中找到Normalize的平均值和STD的最佳值



我已经开始使用PyTorch,无法理解如何找到均值和std作为normalize的输入参数。

我看过这个

transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) #https://pytorch.org/vision/stable/transforms.html#

在另一个例子中:

transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])
])  #https://github.com/MicrosoftDocs/ml-basics/blob/master/challenges/05%20%20-%20Safari%20CNN%20Solution%20(PyTorch).ipynb

那么,如果我有一组图像,我应该如何知道或获得这些值呢?这三个参数是否也与R G B有关?

假设您已经有X_train,它是numpy矩阵的列表,例如32x32x3:

X_train = X_train / 255 #normalization of pixels
train_mean = X_train.reshape(3,-1).mean(axis=1)
train_std = X_train.reshape(3,-1).std(axis=1)

然后你可以在你的归一化转换器中传递最后两个变量:

transforms.Normalize(mean = train_mean ,std = train_std)

正如我在评论中提到的,您可以将ImageNet统计信息用于一般域。否则,您可以使用PyTorch表单中的数据来计算需要规范化的数据集的平均值和标准值。

class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(100, 3, 24, 24)

def __getitem__(self, index):
x = self.data[index]
return x
def __len__(self):
return len(self.data)

dataset = MyDataset()
loader = DataLoader(
dataset,
batch_size=10,
num_workers=1,
shuffle=False
)

mean = 0.
std = 0.
nb_samples = 0.
for data in loader:
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
nb_samples += batch_samples
mean /= nb_samples
std /= nb_samples

最新更新