在输入和输出整数数组的神经网络中,我应该为 PyTorch 参数使用什么 dtype?

What dtype should I use for PyTorch parameters in a neural network that inputs and outputs arrays of integers?

我目前正在 PyTorch 中构建一个神经网络,它接受 整数 的张量并输出 整数 的张量。只有 少量正整数是 "allowed"(如 0、1、2、3 和 4)作为输入和输出张量的元素。

神经网络通常连续工作 space。 例如,层与层之间的非线性激活函数是连续的,并将整数映射到实数(包括非整数)。

是否最好在内部使用 torch.uint8 之类的无符号整数作为网络的权重和偏差以及一些将整数映射到整数的自定义激活函数?

或者我应该使用像 torch.float32 这样的高精度浮点数,然后通过将实数分箱到最接近的整数来最后舍入?我认为第二种策略是要走的路,但也许我错过了一些效果很好的东西。

在不太了解您的申请的情况下,我会选择 torch.float32 四舍五入。主要原因是,如果您使用 GPU 来计算您的神经网络,它将要求权重和数据采用 float32 数据类型。如果你不打算训练你的神经网络并且你想在 CPU 上 运行,那么像 torch.uint8 这样的数据类型可能会帮助你,因为你可以在每个时间间隔(即你的应用程序)获得更多指令应该 运行 更快)。如果这没有给您留下线索,那么请更具体地说明您的申请。