Tensorflow:调用 tf.reduce_sum 时如何忽略数组的一部分

Tensorflow: How to ignore parts of an array when calling tf.reduce_sum

我想改变典型的 MSE 损失函数。现在我有以下代码:

squared_difference = tf.reduce_sum(tf.square(target - output), [1])
mse_loss = tf.reduce_mean(squared_difference)

两个张量的形状都是[batch_size, 10],目标的一个例子是[0,1,2,3,0.5,0.5,0.5,7,8,9]0.5 始终位于索引 4、5 和 6。

我现在要做的是完全忽略这些指标,如果网络输出在这些指标处没有 0.5,则不要增加损失。

所以如果输出是 [0,1,2,3,20,10,14,7,8,9] 损失应该是 0.

实现此目标的最佳方法是什么?

您可以通过多种方式处理此问题。一种直接的方法是使用 tf.losses.mean_squared_errorweights 参数。传递一个 bsz x labels 张量,它用作一种掩码,其中 1 表示您要考虑的值,0 表示忽略。大多数损失函数都存在 weights 参数。