LibTorch 中是否有等价于 torch.distributions.Normal 的 C++ API 用于 PyTorch?
Is there an equivalent of torch.distributions.Normal in LibTorch, the C++ API for PyTorch?
我正在使用随机策略实现策略梯度算法,并且由于“辅助”非 PyTorch 操作在 Python 中很慢,我想在 C++ 中实现该算法。有没有办法在 PyTorch C++ 中实现正态分布 API?
Python implementation actually calls the C++ back-end in the at::
namespace (CPU, CUDA, where I found this)。在 PyTorch 团队 and/or 贡献者在 LibTorch 中实现 front-end 之前,您可以使用类似的方法解决它(我只实现了 rsample()
和 log_prob()
因为这是我需要的这个用例):
constexpr double lz = log(sqrt(2 * M_PI));
class Normal {
torch::Tensor mean, stddev, var, log_std;
public:
Normal(const torch::Tensor &mean, const torch::Tensor &std) : mean(mean), stddev(std), var(std * std), log_std(std.log()) {}
torch::Tensor rsample() {
auto device = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
auto eps = torch::randn(1).to(device);
return this->mean + eps * this->stddev;
}
torch::Tensor log_prob(const torch::Tensor &value) {
// log [exp(-(x-mu)^2/(2 sigma^2)) / (sqrt(2 pi) * sigma)] =
// = log [exp(-(x-mu)^2/(2 sigma^2))] - log [sqrt(2 pi) * sigma] =
// = -(x - mu)^2 / (2 sigma^2) - log(sigma) - log(sqrt(2 pi))
return -(value - this->mean)*(value - this->mean) / (2 * this->var) - this->log_std - lz;
}
};
我正在使用随机策略实现策略梯度算法,并且由于“辅助”非 PyTorch 操作在 Python 中很慢,我想在 C++ 中实现该算法。有没有办法在 PyTorch C++ 中实现正态分布 API?
Python implementation actually calls the C++ back-end in the at::
namespace (CPU, CUDA, where I found this)。在 PyTorch 团队 and/or 贡献者在 LibTorch 中实现 front-end 之前,您可以使用类似的方法解决它(我只实现了 rsample()
和 log_prob()
因为这是我需要的这个用例):
constexpr double lz = log(sqrt(2 * M_PI));
class Normal {
torch::Tensor mean, stddev, var, log_std;
public:
Normal(const torch::Tensor &mean, const torch::Tensor &std) : mean(mean), stddev(std), var(std * std), log_std(std.log()) {}
torch::Tensor rsample() {
auto device = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
auto eps = torch::randn(1).to(device);
return this->mean + eps * this->stddev;
}
torch::Tensor log_prob(const torch::Tensor &value) {
// log [exp(-(x-mu)^2/(2 sigma^2)) / (sqrt(2 pi) * sigma)] =
// = log [exp(-(x-mu)^2/(2 sigma^2))] - log [sqrt(2 pi) * sigma] =
// = -(x - mu)^2 / (2 sigma^2) - log(sigma) - log(sqrt(2 pi))
return -(value - this->mean)*(value - this->mean) / (2 * this->var) - this->log_std - lz;
}
};