libtorch (PyTorch C++) 奇怪的 class 语法

libtorch (PyTorch C++) weird class syntax

在 GitHub Here 上的官方 PyTorch C++ 示例中 你可以看到一个奇怪的定义 class:

class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {...}

我的理解是,这定义了一个class CustomDataset which "inherits from" or "extends" torch::data::datasets::Dataset<CustomDataset>。这对我来说很奇怪,因为我们正在创建的 class 是从另一个 class 继承的,而另一个 class 是由我们正在创建的 class 参数化的……这到底是怎么工作的?这是什么意思?这在我看来就像是 Integer class 继承自 vector<Integer>,这似乎很荒谬。

这就是curiously-recurring template pattern,简称CRTP。这种技术的一个主要优点是它启用了所谓的 静态多态性 ,这意味着 torch::data::datasets::Dataset 中的函数可以调用 CustomDataset 中的函数,而无需进行这些函数是虚拟的(从而处理虚拟方法调度等运行时的混乱)。您还可以根据自定义数据集类型的属性执行编译时元编程,例如编译时enable_ifs。

对于 PyTorch,BaseDatasetDataset 的超类)大量使用此技术来支持映射和过滤等操作:

  template <typename TransformType>
  MapDataset<Self, TransformType> map(TransformType transform) & {
    return datasets::map(static_cast<Self&>(*this), std::move(transform));
  }

注意 this 到派生类型的静态转换(只要正确应用 CRTP 就合法); datasets::map 构造一个 MapDataset 对象,该对象也由数据集类型参数化,允许 MapDataset 实现静态调用 get_batch 等方法(或遇到 编译时 错误,如果它们不存在)。

此外,由于 MapDataset 接收自定义数据集类型作为类型参数,编译时元编程是可能的:

  /// The implementation of `get_batch()` for the stateless case, which simply
  /// applies the transform to the output of `get_batch()` from the dataset.
  template <
      typename D = SourceDataset,
      typename = torch::disable_if_t<D::is_stateful>>
  OutputBatchType get_batch_impl(BatchRequestType indices) {
    return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
  }

  /// The implementation of `get_batch()` for the stateful case. Here, we follow
  /// the semantics of `Optional.map()` in many functional languages, which
  /// applies a transformation to the optional's content when the optional
  /// contains a value, and returns a new optional (of a different type)  if the
  /// original optional returned by `get_batch()` was empty.
  template <typename D = SourceDataset>
  torch::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
      BatchRequestType indices) {
    if (auto batch = dataset_.get_batch(std::move(indices))) {
      return transform_.apply_batch(std::move(*batch));
    }
    return nullopt;
  }

请注意,条件启用取决于 SourceDataset,我们之所以可用,是因为数据集是使用此 CRTP 模式参数化的。