Tensorflow CTC 损失:ctc_merge_repeated 参数
Tensorflow CTC loss: ctc_merge_repeated parameter
我正在使用 Tensorflow 1.0 及其 CTC 损失 [1]。
在训练时,我有时会收到 "No valid path found." 警告(这会损害学习)。 不是,因为其他 Tensorflow 用户有时报告的学习率很高。
稍微分析了一下,我发现了导致这个警告的模式:
- 将长度为 seqLen
的输入序列输入 ctc_loss
- 正在输入一个带有 labelLen 个字符的标签
- label 中有 numRepeatedChars 个重复字符,其中我将 "ab" 计为 0,"aa" 计为 1,"aaa" 计为 2 和等等
- 警告发生,当:seqLen - labelLen < numRepeatedChars
三个例子:
- Ex.1: label="abb", len(label)=3, len(inputSequence)=3 => (3-3=0)<1 is true --> warning
- Ex.2: label="abb", len(label)=3, len(inputSequence)=4 => (4-3=1)<1 is false --> 无警告
- Ex.3: label="bbb", len(label)=3, len(inputSequence)=4 => (4-3=1)<2 is true --> warning
当我现在设置 ctc_loss 参数 ctc_merge_repeated=False 时,警告就会消失。
三个问题:
- Q1:为什么出现重复字符时会有警告?我想,只要输入序列不比目标标注短,就没有问题。当重复的字符合并到标签中时,它会变得更短,因此输入序列不短的条件仍然成立。
- Q2:为什么默认设置的ctc_loss会产生这个警告?重复字符在使用 CTC 的领域很常见,例如手写文本识别 (HTR)
- Q3:做HTR应该用什么设置?当然标签可以有重复的字符。因此 ctc_merge_repeated=False 是有道理的。有什么建议么?
Python 重现警告的程序:
import tensorflow as tf
import numpy as np
def createGraph():
tinputs=tf.placeholder(tf.float32, [100, 1, 65]) # max 100 time steps, 1 batch element, 64+1 classes
tlabels=tf.SparseTensor(tf.placeholder(tf.int64, shape=[None,2]) , tf.placeholder(tf.int32,[None]), tf.placeholder(tf.int64,[2])) # labels
tseqLen=tf.placeholder(tf.int32, [None]) # list of sequence length in batch
tloss=tf.reduce_mean(tf.nn.ctc_loss(labels=tlabels, inputs=tinputs, sequence_length=tseqLen, ctc_merge_repeated=True)) # ctc loss
return (tinputs, tlabels, tseqLen, tloss)
def getNextBatch(nc): # next batch with given number of chars in label
indices=[[0,i] for i in range(nc)]
values=[i%65 for i in range(nc)]
values[0]=0
values[1]=0 # TODO: (un)comment this to trigger warning
shape=[1, nc]
labels=tf.SparseTensorValue(indices, values, shape)
seqLen=[nc]
inputs=np.random.rand(100, 1, 65)
return (labels, inputs, seqLen)
(tinputs, tlabels, tseqLen, tloss)=createGraph()
sess=tf.Session()
sess.run(tf.global_variables_initializer())
nc=3 # number of chars in label
print('next batch with 1 element has label len='+str(nc))
(labels, inputs, seqLen)=getNextBatch(nc)
res=sess.run([tloss], { tlabels: labels, tinputs:inputs, tseqLen:seqLen } )
这是警告来自的 C++ Tensorflow 代码 [2]:
// It is possible that no valid path is found if the activations for the
// targets are zero.
if (log_p_z_x == kLogZero) {
LOG(WARNING) << "No valid path found.";
dy_b = y;
return;
}
[1] https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/nn/ctc_loss
[2] https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/ctc/ctc_loss_calculator.cc
好的,明白了,这不是错误,这就是 CTC 的工作原理:让我们举一个发生警告的例子:输入序列的长度是 2,标签是 "aa"(也是长度 2) .
现在产生 "aa" 的最短路径是 a->blank->a(长度 3)。
但是对于一个标签"ab",最短路径是a->b(长度为2)。
这说明了为什么对于 "aa" 中的重复标签,输入序列必须更长。它只是通过插入空白在 CTC 中对重复标签进行编码的方式。
因此,在固定输入大小时,标签重复会减少允许标签的最大长度。
我正在使用 Tensorflow 1.0 及其 CTC 损失 [1]。 在训练时,我有时会收到 "No valid path found." 警告(这会损害学习)。 不是,因为其他 Tensorflow 用户有时报告的学习率很高。
稍微分析了一下,我发现了导致这个警告的模式:
- 将长度为 seqLen 的输入序列输入 ctc_loss
- 正在输入一个带有 labelLen 个字符的标签
- label 中有 numRepeatedChars 个重复字符,其中我将 "ab" 计为 0,"aa" 计为 1,"aaa" 计为 2 和等等
- 警告发生,当:seqLen - labelLen < numRepeatedChars
三个例子:
- Ex.1: label="abb", len(label)=3, len(inputSequence)=3 => (3-3=0)<1 is true --> warning
- Ex.2: label="abb", len(label)=3, len(inputSequence)=4 => (4-3=1)<1 is false --> 无警告
- Ex.3: label="bbb", len(label)=3, len(inputSequence)=4 => (4-3=1)<2 is true --> warning
当我现在设置 ctc_loss 参数 ctc_merge_repeated=False 时,警告就会消失。
三个问题:
- Q1:为什么出现重复字符时会有警告?我想,只要输入序列不比目标标注短,就没有问题。当重复的字符合并到标签中时,它会变得更短,因此输入序列不短的条件仍然成立。
- Q2:为什么默认设置的ctc_loss会产生这个警告?重复字符在使用 CTC 的领域很常见,例如手写文本识别 (HTR)
- Q3:做HTR应该用什么设置?当然标签可以有重复的字符。因此 ctc_merge_repeated=False 是有道理的。有什么建议么?
Python 重现警告的程序:
import tensorflow as tf
import numpy as np
def createGraph():
tinputs=tf.placeholder(tf.float32, [100, 1, 65]) # max 100 time steps, 1 batch element, 64+1 classes
tlabels=tf.SparseTensor(tf.placeholder(tf.int64, shape=[None,2]) , tf.placeholder(tf.int32,[None]), tf.placeholder(tf.int64,[2])) # labels
tseqLen=tf.placeholder(tf.int32, [None]) # list of sequence length in batch
tloss=tf.reduce_mean(tf.nn.ctc_loss(labels=tlabels, inputs=tinputs, sequence_length=tseqLen, ctc_merge_repeated=True)) # ctc loss
return (tinputs, tlabels, tseqLen, tloss)
def getNextBatch(nc): # next batch with given number of chars in label
indices=[[0,i] for i in range(nc)]
values=[i%65 for i in range(nc)]
values[0]=0
values[1]=0 # TODO: (un)comment this to trigger warning
shape=[1, nc]
labels=tf.SparseTensorValue(indices, values, shape)
seqLen=[nc]
inputs=np.random.rand(100, 1, 65)
return (labels, inputs, seqLen)
(tinputs, tlabels, tseqLen, tloss)=createGraph()
sess=tf.Session()
sess.run(tf.global_variables_initializer())
nc=3 # number of chars in label
print('next batch with 1 element has label len='+str(nc))
(labels, inputs, seqLen)=getNextBatch(nc)
res=sess.run([tloss], { tlabels: labels, tinputs:inputs, tseqLen:seqLen } )
这是警告来自的 C++ Tensorflow 代码 [2]:
// It is possible that no valid path is found if the activations for the
// targets are zero.
if (log_p_z_x == kLogZero) {
LOG(WARNING) << "No valid path found.";
dy_b = y;
return;
}
[1] https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/nn/ctc_loss
[2] https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/ctc/ctc_loss_calculator.cc
好的,明白了,这不是错误,这就是 CTC 的工作原理:让我们举一个发生警告的例子:输入序列的长度是 2,标签是 "aa"(也是长度 2) .
现在产生 "aa" 的最短路径是 a->blank->a(长度 3)。 但是对于一个标签"ab",最短路径是a->b(长度为2)。 这说明了为什么对于 "aa" 中的重复标签,输入序列必须更长。它只是通过插入空白在 CTC 中对重复标签进行编码的方式。
因此,在固定输入大小时,标签重复会减少允许标签的最大长度。