Tensorflow SparseCategoricalCrossentropy 是如何实现的?

How is Tensorflow SparseCategoricalCrossentropy is Impelemented?

我正在研究 SparseCategoricalCrossentropy 的加权版本。现在我的实现是将 y_true 转换为一种热形式并计算交叉熵,然后将其与权重矩阵相乘。当权重全部为 1 时,我的实现和 SparseCategoricalCrossentropy 之间得到相同的输出,但是我的问题是一种热编码。我有很多 类 (32+bg),当使用一种热编码时,我 运行 内存不足,因为 images/batch 大尺寸 SparseCategoricalCrossentropy 不会发生这种情况。我试图弄清楚内置的是如何实现的(有没有办法避免一种热编码等)。内置的是如何实现的或者它是在哪里实现的看[1]它可能是在本机端实现的但我找不到它?

[1] https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/losses.py#L692

SparseCategoricalCrossentropy documentation has a "View Source on GitHub" tab you can click on. This will show you the implementation. Doing this leads us to line 666 of tensorflow.python.keras.losses. We can see from the class definition that it wraps a function sparse_categorical_crossentropy which is defined on line 4867 of tensorflow.keras.backend. We can see at the bottom of the function definition this is a wrapper around tf.nn.sparse_softmax_cross_entropy_with_logits and this function definition can be found in tensorflow.python.ops.nn_ops. At the bottom of this function definition, we can see it is a wrapper around gen_nn_ops.sparse_softmax_cross_entropy_with_logits. If you look for gen_nn_ops, you won't find it. It is the name of the *.so file that python imports to run tensorflow's C++ op code. So what we are really looking for is a sparse softmax C++ kernel, which can be found in tensorflow.core.kernels.sparse_xent_op.cc. This op calls a functor which calls a method SparseXentEigenImpl whose implementation can be found in the corresponding header file, sparse_xent_op.h。从该文件的第 47 行开始,您可以看到它们是如何产生稀疏损失的。

// Generator for calculation of the sparse Xent loss.
// This generator takes the logits, the sum of the exponentiated
// logits, and the label indices.  For each minibatch entry, ignoring
// the batch index b, it calculates:
//
//   loss[j] = (log(sum_exp_logits) - logits[j]) * 1{ j == label }
//
// for j = 0 .. num_classes.  This value must be summed over all j for
// the final loss.

并且在line 224上有评论概述损失计算公式

//  sum(-labels *
//     ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
//  along classes

不确定这是否有助于您创建加权运算,但这是在 tensorflow 中计算稀疏 xent 的方式。

编辑: 还有一个方法tf.nn.weighted_cross_entropy_with_logits。不确定这是否符合您的稀疏性要求,但可能比尝试自己实现一些东西更好。