如何在LibTorch中使用collate_fn



我试图在libtorch中使用CNN实现基于图像的回归。问题是,我的图像有不同的大小,这将导致一个异常批处理图像。

首先,我创建我的dataset:

auto set = MyDataSet(pathToData).map(torch::data::transforms::Stack<>());
然后我创建dataLoader:
auto dataLoader = torch::data::make_data_loader(
std::move(set),
torch::data::DataLoaderOptions().batch_size(batchSize).workers(numWorkersDataLoader)
);

将在列车循环中批量处理数据抛出异常:

for (torch::data::Example<> &batch: *dataLoader) {
processBatch(model, optimizer, counter, batch);
}

,批大小大于1(批大小为1时,一切都工作得很好,因为不涉及任何堆叠)。例如,使用批大小为2,我会得到以下错误:

...
what():  stack expects each tensor to be equal size, but got [3, 1264, 532] at entry 0 and [3, 299, 294] at entry 1

我读到,例如可以使用collate_fn来实现一些填充(例如在这里),我只是不知道在哪里实现它。例如,torch::data::DataLoaderOptions不提供这样的东西。

有人知道怎么做吗?

我现在有办法了。总而言之,我将CNN拆分为Conv层和denselayer,并在批处理构建中使用torch::nn::AdaptiveMaxPool2d的输出。

为了做到这一点,我必须修改我的Dataset, Net和train/val/test-methods。在我的Net中,我增加了两个额外的forward函数。第一个通过所有卷积层传递数据并返回AdaptiveMaxPool2d层的输出。第二个是通过所有Dense-Layers传递数据。在实践中,这看起来像:

torch::Tensor forwardConLayer(torch::Tensor x) {
x = torch::relu(conv1(x));
x = torch::relu(conv2(x));
x = torch::relu(conv3(x));
x = torch::relu(ada1(x));
x = torch::flatten(x);
return x;
}
torch::Tensor forwardDenseLayer(torch::Tensor x) {
x = torch::relu(lin1(x));
x = lin2(x);
return x;
}

然后我覆盖get_batch方法并使用forwardConLayer来计算每个批处理条目。为了训练(正确地),我在构造批处理之前调用zero_grad()。所有这些看起来像:

std::vector<ExampleType> get_batch(at::ArrayRef<size_t> indices) override {
// impl from bash.h
this->net.zero_grad();
std::vector<ExampleType> batch;
batch.reserve(indices.size());
for (const auto i : indices) {
ExampleType batchEntry = get(i);
auto batchEntryData = (batchEntry.data).unsqueeze(0);
auto newBatchEntryData = this->net.forwardConLayer(batchEntryData);             
batchEntry.data = newBatchEntryData;
batch.push_back(batchEntry);
}
return batch;
}

最后,在所有我通常会调用forward的地方调用forwardDenseLayer,例如:

for (torch::data::Example<> &batch: *dataLoader) {
auto data = batch.data;
auto target = batch.target.squeeze();
auto output = model.forwardDenseLayer(data);
auto loss = torch::mse_loss(output, target);
LOG(INFO) << "Batch loss: " << loss.item<double>();
loss.backward();
optimizer.step();
}

如果数据加载器的工作线程数不为0,这个解决方案似乎会导致错误。错误是:

terminate called after thro9wing an instance of 'std::runtime_error'
what(): one of the variables needed for gradient computation has been modified by an inplace operation: [CPUFloatType [3, 12, 3, 3]] is at version 2; expected version 1 instead. ...

这个错误是有意义的,因为数据在批处理过程中通过CNN的头部。解决这个"问题"的办法;是将工人的数目设置为0。

相关内容

  • 没有找到相关文章

最新更新