自定义指标仅适用于急切执行

Custom metrics work with eager execution only

基于我在 tensorflow 中问 , I'm trying to make custom metrics word_accuracy and char_accuracy work with CRNN-CTC model implementation 的一个问题的答案。它在 link 后 运行 中运行得非常好:

import tensorflow as tf

tf.config.run_functions_eagerly(True)

这里是CTC自定义层以及精度计算函数:

def calculate_accuracy(y_true, y_pred, metric, unknown_placeholder):
    y_pred = tf.stack(y_pred)
    y_true = tf.cast(y_true, y_pred.dtype)
    unknown_indices = tf.where(y_pred == -1)
    y_pred = tf.tensor_scatter_nd_update(
        y_pred,
        unknown_indices,
        tf.cast(tf.ones(unknown_indices.shape[0]) * unknown_placeholder, tf.int64),
    )
    if metric == 'word':
        return tf.where(tf.reduce_all(y_true == y_pred, 1)).shape[0] / y_true.shape[0]
    if metric == 'char':
        return tf.where(y_true == y_pred).shape[0] / tf.reduce_prod(y_true.shape)
    return 0


class CTCLayer(Layer):
    def __init__(self, max_label_length, unknown_placeholder, **kwargs):
        super().__init__(**kwargs)
        self.max_label_length = max_label_length
        self.unknown_placeholder = unknown_placeholder

    def call(self, *args):
        y_true, y_pred = args
        batch_length = tf.cast(tf.shape(y_true)[0], dtype='int64')
        input_length = tf.cast(tf.shape(y_pred)[1], dtype='int64')
        label_length = tf.cast(tf.shape(y_true)[1], dtype='int64')
        input_length = input_length * tf.ones(shape=(batch_length, 1), dtype='int64')
        label_length = label_length * tf.ones(shape=(batch_length, 1), dtype='int64')
        loss = tf.keras.backend.ctc_batch_cost(
            y_true, y_pred, input_length, label_length
        )
        if y_true.shape[1] is not None:  # this is to prevent an error at model creation
            predictions = decode_batch_predictions(y_pred, self.max_label_length)
            self.add_metric(
                calculate_accuracy(
                    y_true, predictions, 'word', self.unknown_placeholder
                ),
                'word_accuracy',
            )
            self.add_metric(
                calculate_accuracy(
                    y_true, predictions, 'char', self.unknown_placeholder
                ),
                'char_accuracy',
            )
        self.add_loss(loss)
        return y_pred

if y_true.shape[1] is not None 块旨在防止在创建模型时发生错误,因为传递的是占位符而不是实际张量。这是如果 if 语句不存在时会发生的情况(无论是否急于执行,我仍然会遇到同样的错误)

3 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    697       except Exception as e:  # pylint:disable=broad-except
    698         if hasattr(e, 'ag_error_metadata'):
--> 699           raise e.ag_error_metadata.to_exception(e)
    700         else:
    701           raise

ValueError: Exception encountered when calling layer "ctc_loss" (type CTCLayer).

in user code:

    File "<ipython-input-6-fabf4ec5a640>", line 67, in call  *
        predictions = decode_batch_predictions(y_pred, self.max_label_length)
    File "<ipython-input-6-fabf4ec5a640>", line 23, in decode_batch_predictions  *
        results = tf.keras.backend.ctc_decode(
    File "/usr/local/lib/python3.7/dist-packages/keras/backend.py", line 6436, in ctc_decode
        inputs=y_pred, sequence_length=input_length)

    ValueError: Shape must be rank 1 but is rank 0 for '{{node ctc_loss/CTCGreedyDecoder}} = CTCGreedyDecoder[T=DT_FLOAT, blank_index=-1, merge_repeated=true](ctc_loss/Log_1, ctc_loss/Cast_9)' with input shapes: [31,?,20], [].


Call arguments received:
  • args=('tf.Tensor(shape=(None, None), dtype=float32)', 'tf.Tensor(shape=(None, 31, 20), dtype=float32)')

注意:在图形执行中,标签的形状总是(None, None),所以if块下添加指标的代码永远不会执行。要使指标正常工作,只需 运行 我包含的笔记本,无需修改,稍后修改它以重现错误。

启用即时执行后您应该看到以下内容:

/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py:4527: UserWarning: Even though the `tf.config.experimental_run_functions_eagerly` option is set, this option does not apply to tf.data functions. To force eager execution of tf.data functions, please use `tf.data.experimental.enable_debug_mode()`.
  "Even though the `tf.config.experimental_run_functions_eagerly` "
Epoch 1/100
     59/Unknown - 42s 177ms/step - loss: 18.1605 - word_accuracy: 0.0000e+00 - char_accuracy: 2.1186e-04
