如何将 decode_batch_predictions() 方法添加到 Keras 验证码 OCR 模型中?
How can I add the decode_batch_predictions() method into the Keras Captcha OCR model?
当前Keras Captcha OCR modelreturns一个CTC编码输出,需要推理后解码
要对此进行解码,需要 运行 在推理后将解码效用函数作为一个单独的步骤。
preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds)
解码后的效用函数使用 keras.backend.ctc_decode
,后者又使用贪心或波束搜索解码器。
# A utility function to decode the output of the network
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
:, :max_length
]
# Iterate over the results and get back the text
output_text = []
for res in results:
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
output_text.append(res)
return output_text
我想使用 Keras 训练验证码 OCR 模型,returns CTC 解码为输出,推理后不需要额外的解码步骤。
我该如何实现?
你的问题可以有两种解释。一个是:我想要一个神经网络来解决 CTC 解码步骤已经在网络学习内容中的问题。另一个是您想要一个模型 class 在其内部执行此 CTC 解码,而不使用外部功能函数。
我不知道第一个问题的答案。我什至不知道它是否可行。无论如何,这听起来像是一个困难的理论问题,如果您在这里运气不好,您可能想尝试将其发布到 datascience.stackexchange.com,这是一个更注重理论的社区。
现在,如果您要解决的是问题的第二个工程版本,那我可以帮助您。该问题的解决方案如下:
您需要使用您想要的方法将 class keras.models.Model
替换为 class。我浏览了您发布的 link 中的教程,并附带了以下内容 class:
class ModifiedModel(keras.models.Model):
# A utility function to decode the output of the network
def decode_batch_predictions(self, pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
:, :max_length
]
# Iterate over the results and get back the text
output_text = []
for res in results:
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
output_text.append(res)
return output_text
def predict_texts(self, batch_images):
preds = self.predict(batch_images)
return self.decode_batch_predictions(preds)
你可以给它起你想要的名字,这只是为了说明目的。
定义此 class 后,您将替换行
# Get the prediction model by extracting layers till the output layer
prediction_model = keras.models.Model(
model.get_layer(name="image").input, model.get_layer(name="dense2").output
)
和
prediction_model = ModifiedModel(
model.get_layer(name="image").input, model.get_layer(name="dense2").output
)
然后你可以替换这些行
preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds)
和
pred_texts = prediction_model.predict_texts(batch_images)
实现此目的最可靠的方法是添加一个方法,该方法被称为模型定义的一部分:
def CTCDecoder():
def decoder(y_pred):
input_shape = tf.keras.backend.shape(y_pred)
input_length = tf.ones(shape=input_shape[0]) * tf.keras.backend.cast(
input_shape[1], 'float32')
unpadded = tf.keras.backend.ctc_decode(y_pred, input_length)[0][0]
unpadded_shape = tf.keras.backend.shape(unpadded)
padded = tf.pad(unpadded,
paddings=[[0, 0], [0, input_shape[1] - unpadded_shape[1]]],
constant_values=-1)
return padded
return tf.keras.layers.Lambda(decoder, name='decode')
然后定义模型如下:
prediction_model = keras.models.Model(inputs=inputs, outputs=CTCDecoder()(model.output))
归功于 tulasiram58827。
此实现支持导出到 TFLite,但仅支持 float32。 Quantized (int8) TFLite export 仍然抛出错误,并且是 TF team 的 open ticket。
当前Keras Captcha OCR modelreturns一个CTC编码输出,需要推理后解码
要对此进行解码,需要 运行 在推理后将解码效用函数作为一个单独的步骤。
preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds)
解码后的效用函数使用 keras.backend.ctc_decode
,后者又使用贪心或波束搜索解码器。
# A utility function to decode the output of the network
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
:, :max_length
]
# Iterate over the results and get back the text
output_text = []
for res in results:
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
output_text.append(res)
return output_text
我想使用 Keras 训练验证码 OCR 模型,returns CTC 解码为输出,推理后不需要额外的解码步骤。
我该如何实现?
你的问题可以有两种解释。一个是:我想要一个神经网络来解决 CTC 解码步骤已经在网络学习内容中的问题。另一个是您想要一个模型 class 在其内部执行此 CTC 解码,而不使用外部功能函数。
我不知道第一个问题的答案。我什至不知道它是否可行。无论如何,这听起来像是一个困难的理论问题,如果您在这里运气不好,您可能想尝试将其发布到 datascience.stackexchange.com,这是一个更注重理论的社区。
现在,如果您要解决的是问题的第二个工程版本,那我可以帮助您。该问题的解决方案如下:
您需要使用您想要的方法将 class keras.models.Model
替换为 class。我浏览了您发布的 link 中的教程,并附带了以下内容 class:
class ModifiedModel(keras.models.Model):
# A utility function to decode the output of the network
def decode_batch_predictions(self, pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
:, :max_length
]
# Iterate over the results and get back the text
output_text = []
for res in results:
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
output_text.append(res)
return output_text
def predict_texts(self, batch_images):
preds = self.predict(batch_images)
return self.decode_batch_predictions(preds)
你可以给它起你想要的名字,这只是为了说明目的。 定义此 class 后,您将替换行
# Get the prediction model by extracting layers till the output layer
prediction_model = keras.models.Model(
model.get_layer(name="image").input, model.get_layer(name="dense2").output
)
和
prediction_model = ModifiedModel(
model.get_layer(name="image").input, model.get_layer(name="dense2").output
)
然后你可以替换这些行
preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds)
和
pred_texts = prediction_model.predict_texts(batch_images)
实现此目的最可靠的方法是添加一个方法,该方法被称为模型定义的一部分:
def CTCDecoder():
def decoder(y_pred):
input_shape = tf.keras.backend.shape(y_pred)
input_length = tf.ones(shape=input_shape[0]) * tf.keras.backend.cast(
input_shape[1], 'float32')
unpadded = tf.keras.backend.ctc_decode(y_pred, input_length)[0][0]
unpadded_shape = tf.keras.backend.shape(unpadded)
padded = tf.pad(unpadded,
paddings=[[0, 0], [0, input_shape[1] - unpadded_shape[1]]],
constant_values=-1)
return padded
return tf.keras.layers.Lambda(decoder, name='decode')
然后定义模型如下:
prediction_model = keras.models.Model(inputs=inputs, outputs=CTCDecoder()(model.output))
归功于 tulasiram58827。
此实现支持导出到 TFLite,但仅支持 float32。 Quantized (int8) TFLite export 仍然抛出错误,并且是 TF team 的 open ticket。