创建 AdamParamState 的实例

Create an instance of AdamParamState

我需要创建 AdamParamState 的实例。我以 adam.cpp 代码为例,并相应地从那里复制了以下代码。但是,使用提供的headers,它仍然无法识别AdamParamState

感谢您对此事的任何帮助或评论。

#include <torch/optim/adam.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/nn/module.h>
#include <torch/serialize/archive.h>
#include <torch/utils.h>

#include <ATen/ATen.h>

void get_state(torch::optim::Optimizer *optimizer){
    for (auto& group : optimizer->param_groups()) {
        for (auto &p : group.params()) {
            if (!p.grad().defined()) {
                continue;
            }

            auto grad = p.grad();
            TORCH_CHECK(!grad.is_sparse(),
                        "Adam does not support sparse gradients"/*, please consider SparseAdam instead*/);
            ska::flat_hash_map<std::string, std::unique_ptr<torch::optim::OptimizerParamState>>& state_ = optimizer->state();
            auto param_state = state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
            auto tmp_ = p.dim();
            int tmp_0;
            int tmp_1;
            if (tmp_ > 0)
                tmp_0  = p.size(0);
            if (tmp_ > 1)
                tmp_1  = p.size(1);

            std::cout << tmp_ << tmp_0 << tmp_1 << std::endl;
//                auto& options = static_cast<AdamOptions&>(group.options());
            auto& state = static_cast<AdamParamState&>(*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
        }
    }

}

我发现这行得通:

auto& state = static_cast<torch::optim::AdamParamState&>(*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);

非常简单又多汁!