TensorFlow Keras 'accuracy' 引擎盖下的指标实现
TensorFlow Keras 'accuracy' metric under the hood implementation
使用 TensorFlow Keras 构建 classifier 时,通常会在编译步骤中通过指定 metrics=['accuracy']
来监控模型准确性:
model = tf.keras.Model(...)
model.compile(optimizer=..., loss=..., metrics=['accuracy'])
无论模型是否输出 logits 或 class 概率,以及模型是否期望地面实况标签是单热编码向量或整数索引(即,整数在区间 [0, n_classes)
).
如果想使用交叉熵损失,情况就不同了:上述四种情况的每一种组合都需要在编译步骤中传递不同的loss
值:
如果模型输出概率并且真实标签是单热编码的,那么loss='categorical_crossentropy'
有效。
如果模型输出概率并且真实标签是整数索引,那么loss='sparse_categorical_crossentropy'
有效。
如果模型输出 logits 并且真实标签是单热编码的,那么 loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True)
有效。
如果模型输出 logits 并且真实标签是整数索引,则 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
有效。
似乎仅指定 loss='categorical_crossentropy'
不足以处理这四种情况,而指定 metrics=['accuracy']
就足够鲁棒了。
问题
当用户在模型编译步骤中指定 metrics=['accuracy']
时,幕后发生了什么,无论模型是否输出逻辑或概率以及地面真值标签是否是单热编码,都可以正确执行精度计算向量或整数索引?
我怀疑 logits 与概率的情况很简单,因为预测的 class 可以通过任何一种方式作为 argmax 获得,但理想情况下我希望指出 TensorFlow 2 源代码中计算的位置实际上已经完成了。
请注意,我目前正在使用 TensorFlow 2.0.0-rc1。
编辑
在纯 Keras 中,metrics=['accuracy']
在 Model.compile
method.
中显式处理
找到了:这在 tensorflow.python.keras.engine.training_utils.get_metric_function
中处理。特别是检查输出形状以确定使用哪个精度函数。
详细说明,在当前实施中 Model.compile
either delegates metric processing to Model._compile_eagerly
(if executing eagerly) or does it directly. In either case, Model._cache_output_metric_attributes
is called, which calls collect_per_output_metric_info
未加权和加权指标。此函数遍历提供的指标,对每个指标调用 get_metric_function
。
使用 TensorFlow Keras 构建 classifier 时,通常会在编译步骤中通过指定 metrics=['accuracy']
来监控模型准确性:
model = tf.keras.Model(...)
model.compile(optimizer=..., loss=..., metrics=['accuracy'])
无论模型是否输出 logits 或 class 概率,以及模型是否期望地面实况标签是单热编码向量或整数索引(即,整数在区间 [0, n_classes)
).
如果想使用交叉熵损失,情况就不同了:上述四种情况的每一种组合都需要在编译步骤中传递不同的loss
值:
如果模型输出概率并且真实标签是单热编码的,那么
loss='categorical_crossentropy'
有效。如果模型输出概率并且真实标签是整数索引,那么
loss='sparse_categorical_crossentropy'
有效。如果模型输出 logits 并且真实标签是单热编码的,那么
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True)
有效。如果模型输出 logits 并且真实标签是整数索引,则
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
有效。
似乎仅指定 loss='categorical_crossentropy'
不足以处理这四种情况,而指定 metrics=['accuracy']
就足够鲁棒了。
问题
当用户在模型编译步骤中指定 metrics=['accuracy']
时,幕后发生了什么,无论模型是否输出逻辑或概率以及地面真值标签是否是单热编码,都可以正确执行精度计算向量或整数索引?
我怀疑 logits 与概率的情况很简单,因为预测的 class 可以通过任何一种方式作为 argmax 获得,但理想情况下我希望指出 TensorFlow 2 源代码中计算的位置实际上已经完成了。
请注意,我目前正在使用 TensorFlow 2.0.0-rc1。
编辑
在纯 Keras 中,metrics=['accuracy']
在 Model.compile
method.
找到了:这在 tensorflow.python.keras.engine.training_utils.get_metric_function
中处理。特别是检查输出形状以确定使用哪个精度函数。
详细说明,在当前实施中 Model.compile
either delegates metric processing to Model._compile_eagerly
(if executing eagerly) or does it directly. In either case, Model._cache_output_metric_attributes
is called, which calls collect_per_output_metric_info
未加权和加权指标。此函数遍历提供的指标,对每个指标调用 get_metric_function
。