如何为 TensorFlow 创建 "exact match" eval_metric_op?
How to create an "exact match" eval_metric_op for TensorFlow?
我正在尝试创建一个 eval_metric_op 函数,该函数将显示多标签分类问题在给定阈值下的完全匹配比例。下面的函数returns 0(不完全匹配)或1(完全匹配)根据给定的阈值。
def exact_match(y_true, y_logits, threshold):
y_pred = np.round(y_logits-threshold+0.5)
return int(np.array_equal(y_true, y_pred))
y_true = np.array([1,1,0])
y_logits = np.array([0.67, 0.9, 0.55])
print(exact_match(y_true, y_logits, 0.5))
print(exact_match(y_true, y_logits, 0.6))
0.5 的阈值会产生 [1,1,1] 的预测,这是不正确的,因此函数 returns 0。0.6 的阈值会产生 [1,1,0] 的预测,即正确所以函数 returns 1.
我想将这个函数变成一个 tensorflow eval metric op -- 有人可以建议最好的方法吗?
我可以使用下面的 tensorflow ops 得到相同的逻辑,但我不完全确定如何将其变成自定义 eval_metric_op:
import tensorflow as tf
def exact_match_fn(y_true, y_logits, threshold):
#pred = tf.equal(tf.round(y_logits), tf.round(y_true))
predictions = tf.to_float(tf.greater_equal(y_logits, threshold))
pred_match = tf.equal(predictions, tf.round(y_true))
exact_match = tf.reduce_min(tf.to_float(pred_match))
return exact_match
graph = tf.Graph()
with graph.as_default():
y_true = tf.constant([1,1,0], dtype=tf.float32)
y_logits = tf.constant([0.67,0.9,0.55], dtype=tf.float32)
exact_match_50 = exact_match_fn(y_true, y_logits, 0.5)
exact_match_60 = exact_match_fn(y_true, y_logits, 0.6)
sess = tf.InteractiveSession(graph=graph)
print(sess.run([exact_match_50, exact_match_60]))
以上代码将导致 exact_match_50 的 0(至少 1 个预测不正确)和 exact_match_60 的 1(所有标签正确)。
仅使用 tf.contrib.metrics.streaming_mean()
就足够了吗?还是有更好的选择?我会将其实现为:
tf.contrib.metrics.streaming_mean(exact_match(y_true, y_logits, threshold))
您的 exact_match_fn
的输出是一个可用于评估的操作。如果您想要一批的平均值,请将 reduce_min
更改为仅在相关轴上减少。
例如如果你们 y_true
/y_logits
每个人的形状都是 (batch_size, n)
def exact_match_fn(y_true, y_logits, threshold):
#pred = tf.equal(tf.round(y_logits), tf.round(y_true))
predictions = tf.to_float(tf.greater_equal(y_logits, threshold))
pred_match = tf.equal(predictions, tf.round(y_true))
exact_match = tf.reduce_min(tf.to_float(pred_match), axis=1)
return exact_match
def exact_match_prop_fn(*args):
return tf.reduce_mean(exact_match_fn(*args))
这将为您提供一批的平均值。如果你想要整个数据集的平均值,我只收集匹配项(或 correct
和 total
计数)并在 session/tensorflow 之外进行评估,但 streaming_mean
可能就是那样,不确定。
我正在尝试创建一个 eval_metric_op 函数,该函数将显示多标签分类问题在给定阈值下的完全匹配比例。下面的函数returns 0(不完全匹配)或1(完全匹配)根据给定的阈值。
def exact_match(y_true, y_logits, threshold):
y_pred = np.round(y_logits-threshold+0.5)
return int(np.array_equal(y_true, y_pred))
y_true = np.array([1,1,0])
y_logits = np.array([0.67, 0.9, 0.55])
print(exact_match(y_true, y_logits, 0.5))
print(exact_match(y_true, y_logits, 0.6))
0.5 的阈值会产生 [1,1,1] 的预测,这是不正确的,因此函数 returns 0。0.6 的阈值会产生 [1,1,0] 的预测,即正确所以函数 returns 1.
我想将这个函数变成一个 tensorflow eval metric op -- 有人可以建议最好的方法吗?
我可以使用下面的 tensorflow ops 得到相同的逻辑,但我不完全确定如何将其变成自定义 eval_metric_op:
import tensorflow as tf
def exact_match_fn(y_true, y_logits, threshold):
#pred = tf.equal(tf.round(y_logits), tf.round(y_true))
predictions = tf.to_float(tf.greater_equal(y_logits, threshold))
pred_match = tf.equal(predictions, tf.round(y_true))
exact_match = tf.reduce_min(tf.to_float(pred_match))
return exact_match
graph = tf.Graph()
with graph.as_default():
y_true = tf.constant([1,1,0], dtype=tf.float32)
y_logits = tf.constant([0.67,0.9,0.55], dtype=tf.float32)
exact_match_50 = exact_match_fn(y_true, y_logits, 0.5)
exact_match_60 = exact_match_fn(y_true, y_logits, 0.6)
sess = tf.InteractiveSession(graph=graph)
print(sess.run([exact_match_50, exact_match_60]))
以上代码将导致 exact_match_50 的 0(至少 1 个预测不正确)和 exact_match_60 的 1(所有标签正确)。
仅使用 tf.contrib.metrics.streaming_mean()
就足够了吗?还是有更好的选择?我会将其实现为:
tf.contrib.metrics.streaming_mean(exact_match(y_true, y_logits, threshold))
您的 exact_match_fn
的输出是一个可用于评估的操作。如果您想要一批的平均值,请将 reduce_min
更改为仅在相关轴上减少。
例如如果你们 y_true
/y_logits
每个人的形状都是 (batch_size, n)
def exact_match_fn(y_true, y_logits, threshold):
#pred = tf.equal(tf.round(y_logits), tf.round(y_true))
predictions = tf.to_float(tf.greater_equal(y_logits, threshold))
pred_match = tf.equal(predictions, tf.round(y_true))
exact_match = tf.reduce_min(tf.to_float(pred_match), axis=1)
return exact_match
def exact_match_prop_fn(*args):
return tf.reduce_mean(exact_match_fn(*args))
这将为您提供一批的平均值。如果你想要整个数据集的平均值,我只收集匹配项(或 correct
和 total
计数)并在 session/tensorflow 之外进行评估,但 streaming_mean
可能就是那样,不确定。