如何为 TensorFlow 创建 "exact match" eval_metric_op?

How to create an "exact match" eval_metric_op for TensorFlow?

我正在尝试创建一个 eval_metric_op 函数,该函数将显示多标签分类问题在给定阈值下的完全匹配比例。下面的函数returns 0(不完全匹配)或1(完全匹配)根据给定的阈值。

def exact_match(y_true, y_logits, threshold):
    y_pred = np.round(y_logits-threshold+0.5)
    return int(np.array_equal(y_true, y_pred))

y_true = np.array([1,1,0])
y_logits = np.array([0.67, 0.9, 0.55])

print(exact_match(y_true, y_logits, 0.5))
print(exact_match(y_true, y_logits, 0.6))

0.5 的阈值会产生 [1,1,1] 的预测,这是不正确的,因此函数 returns 0。0.6 的阈值会产生 [1,1,0] 的预测,即正确所以函数 returns 1.

我想将这个函数变成一个 tensorflow eval metric op -- 有人可以建议最好的方法吗?

我可以使用下面的 tensorflow ops 得到相同的逻辑,但我不完全确定如何将其变成自定义 eval_metric_op:

import tensorflow as tf

def exact_match_fn(y_true, y_logits, threshold):
    #pred = tf.equal(tf.round(y_logits), tf.round(y_true))
    predictions = tf.to_float(tf.greater_equal(y_logits, threshold))
    pred_match = tf.equal(predictions, tf.round(y_true))
    exact_match = tf.reduce_min(tf.to_float(pred_match))
    return exact_match

graph = tf.Graph()
with graph.as_default():
    y_true = tf.constant([1,1,0], dtype=tf.float32)
    y_logits = tf.constant([0.67,0.9,0.55], dtype=tf.float32)
    exact_match_50 = exact_match_fn(y_true, y_logits, 0.5)
    exact_match_60 = exact_match_fn(y_true, y_logits, 0.6)

sess = tf.InteractiveSession(graph=graph)
print(sess.run([exact_match_50, exact_match_60]))

以上代码将导致 exact_match_50 的 0(至少 1 个预测不正确)和 exact_match_60 的 1(所有标签正确)。

仅使用 tf.contrib.metrics.streaming_mean() 就足够了吗?还是有更好的选择?我会将其实现为:

tf.contrib.metrics.streaming_mean(exact_match(y_true, y_logits, threshold))

您的 exact_match_fn 的输出是一个可用于评估的操作。如果您想要一批的平均值,请将 reduce_min 更改为仅在相关轴上减少。

例如如果你们 y_true/y_logits 每个人的形状都是 (batch_size, n)

def exact_match_fn(y_true, y_logits, threshold):
    #pred = tf.equal(tf.round(y_logits), tf.round(y_true))
    predictions = tf.to_float(tf.greater_equal(y_logits, threshold))
    pred_match = tf.equal(predictions, tf.round(y_true))
    exact_match = tf.reduce_min(tf.to_float(pred_match), axis=1)
    return exact_match


def exact_match_prop_fn(*args):
    return tf.reduce_mean(exact_match_fn(*args))

这将为您提供一批的平均值。如果你想要整个数据集的平均值,我只收集匹配项(或 correcttotal 计数)并在 session/tensorflow 之外进行评估,但 streaming_mean 可能就是那样,不确定。