为什么我们需要 `int64` 作为损失函数中的 MNIST 标签,来自 tensorflow?
Why we need `int64` for MNIST labels in a loss function , from tensorflow?
代码摘自Tensorflow tutorial。该函数对 MNIST 数据集运行操作,该数据集是 0-9 的手写图片数据集。为什么要将标签投射到 int64
,我认为 int32
就足够了。
def loss(logits,labels):
labels = tf.to_int64(labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits,labels,name='xentropy')
loss = tf.reduce_mean(cross_entropy,name='xentropy_mean')
return loss
这个documentation表示它可以是int32
或int64
。因此,您可以选择其中之一。在这里,他们更愿意选择 int64
.
引用文档:
labels
: Tensor of shape [d_0, d_1, ..., d_{r-2}]
and dtype int32
or int64
. Each entry in labels
must be an index in [0, num_classes)
. Other values will raise an exception when this op is run on CPU, and return NaN
for corresponding corresponding loss and gradient rows on GPU.
代码摘自Tensorflow tutorial。该函数对 MNIST 数据集运行操作,该数据集是 0-9 的手写图片数据集。为什么要将标签投射到 int64
,我认为 int32
就足够了。
def loss(logits,labels):
labels = tf.to_int64(labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits,labels,name='xentropy')
loss = tf.reduce_mean(cross_entropy,name='xentropy_mean')
return loss
这个documentation表示它可以是int32
或int64
。因此,您可以选择其中之一。在这里,他们更愿意选择 int64
.
引用文档:
labels
: Tensor of shape[d_0, d_1, ..., d_{r-2}]
and dtypeint32
orint64
. Each entry inlabels
must be an index in[0, num_classes)
. Other values will raise an exception when this op is run on CPU, and returnNaN
for corresponding corresponding loss and gradient rows on GPU.