具有类似 argwhere 检查的 Keras 自定义损失函数

Keras custom loss function with argwhere-like check

我正在尝试在 Keras 中为生成矩阵的生成器创建自定义损失函数。该矩阵由较多的元素和较少的中心组成。与元素相比,中心具有更高的价值 - 元素具有价值 <0.1,而中心应达到价值 >0.5。重要的是中心位于完全正确的索引处,而适合元素则不太重要。这就是为什么我要尝试创建会执行以下操作的损失:

  1. select y_true 中值为 >0.5 的所有元素,在 numpy 中我会做 indices = np.argwhere(y_true>0.5)
  2. 比较 y_truey_pred 给定索引处的值,例如 loss=(K.square(y_pred[indices]-y_true[indices]))
  3. select 所有其他元素 indices_low = np.argwhere(y_true<0.5)
  4. 与步骤 2 相同,即保存为 loss_low
  5. return加权损失,即return loss*100+loss_low,只是为了给更重要的数据更高的权重

但是,我找不到在 keras 后端实现此目的的方法,我找到了 ,试图寻找与我的问题类似的东西,但似乎没有 tf.argwhere(无法在文档中找到,也无法浏览 net/SO)。那么我该如何实现呢?

请注意,中心的数量和位置可能会有所不同,并且生成器从一开始就很糟糕,因此它不会生成任何东西,或者会生成比实际更多的东西,所以我认为我不能简单地使用 tf.where。我在这里可能不正确,因为我是自定义损失函数的新手,欢迎任何想法。

编辑

毕竟 K.tf.where 正是我要找的,所以我试了一下:

def custom_mse():
    def mse(y_true, y_pred):
        indices = K.tf.where(y_true>0.5)
        loss = K.square(y_true[indices]-y_pred[indices])  
        indices = K.tf.where(y_true<0.5)
        loss_low = K.square(y_true[indices]-y_pred[indices]) 
        return 100*loss+loss_low
    return mse

但这一直在抛出错误:

ValueError: Shape must be rank 1 but is rank 3 for 'loss_1/Generator_loss/strided_slice' (op: 'StridedSlice') with input shapes: [?,?,?,?], [1,?,4], [1,?,4], [1].

如何使用 where 输出?

过了一段时间我终于找到了正确的解决方案,所以它可能对以后的人有所帮助:

首先,我的代码因我长期使用 numpy 和 Pandas 而产生偏差,因此我希望 tf 元素可以被寻址为 y_true[indices],实际上有内置函数 tf.gathertf.gather_nd 用于获取张量的元素。但是,由于两个损失中的元素数量不同,我不能使用它,因为一起计算损失会导致不正确的大小错误。

这让我采用了不同的方法,。理解已接受答案中的代码,我发现您不仅可以使用 tf.where 来获取索引,还可以将掩码应用于您的张量。我的问题的最终解决方案是在输入张量上应用两个掩码并计算两个损失,一个是我计算较高值的损失,一个是我计算较低值的损失,然后乘以应该具有更高权重的损失。

def custom_mse():
    def mse(y_true, y_pred):
        great = K.tf.greater(y_true,0.5)
        loss = K.square(tf.where(great, y_true, tf.zeros(tf.shape(y_true)))-tf.where(great, y_pred, tf.zeros(tf.shape(y_pred))))
        
        lower = K.tf.less(y_true,0.5)
        loss_low = K.square(tf.where(lower, y_true, tf.zeros(tf.shape(y_true)))-tf.where(lower, y_pred, tf.zeros(tf.shape(y_pred))))
        return 100*loss+loss_low
    return mse