为什么我们需要 `int64` 作为损失函数中的 MNIST 标签,来自 tensorflow?

Why we need `int64` for MNIST labels in a loss function , from tensorflow?

代码摘自Tensorflow tutorial。该函数对 MNIST 数据集运行操作,该数据集是 0-9 的手写图片数据集。为什么要将标签投射到 int64,我认为 int32 就足够了。

def loss(logits,labels):
    labels = tf.to_int64(labels)
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits,labels,name='xentropy')
    loss = tf.reduce_mean(cross_entropy,name='xentropy_mean')
    return loss

这个documentation表示它可以是int32int64。因此,您可以选择其中之一。在这里,他们更愿意选择 int64.

引用文档:

labels: Tensor of shape [d_0, d_1, ..., d_{r-2}] and dtype int32 or int64. Each entry in labels must be an index in [0, num_classes). Other values will raise an exception when this op is run on CPU, and return NaN for corresponding corresponding loss and gradient rows on GPU.