在张量流中,如何在将标量张量值移动到 GPU 之前访问它?
In tensorflow, how does one access a scalar tensor value before it is moved to the GPU?
在 tensorflow 中我注册了一个像这样的操作:
REGISTER_OP("RimeBSqrt")
.Input("stokes: FT")
.Input("alpha: FT")
.Input("frequency: FT")
.Input("ref_freq: FT")
.Output("b_sqrt: CT")
.Attr("FT: {float, double} = DT_FLOAT")
.Attr("CT: {complex64, complex128} = DT_COMPLEX64");
以上输入均为张量,
但是 ref_freq 是标量或 0-D 张量。
在我的 CPU 内核的 Compute() 方法中
我可以执行以下操作来提取标量:
const Tensor & in_ref_freq = context->input(3);
FT ref_freq = in_ref_freq.tensor<FT, 1>()(0);
但是,同样的代码会产生段错误
在我的 GPU 内核的 Compute() 方法中,因为
CPU 现在尝试访问
GPU 设备。无论如何拦截这个标量
发送到 GPU 之前的值?我想避免
以下额外的内存间接级别
一个 CUDA 内核:
template <typename FT>
__global__ void kernel(..., FT * ref_freq, ...)
{
FT value = ref_freq[0];
}
我不认为 Attr
是用于 ref_freq
的方法,因为它是可变的、可配置的值。
您可以指定 TensorFlow OpKernel
的一个或多个输入(或输出)在 "host memory" 中,这样您就可以访问 Compute()
中的值方法。为此,您需要修改 REGISTER_KERNEL_BUILDER()
调用以添加 .HostMemory("ref_freq")
指令:
REGISTER_KERNEL_BUILDER(
Name("RimeBSqrt")
.Device(tensorflow::DEVICE_GPU)
.TypeConstraint<float>("FT")
.TypeConstraint<tensorflow::complex64>("CT")
.HostMemory("ref_freq"),
RimeBSqrt<tensorflow::GPUDevice, float, tensorflow::complex64>);
在 tensorflow 中我注册了一个像这样的操作:
REGISTER_OP("RimeBSqrt")
.Input("stokes: FT")
.Input("alpha: FT")
.Input("frequency: FT")
.Input("ref_freq: FT")
.Output("b_sqrt: CT")
.Attr("FT: {float, double} = DT_FLOAT")
.Attr("CT: {complex64, complex128} = DT_COMPLEX64");
以上输入均为张量, 但是 ref_freq 是标量或 0-D 张量。 在我的 CPU 内核的 Compute() 方法中 我可以执行以下操作来提取标量:
const Tensor & in_ref_freq = context->input(3);
FT ref_freq = in_ref_freq.tensor<FT, 1>()(0);
但是,同样的代码会产生段错误 在我的 GPU 内核的 Compute() 方法中,因为 CPU 现在尝试访问 GPU 设备。无论如何拦截这个标量 发送到 GPU 之前的值?我想避免 以下额外的内存间接级别 一个 CUDA 内核:
template <typename FT>
__global__ void kernel(..., FT * ref_freq, ...)
{
FT value = ref_freq[0];
}
我不认为 Attr
是用于 ref_freq
的方法,因为它是可变的、可配置的值。
您可以指定 TensorFlow OpKernel
的一个或多个输入(或输出)在 "host memory" 中,这样您就可以访问 Compute()
中的值方法。为此,您需要修改 REGISTER_KERNEL_BUILDER()
调用以添加 .HostMemory("ref_freq")
指令:
REGISTER_KERNEL_BUILDER(
Name("RimeBSqrt")
.Device(tensorflow::DEVICE_GPU)
.TypeConstraint<float>("FT")
.TypeConstraint<tensorflow::complex64>("CT")
.HostMemory("ref_freq"),
RimeBSqrt<tensorflow::GPUDevice, float, tensorflow::complex64>);