Keras 后端:argmax 如果高于阈值,否则 -1

Keras backend: argmax if above threshold, else -1

我想创建一个自定义精度函数,仅当 argmax 处的值超过阈值时才将 argmax 用于 y_pred,否则为 -1。

就支持的 Keras 而言,这将是 sparse_categorical_accuracy:

的修改
return backend.cast(
    backend.equal(
        backend.flatten(y_true),
        backend.cast(backend.argmax(y_pred, axis=-1),
                     backend.floatx())),
    backend.floatx())

所以,而不是:

backend.argmax(y_pred, axis=-1)

我需要一个具有伪代码逻辑的函数:

argmax_values = backend.argmax(y_pred, axis=-1)
argmax_values if y_pred[argmax_values] > threshold else -1

举个具体的例子,如果:

x = [[0.75, 0.25], [0.85, 0.15], [0.5, 0.5], [0.95, 0.05]]

threshold=0.8,则所需函数的结果将是:

[-1, 0, -1, 0]

如何使用 Keras 后端实现此目的?我的 Keras 版本是 2.2.4,所以我无法访问 TensorFlow 2 后端。

您可以使用 K.switch 根据条件从两个不同的张量中有条件地分配值。使用 K.switch,您想要的函数将是:

from keras import backend as K

def argmax_w_threshold(y_pred, threshold=0.8):
    argmax_values = K.cast(K.argmax(y_pred, axis=-1), K.floatx())
    return K.switch(
        K.max(y_pred, axis=-1) > threshold,
        argmax_values, 
        -1. * K.ones_like(argmax_values)
    )

请注意,K.switchthenelse 部分的张量必须具有相同的形状,因此使用 K.ones_like.

以你的例子为例:

>>> import tensorflow as tf
>>> sess = tf.InteractiveSession()
>>> x = [[0.75, 0.25], [0.85, 0.15], [0.5, 0.5], [0.95, 0.05]]
>>> sess.run(argmax_w_threshold(x))
array([-1.,  0., -1.,  0.], dtype=float32)