Epoch 00001: val_loss improved from inf to 17.36043, saving model to 1k_captcha.tf

59/59 [==============================] - 44s 213ms/step - loss: 18.1605 - word_accuracy: 0.0000e+00 - char_accuracy: 2.1186e-04 - val_loss: 17.3604 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0000e+00
Epoch 2/100
59/59 [==============================] - ETA: 0s - loss: 16.1261 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0021
Epoch 00002: val_loss improved from 17.36043 to 16.20875, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 210ms/step - loss: 16.1261 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0021 - val_loss: 16.2087 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0000e+00
Epoch 3/100
59/59 [==============================] - ETA: 0s - loss: 15.8597 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0110
Epoch 00003: val_loss improved from 16.20875 to 16.11712, saving model to 1k_captcha.tf

59/59 [==============================] - 12s 204ms/step - loss: 15.8597 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0110 - val_loss: 16.1171 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0071
Epoch 4/100
59/59 [==============================] - ETA: 0s - loss: 15.3741 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0184
Epoch 00004: val_loss did not improve from 16.11712

59/59 [==============================] - 12s 207ms/step - loss: 15.3741 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0184 - val_loss: 16.6811 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0143
Epoch 5/100
59/59 [==============================] - ETA: 0s - loss: 14.9846 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0225
Epoch 00005: val_loss improved from 16.11712 to 15.23923, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 214ms/step - loss: 14.9846 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0225 - val_loss: 15.2392 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0268
Epoch 6/100
59/59 [==============================] - ETA: 0s - loss: 14.4598 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0258
Epoch 00006: val_loss did not improve from 15.23923

59/59 [==============================] - 12s 207ms/step - loss: 14.4598 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0258 - val_loss: 18.6373 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0089
Epoch 7/100
59/59 [==============================] - ETA: 0s - loss: 13.8650 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0335
Epoch 00007: val_loss improved from 15.23923 to 14.37547, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 215ms/step - loss: 13.8650 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0335 - val_loss: 14.3755 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0393
Epoch 8/100
59/59 [==============================] - ETA: 0s - loss: 13.1221 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0422
Epoch 00008: val_loss did not improve from 14.37547

59/59 [==============================] - 13s 208ms/step - loss: 13.1221 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0422 - val_loss: 14.4376 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0393
Epoch 9/100
59/59 [==============================] - ETA: 0s - loss: 12.2508 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0780
Epoch 00009: val_loss did not improve from 14.37547

59/59 [==============================] - 13s 211ms/step - loss: 12.2508 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0780 - val_loss: 14.8398 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0500
Epoch 10/100
59/59 [==============================] - ETA: 0s - loss: 11.0290 - word_accuracy: 0.0000e+00 - char_accuracy: 0.1460
Epoch 00010: val_loss did not improve from 14.37547

59/59 [==============================] - 13s 215ms/step - loss: 11.0290 - word_accuracy: 0.0000e+00 - char_accuracy: 0.1460 - val_loss: 14.4219 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.1054
Epoch 11/100
59/59 [==============================] - ETA: 0s - loss: 9.8587 - word_accuracy: 0.0011 - char_accuracy: 0.2004
Epoch 00011: val_loss improved from 14.37547 to 10.11944, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 212ms/step - loss: 9.8587 - word_accuracy: 0.0011 - char_accuracy: 0.2004 - val_loss: 10.1194 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.1750
Epoch 12/100
59/59 [==============================] - ETA: 0s - loss: 8.6827 - word_accuracy: 0.0032 - char_accuracy: 0.2388
Epoch 00012: val_loss did not improve from 10.11944

59/59 [==============================] - 13s 216ms/step - loss: 8.6827 - word_accuracy: 0.0032 - char_accuracy: 0.2388 - val_loss: 10.3900 - val_word_accuracy: 0.0089 - val_char_accuracy: 0.1714
Epoch 13/100
59/59 [==============================] - ETA: 0s - loss: 7.4976 - word_accuracy: 0.0127 - char_accuracy: 0.3047
Epoch 00013: val_loss improved from 10.11944 to 8.38430, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 215ms/step - loss: 7.4976 - word_accuracy: 0.0127 - char_accuracy: 0.3047 - val_loss: 8.3843 - val_word_accuracy: 0.0179 - val_char_accuracy: 0.2714
Epoch 14/100
59/59 [==============================] - ETA: 0s - loss: 6.6434 - word_accuracy: 0.0508 - char_accuracy: 0.3519
Epoch 00014: val_loss did not improve from 8.38430

