我有train_x,valid_x从trainX中分离出来,train_y valid_y从trainY中分离出来,它们的形状如下。 我想对标签的图像进行分类 标签 = set(["面孔"、"豹子"、"摩托车"、"飞机"](。
print(train_x.shape, len(train_y))
torch.Size([1339, 96, 96, 3]) 1339
print(valid_x.shape, len(valid_y))
torch.Size([335, 96, 96, 3]) 335
print(testX.shape, len(testY))
torch.Size([559, 96, 96, 3]) 559
所以我想对数据批量使用常规训练/有效代码,如下所示:
#train the network
n_epochs = 20
valid_loss = []
train_loss = []
for epoch in range(1,n_epochs+1):
cur_train_loss = 0.0
cur_valid_loss = 0.0
#####################
#### Train model ####
#####################
cnn_model.train()
for data, target in trainLoader:
if train_on_gpu:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = cnn_model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
cur_train_loss += loss.item() * data.size(0)
########################
#### Validate model ####
########################
cnn_model.eval()
for data, target in validLoader:
if train_on_gpu:
data, target = data.cuda(), target.cuda()
output = cnn_model(data)
loss = criterion(output, target)
cur_valid_loss += loss.item() * data.size(0)
# calculate avg loss
avg_train_loss = cur_train_loss / len(trainLoader.sampler)
avg_valid_loss = cur_valid_loss / len(validLoader.sampler)
train_loss.append(avg_train_loss)
valid_loss.append(avg_valid_loss)
print('Epoch: {} t train_loss: {:.6f} t valid_loss: {:.6f}'.format(epoch, avg_train_loss, avg_valid_loss))
那么我必须为此做些什么呢? 我已经搜索了,但没有发现任何具体的东西。我想为此使用 PyTorch。我已经为另一个类似的问题构建了模型,但我使用 DataLoader 一次加载一批数据进行训练和验证。
您可以使用torch.utils.data.TensorDataset
创建一个数据集,其中每个train_x
样本都与train_y
中相应的标签相关联,以便DataLoader
可以像您习惯的那样创建批处理。
from torch.utils.data import DataLoader, TensorDataset
train_dataset = TensorDataset(train_x, train_y)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataset = TensorDataset(valid_x, valid_y)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataset = TensorDataset(testX, testY)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)