如何像np.where一样使用Tensorflow.where?

How To Use Tensorflow.where in the same way as np.where?

我正在尝试制作一个计算 MSE 的自定义损失函数,但忽略所有低于某个阈值(接近 0)的点。我可以通过以下方式使用 numpy 数组实现此目的。

import numpy as np

a = np.random.normal(size=(4,4))
b = np.random.normal(size=(4,4))
temp_a = a[np.where(a>0.5)] # Your threshold condition
temp_b = b[np.where(a>0.5)]
mse = mean_squared_error(temp_a, temp_b)

但我不知道如何使用 keras 后端执行此操作。我的自定义损失函数不起作用,因为 numpy 无法对张量进行运算。

def customMSE(y_true, y_pred):
    '''
    Correct predictions of 0 do not affect performance.
    '''
    y_true_ = y_true[tf.where(y_true>0.1)] # Your threshold condition
    y_pred_ = y_pred[tf.where(y_true>0.1)]
    mse = K.mean(K.square(y_pred_ - y_true_), axis=1)
    return mse

但是当我这样做时,返回错误

ValueError: Shape must be rank 1 but is rank 3 for '{{node customMSE/strided_slice}} = StridedSlice[Index=DT_INT64, T=DT_FLOAT, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](cond_2/Identity_1, customMSE/strided_slice/stack, customMSE/strided_slice/stack_1, customMSE/strided_slice/Cast)' with input shapes: [?,?,?,?], [1,?,4], [1,?,4], [1].```

您可以在损失函数中使用tf.where代替np.where

如果你想损失相应的真实值低于阈值的预测,你可以像这样编写自定义函数:

def my_loss_threshold(threshold):
    def my_loss(y_true,y_pred):
        # keep predictions pixels where their corresponding y_true is above a threshold
        y_pred = tf.gather_nd(y_pred, tf.where(y_true>=threshold))
        # keep image pixels where they're above a threshold
        y_true = tf.gather_nd(y_true, tf.where(y_true>=threshold))
        # compute MSE between filtered pixels
        loss = tf.square(y_true-y_pred)
        # return mean of losses
        return tf.reduce_mean(loss)
    return my_loss 

model.compile(loss=my_loss_threshold(threshold=0.1), optimizer="adam")

我将损失函数包装到另一个函数中,因此您可以将阈值作为超参数传递给模型编译。