59/59 [==============================] - 13s 217ms/step - loss: 6.6434 - word_accuracy: 0.0508 - char_accuracy: 0.3519 - val_loss: 9.5689 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.2571
Epoch 15/100
59/59 [==============================] - ETA: 0s - loss: 5.3200 - word_accuracy: 0.1398 - char_accuracy: 0.4271
Epoch 00015: val_loss improved from 8.38430 to 6.74445, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 214ms/step - loss: 5.3200 - word_accuracy: 0.1398 - char_accuracy: 0.4271 - val_loss: 6.7445 - val_word_accuracy: 0.0804 - val_char_accuracy: 0.3482
Epoch 16/100
59/59 [==============================] - ETA: 0s - loss: 4.4252 - word_accuracy: 0.2108 - char_accuracy: 0.4799
Epoch 00016: val_loss improved from 6.74445 to 5.40682, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 222ms/step - loss: 4.4252 - word_accuracy: 0.2108 - char_accuracy: 0.4799 - val_loss: 5.4068 - val_word_accuracy: 0.1161 - val_char_accuracy: 0.4446
Epoch 17/100
59/59 [==============================] - ETA: 0s - loss: 3.8119 - word_accuracy: 0.2691 - char_accuracy: 0.5206
Epoch 00017: val_loss improved from 5.40682 to 4.76755, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 220ms/step - loss: 3.8119 - word_accuracy: 0.2691 - char_accuracy: 0.5206 - val_loss: 4.7676 - val_word_accuracy: 0.1964 - val_char_accuracy: 0.4929
Epoch 18/100
59/59 [==============================] - ETA: 0s - loss: 3.1290 - word_accuracy: 0.3379 - char_accuracy: 0.5712
Epoch 00018: val_loss improved from 4.76755 to 4.45828, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 221ms/step - loss: 3.1290 - word_accuracy: 0.3379 - char_accuracy: 0.5712 - val_loss: 4.4583 - val_word_accuracy: 0.2768 - val_char_accuracy: 0.5375
Epoch 19/100
59/59 [==============================] - ETA: 0s - loss: 2.6048 - word_accuracy: 0.4163 - char_accuracy: 0.6267
Epoch 00019: val_loss improved from 4.45828 to 4.13174, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 222ms/step - loss: 2.6048 - word_accuracy: 0.4163 - char_accuracy: 0.6267 - val_loss: 4.1317 - val_word_accuracy: 0.2054 - val_char_accuracy: 0.5143
Epoch 20/100
59/59 [==============================] - ETA: 0s - loss: 2.1555 - word_accuracy: 0.5117 - char_accuracy: 0.6979
Epoch 00020: val_loss improved from 4.13174 to 3.35257, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 223ms/step - loss: 2.1555 - word_accuracy: 0.5117 - char_accuracy: 0.6979 - val_loss: 3.3526 - val_word_accuracy: 0.3482 - val_char_accuracy: 0.5518
Epoch 21/100
59/59 [==============================] - ETA: 0s - loss: 1.8185 - word_accuracy: 0.5604 - char_accuracy: 0.7284
Epoch 00021: val_loss did not improve from 3.35257

59/59 [==============================] - 13s 223ms/step - loss: 1.8185 - word_accuracy: 0.5604 - char_accuracy: 0.7284 - val_loss: 3.5486 - val_word_accuracy: 0.3304 - val_char_accuracy: 0.5500
Epoch 22/100
59/59 [==============================] - ETA: 0s - loss: 1.4279 - word_accuracy: 0.6578 - char_accuracy: 0.8021
Epoch 00022: val_loss improved from 3.35257 to 2.97987, saving model to 1k_captcha.tf

59/59 [==============================] - 14s 229ms/step - loss: 1.4279 - word_accuracy: 0.6578 - char_accuracy: 0.8021 - val_loss: 2.9799 - val_word_accuracy: 0.3750 - val_char_accuracy: 0.6679
Epoch 23/100
59/59 [==============================] - ETA: 0s - loss: 1.1666 - word_accuracy: 0.7278 - char_accuracy: 0.8417
Epoch 00023: val_loss did not improve from 2.97987

59/59 [==============================] - 13s 224ms/step - loss: 1.1666 - word_accuracy: 0.7278 - char_accuracy: 0.8417 - val_loss: 5.2543 - val_word_accuracy: 0.1429 - val_char_accuracy: 0.4768
Epoch 24/100
59/59 [==============================] - ETA: 0s - loss: 1.0938 - word_accuracy: 0.7511 - char_accuracy: 0.8576
Epoch 00024: val_loss improved from 2.97987 to 2.72415, saving model to 1k_captcha.tf

