TF 2.3.0 training keras model using tf dataset with sample weights 不适用于指标
TF 2.3.0 training keras model using tf dataset with sample weights does not apply to metrics
我将 sample_weight 作为 tf.data.Dataset 中的第三个元组传入(在掩码的上下文中使用它,所以我的 sample_weight 要么是 0,要么是 1。问题是sample_weight 似乎没有应用于指标计算。(参考:https://www.tensorflow.org/guide/keras/train_and_evaluate#sample_weights)
这是代码片段:
train_ds = tf.data.Dataset.from_tensor_slices((imgs, labels, masks))
train_ds = train_ds.shuffle(1024).repeat().batch(32).prefetch(buffer_size=AUTO)
model.compile(optimizer = Adam(learning_rate=1e-4),
loss = SparseCategoricalCrossentropy(),
metrics = ['sparse_categorical_accuracy'])
model.fit(train_ds, steps_per_epoch = len(imgs)//32, epochs = 20)
训练后的损失非常接近于零,但sparse_categorical_accuracy不是(大约0.89)。因此,我高度怀疑为构造 tf.dataset 而传入的任何 sample_weight(掩码)都不会在训练期间报告指标时得到应用,而损失似乎是正确的。我通过 运行 对未单独屏蔽的子集的预测进一步确认,并确认准确度为 1.0
此外,根据文档:
https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalAccuracy
指标有 3 个参数:y_true、y_pred、sample_weight
那么如何在度量计算中通过 sample_weight 呢?这是keras框架内model.fit(...)的责任吗?到目前为止,我找不到任何示例。
经过一些调试和文档阅读,我发现 .compile 中有 weighted_metrics 参数,我应该使用它而不是 metrics=。我确认这修复了我在共享 colab 中的测试用例。
model.compile(optimizer = Adam(learning_rate=1e-4),
loss = SparseCategoricalCrossentropy(),
weighted_metrics = [SparseCategoricalAccuracy()])
我将 sample_weight 作为 tf.data.Dataset 中的第三个元组传入(在掩码的上下文中使用它,所以我的 sample_weight 要么是 0,要么是 1。问题是sample_weight 似乎没有应用于指标计算。(参考:https://www.tensorflow.org/guide/keras/train_and_evaluate#sample_weights)
这是代码片段:
train_ds = tf.data.Dataset.from_tensor_slices((imgs, labels, masks))
train_ds = train_ds.shuffle(1024).repeat().batch(32).prefetch(buffer_size=AUTO)
model.compile(optimizer = Adam(learning_rate=1e-4),
loss = SparseCategoricalCrossentropy(),
metrics = ['sparse_categorical_accuracy'])
model.fit(train_ds, steps_per_epoch = len(imgs)//32, epochs = 20)
训练后的损失非常接近于零,但sparse_categorical_accuracy不是(大约0.89)。因此,我高度怀疑为构造 tf.dataset 而传入的任何 sample_weight(掩码)都不会在训练期间报告指标时得到应用,而损失似乎是正确的。我通过 运行 对未单独屏蔽的子集的预测进一步确认,并确认准确度为 1.0
此外,根据文档:
https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalAccuracy
指标有 3 个参数:y_true、y_pred、sample_weight
那么如何在度量计算中通过 sample_weight 呢?这是keras框架内model.fit(...)的责任吗?到目前为止,我找不到任何示例。
经过一些调试和文档阅读,我发现 .compile 中有 weighted_metrics 参数,我应该使用它而不是 metrics=。我确认这修复了我在共享 colab 中的测试用例。
model.compile(optimizer = Adam(learning_rate=1e-4),
loss = SparseCategoricalCrossentropy(),
weighted_metrics = [SparseCategoricalAccuracy()])