如何在 LibTorch 中使用 collate_fn

How to use collate_fn in LibTorch

我正在尝试使用 libtorch 中的 CNN 实现基于图像的回归。问题是,我的图像尺寸不同,这会导致图像批处理异常。

首先,我创建我的 dataset:

auto set = MyDataSet(pathToData).map(torch::data::transforms::Stack<>());

然后我创建 dataLoader:

auto dataLoader = torch::data::make_data_loader(
    std::move(set),
    torch::data::DataLoaderOptions().batch_size(batchSize).workers(numWorkersDataLoader)
);

训练循环中的批处理数据会抛出异常:

for (torch::data::Example<> &batch: *dataLoader) {
        processBatch(model, optimizer, counter, batch);
}

批量大小大于 1(批量大小为 1 时一切正常,因为不涉及任何堆叠)。例如,我将使用 2 的批处理大小得到以下错误:

...
what():  stack expects each tensor to be equal size, but got [3, 1264, 532] at entry 0 and [3, 299, 294] at entry 1

我读到有人可以使用 collate_fn 来实现一些填充(例如 here),我只是不知道在哪里实现它。例如 torch::data::DataLoaderOptions 不提供这样的东西。

有人知道怎么做吗?

我现在有办法了。总而言之,我将我的 CNN 拆分为 Conv- 和 Denselayers,并在批量构建中使用 torch::nn::AdaptiveMaxPool2d 的输出。

为此,我必须修改我的数据集、Net 和 train/val/test-methods。在我的网络中,我添加了两个额外的前向函数。第一个通过所有 Conv-Layers 和 returns 一个 AdaptiveMaxPool2d-Layer 的输出传递数据。第二个通过所有密集层传递数据。实际上这看起来像:

torch::Tensor forwardConLayer(torch::Tensor x) {
    x = torch::relu(conv1(x));
    x = torch::relu(conv2(x));
    x = torch::relu(conv3(x));
    x = torch::relu(ada1(x));
    x = torch::flatten(x);
    return x;
}

torch::Tensor forwardDenseLayer(torch::Tensor x) {
    x = torch::relu(lin1(x));
    x = lin2(x);
    return x;
}

然后我覆盖 get_batch 方法并使用 forwardConLayer 计算每个批次条目。为了(正确地)训练,我在构建批次之前调用 zero_grad() 。总而言之,这看起来像:

std::vector<ExampleType> get_batch(at::ArrayRef<size_t> indices) override {
    // impl from bash.h
    this->net.zero_grad();
    std::vector<ExampleType> batch;
    batch.reserve(indices.size());
    for (const auto i : indices) {
        ExampleType batchEntry = get(i);
        auto batchEntryData = (batchEntry.data).unsqueeze(0);
        auto newBatchEntryData = this->net.forwardConLayer(batchEntryData);             
        batchEntry.data = newBatchEntryData;
        batch.push_back(batchEntry);
    }
    return batch;
}

最后,我在通常会调用 forward 的所有地方调用 forwardDenseLayer,例如:

    for (torch::data::Example<> &batch: *dataLoader) {
        auto data = batch.data;
        auto target = batch.target.squeeze();

        auto output = model.forwardDenseLayer(data);
        auto loss = torch::mse_loss(output, target);
        LOG(INFO) << "Batch loss: " << loss.item<double>();

        loss.backward();
        optimizer.step();
    }

更新

如果数据加载器的 worker 数量不为 0,此解决方案似乎会导致错误。错误是:

terminate called after thro9wing an instance of 'std::runtime_error'
  what(): one of the variables needed for gradient computation has been modified by an inplace operation: [CPUFloatType [3, 12, 3, 3]] is at version 2; expected version 1 instead. ...

这个错误确实有意义,因为数据在批处理过程中通过了 CNN 的头部。这个“问题”的解决方案是将worker的数量设置为0。