59/59 [==============================] - 14s 226ms/step - loss: 1.0938 - word_accuracy: 0.7511 - char_accuracy: 0.8576 - val_loss: 2.7242 - val_word_accuracy: 0.4911 - val_char_accuracy: 0.7250
Epoch 25/100
59/59 [==============================] - ETA: 0s - loss: 0.8378 - word_accuracy: 0.7977 - char_accuracy: 0.8837
Epoch 00025: val_loss improved from 2.72415 to 2.47315, saving model to 1k_captcha.tf

59/59 [==============================] - 13s 223ms/step - loss: 0.8378 - word_accuracy: 0.7977 - char_accuracy: 0.8837 - val_loss: 2.4731 - val_word_accuracy: 0.4554 - val_char_accuracy: 0.6964
Epoch 26/100
59/59 [==============================] - ETA: 0s - loss: 0.6497 - word_accuracy: 0.8633 - char_accuracy: 0.9195
Epoch 00026: val_loss improved from 2.47315 to 2.10521, saving model to 1k_captcha.tf

59/59 [==============================] - 14s 227ms/step - loss: 0.6497 - word_accuracy: 0.8633 - char_accuracy: 0.9195 - val_loss: 2.1052 - val_word_accuracy: 0.4821 - val_char_accuracy: 0.6929
Epoch 27/100
59/59 [==============================] - ETA: 0s - loss: 0.4810 - word_accuracy: 0.9153 - char_accuracy: 0.9528
Epoch 00027: val_loss did not improve from 2.10521

59/59 [==============================] - 14s 226ms/step - loss: 0.4810 - word_accuracy: 0.9153 - char_accuracy: 0.9528 - val_loss: 2.5292 - val_word_accuracy: 0.4375 - val_char_accuracy: 0.7054
Epoch 28/100
59/59 [==============================] - ETA: 0s - loss: 0.4621 - word_accuracy: 0.9121 - char_accuracy: 0.9500
Epoch 00028: val_loss did not improve from 2.10521

59/59 [==============================] - 14s 224ms/step - loss: 0.4621 - word_accuracy: 0.9121 - char_accuracy: 0.9500 - val_loss: 2.1713 - val_word_accuracy: 0.4821 - val_char_accuracy: 0.7268

要重现这个问题,如果你运行之前的notebook,你可能需要重启运行time,然后尝试运行ning without eager execution,metrics永远不会展示。如果要重现错误,请注释掉 if y_true.shape[1] is not None 行并将 if 块与其余代码合并。我需要在提供的笔记本中修改什么才能使指标按演示的方式工作,而无需使用急切执行?

您可能不喜欢这种解决方案,但您可以尝试更改 calculate_accuracydecode_batch_predictions 函数,以便它们仅使用 tf 操作:

def decode_batch_predictions(predictions, max_label_length, char_lookup=None, increment=0):
    input_length = tf.cast(tf.ones(tf.shape(predictions)[0]), dtype=tf.int32) * tf.cast(tf.shape(predictions)[1], dtype=tf.int32)
    
    results = tf.keras.backend.ctc_decode(
        predictions, input_length=input_length, greedy=True
    )[0][0][:, :max_label_length] + increment

    if char_lookup: # For inference
      output = []
      for result in results:
        result = tf.strings.reduce_join(char_lookup(result)).numpy().decode('utf-8')
        output.append(result)
      return output
    else: # For training
      output = tf.TensorArray(tf.int64, size=0, dynamic_size=True)
      for result in results:
        output = output.write(output.size(), result)
      return output.stack()

def calculate_accuracy(y_true, y_pred, metric, unknown_placeholder):
    y_pred = tf.stack(y_pred)
    y_true = tf.cast(y_true, y_pred.dtype)
    unknown_indices = tf.where(y_pred == -1)
    y_pred = tf.tensor_scatter_nd_update(
        y_pred,
        unknown_indices,
        tf.cast(tf.ones(tf.shape(unknown_indices)[0]) * unknown_placeholder, tf.int64),
    )
    if metric == 'word':
        return tf.shape(tf.where(tf.reduce_all(y_true == y_pred, 1)))[0] / tf.shape(y_true)[0]
    if metric == 'char':
        return tf.shape(tf.where(y_true == y_pred))[0] / tf.reduce_prod(tf.shape(y_true))
    return 0
Writing example: 936/1040 [90.0 %] to e7fe398b-da12-4176-a91c-84a8ca076937-train.tfrecord
Writing example: 1040/1040 [100.0 %] to e7fe398b-da12-4176-a91c-84a8ca076937-valid.tfrecord
Epoch 1/100
     59/Unknown - 107s 470ms/step - loss: 18.2176 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0015
Epoch 00001: val_loss improved from inf to 16.23781, saving model to 1k_captcha.tf

这样您就不必使用 tf.config.run_functions_eagerly(True)if y_true.shape[1] is not None