在自定义损失函数中根据条件采取行动

Act on condition in custom loss function

我在 tensorflow 中定义了一个模型,作为损失函数,我编写了自己的自定义损失函数。一般来说,我在正常的MSE函数之外加了一个惩罚项:

return K.mean(K.square(y_true - y_pred)) + penalty

这个惩罚是由一些输入元素和一些目标之间的乘积和总和给出的(我没有报告代码,因为它与问题无关)。

现在,一切正常,但我需要向数据集添加一种新的示例,其中某些特征和输出的值为零。

因此,在批处理中(在训练期间)两种类型的示例,即零值和零值,我需要一种应用于整个批处理的条件语句,以便我可以计算 penalty 根据条件的评估以两种不同的方式。

可以做到吗?

您可以使用 tf.where :

例如:

  • 如果元素为零,returns本身
  • 如果元素不为零,returns istself + 1
>>> batch_1 = tf.constant([0.,1.,0.,0.5,3.,-2.,0.])
>>> tf.where(batch_1 == 0, batch_1, batch_1+1)
 <tf.Tensor: shape=(7,), dtype=float32, numpy=array([0. , 2. , 0. , 1.5, 4. , -1. , 0. ], dtype=float32)>

它也适用于多维张量,例如:

  • 如果批次的第二个元素为零,则对第一个元素求和
  • 否则,对整批求和
>>> a = tf.constant([[1,0,3],[3,4,9],,[8,0,1]],tf.float32)
>>> tf.where(a[:,1]==0,tf.reduce_sum(a[:,:1],axis=1), tf.reduce_sum(a,axis=1))
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 1., 16.,  8.], dtype=float32)>