pytorch / libtorch C++ 中的自定义子模块
Custom submodules in pytorch / libtorch C++
完全公开,几天前我在 PyTorch 论坛上问了同样的问题但没有得到回复,所以这在技术上是一个转贴,但我相信这仍然是一个好问题,因为我一直无法在网上的任何地方找到答案。这里是:
你能展示一个使用自定义模块的 register_module 的例子吗?
我在网上找到的唯一示例是将线性层或卷积层注册为子模块。
我尝试编写自己的模块并将其注册到另一个模块,但无法正常工作。
我的 IDE 告诉我 no instance of overloaded function "MyModel::register_module" matches the argument list -- argument types are: (const char [14], TreeEmbedding)
(TreeEmbedding 是我创建的另一个扩展 torch::nn::Module 的结构的名称。)
我错过了什么吗?这方面的一个例子会很有帮助。
编辑:附加上下文如下。
我有一个头文件 "model.h",其中包含以下内容:
struct TreeEmbedding : torch::nn::Module {
TreeEmbedding();
torch::Tensor forward(Graph tree);
};
struct MyModel : torch::nn::Module{
size_t embeddingSize;
TreeEmbedding treeEmbedding;
MyModel(size_t embeddingSize=10);
torch::Tensor forward(std::vector<Graph> clauses, std::vector<Graph> contexts);
};
我还有一个 cpp 文件 "model.cpp",其中包含以下内容:
MyModel::MyModel(size_t embeddingSize) :
embeddingSize(embeddingSize)
{
treeEmbedding = register_module("treeEmbedding", TreeEmbedding{});
}
这个设置还是和上面一样的错误。文档中的代码确实有效(使用线性层等内置组件),但使用自定义模块无效。在追踪到 torch::nn::Linear 之后,它看起来好像是 ModuleHolder
(不管那是什么...)
谢谢,
杰克
如果有人能提供更多详细信息,我会接受更好的答案,但为了以防万一有人想知道,我想我会提供我能找到的少量信息:
register_module 接受一个字符串作为它的第一个参数,它的第二个参数可以是一个 ModuleHolder(我不知道这是什么......)或者它可以是一个 shared_ptr 到你的模块。所以这是我的例子:
treeEmbedding = register_module<TreeEmbedding>("treeEmbedding", make_shared<TreeEmbedding>());
到目前为止,这似乎对我有用。
完全公开,几天前我在 PyTorch 论坛上问了同样的问题但没有得到回复,所以这在技术上是一个转贴,但我相信这仍然是一个好问题,因为我一直无法在网上的任何地方找到答案。这里是:
你能展示一个使用自定义模块的 register_module 的例子吗? 我在网上找到的唯一示例是将线性层或卷积层注册为子模块。
我尝试编写自己的模块并将其注册到另一个模块,但无法正常工作。
我的 IDE 告诉我 no instance of overloaded function "MyModel::register_module" matches the argument list -- argument types are: (const char [14], TreeEmbedding)
(TreeEmbedding 是我创建的另一个扩展 torch::nn::Module 的结构的名称。)
我错过了什么吗?这方面的一个例子会很有帮助。
编辑:附加上下文如下。
我有一个头文件 "model.h",其中包含以下内容:
struct TreeEmbedding : torch::nn::Module {
TreeEmbedding();
torch::Tensor forward(Graph tree);
};
struct MyModel : torch::nn::Module{
size_t embeddingSize;
TreeEmbedding treeEmbedding;
MyModel(size_t embeddingSize=10);
torch::Tensor forward(std::vector<Graph> clauses, std::vector<Graph> contexts);
};
我还有一个 cpp 文件 "model.cpp",其中包含以下内容:
MyModel::MyModel(size_t embeddingSize) :
embeddingSize(embeddingSize)
{
treeEmbedding = register_module("treeEmbedding", TreeEmbedding{});
}
这个设置还是和上面一样的错误。文档中的代码确实有效(使用线性层等内置组件),但使用自定义模块无效。在追踪到 torch::nn::Linear 之后,它看起来好像是 ModuleHolder
(不管那是什么...)
谢谢, 杰克
如果有人能提供更多详细信息,我会接受更好的答案,但为了以防万一有人想知道,我想我会提供我能找到的少量信息:
register_module 接受一个字符串作为它的第一个参数,它的第二个参数可以是一个 ModuleHolder(我不知道这是什么......)或者它可以是一个 shared_ptr 到你的模块。所以这是我的例子:
treeEmbedding = register_module<TreeEmbedding>("treeEmbedding", make_shared<TreeEmbedding>());
到目前为止,这似乎对我有用。