keras 中的 CTC 损失实现

CTC loss implementation in keras

我正在尝试使用 keras 为我的简化神经网络实施 CTC 损失:

  
def ctc_lambda_func(args):
    y_pred, y_train, input_length, label_length = args
 
    return K.ctc_batch_cost(y_train, y_pred, input_length, label_length)


x_train = x_train.reshape(x_train.shape[0],20, 10).astype('float32')

input_data = layers.Input(shape=(20,10,))
x=layers.Convolution1D(filters=256, kernel_size=3,  padding="same", strides=1, use_bias=False ,activation= 'relu')(input_data)
x=layers.BatchNormalization()(x)
x=layers.Dropout(0.2)(x)

x=layers.Bidirectional (LSTM(units=200 , return_sequences=True)) (x)
x=layers.BatchNormalization()(x)
x=layers.Dropout(0.2)(x)


y_pred=outputs = layers.Dense(5, activation='softmax')(x)
fun = Model(input_data, y_pred)
# fun.summary()

label_length=np.zeros((3800,1))
input_length=np.zeros((3800,1))

for i in range (3799):
    label_length[i,0]=4
    input_length[i,0]=5 
  
y_train = np.array(y_train)
x_train = np.array(x_train)
input_length = np.array(input_length)
label_length = np.array(label_length) 

  
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, y_train, input_length, label_length])
model =keras.models.Model(inputs=[input_data, y_train, input_length, label_length], outputs=loss_out)
model.compile(loss={'ctc': lambda y_train, y_pred: y_pred}, optimizer = 'adam')
model.fit(x=[x_train, y_train, input_length, label_length],  epochs=10, batch_size=100)

我们有 y_true(或 y_train),维度为 (3800,4),因此我输入 label_length=4 和 input_length=5 (+ 1 表示空白)

我遇到这个错误:

ValueError: Input tensors to a Model must come from `tf.keras.Input`. Received: [[0. 1. 0. 0.]
 [0. 1. 0. 0.]
 [0. 1. 0. 0.]
 ...
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]] (missing previous layer metadata).

y_true是这样的:

 [[0. 1. 0. 0.]
 [0. 1. 0. 0.]
 ...
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]]

我的问题是什么?

你误解了长度。不是label的个数类,是序列的实际长度。 CTC只能用于目标符号数小于输入状态数的情况。从技术上讲,输入和输出的数量是相同的,但一些输出是空白。 (这通常发生在语音识别中,您有大量的输入信号 windows 而输出中的音素相对较少。)

假设您必须填充输入和输出才能将它们放在一起:

  • input_length 应为批次中的每个项目包含多少输入实际有效,即不填充;

  • label_length 应包含模型应为批次中的每个项目生成多少个 non-blank 标签。