Tensorflow/tfjs "When inputs is an array, neither initialState or constants should be provided"

Tensorflow/tfjs "When inputs is an array, neither initialState or constants should be provided"

我正在尝试使用 python (tensorflow/keras) 制作 ChatBot 来制作、训练和转换神经网络,然后在我的 Angular 应用程序中使用它 tensorflow/tfjs.我正在按照此处找到的示例进行操作:https://github.com/tensorflow/tfjs-examples/tree/master/translation 但也尝试添加嵌入层。

正在创建模型:

latent_dim = 200
encoder_inputs = Input(shape=(encoder_max_length, ), dtype='int32', )
encoder_embedding = Embedding(num_tokens,
                             embedd_size,
                             weights=[word2em],
                             input_length=encoder_max_length,
                             mask_zero=True,
                             trainable=False
                             )(encoder_inputs)
encoder_outputs, state_h, state_c = LSTM(latent_dim, return_state=True)(encoder_embedding)
encoder_states = [state_h, state_c]

decoder_inputs = Input(shape=(decoder_max_length, ), dtype='int32', )
decoder_embedding = Embedding(num_tokens,
                             embedd_size,
                             weights=[word2em],
                             input_length=decoder_max_length,
                             mask_zero=True,
                             trainable=False
                             )(decoder_inputs)
decoder_LSTM = LSTM(latent_dim, return_state=True, return_sequences=True)
decoder_outputs, _, _ = decoder_LSTM(decoder_embedding, initial_state=encoder_sates)
decoder_dense = Dense(num_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

创建编码器/解码器模型 (tfjs):

如果我从 LSTM.apply() 中删除 'initialState',错误就会消失,但无论输入如何,结果都是相同的句子。

prepareEncoder(model) {
    const encInputs = model.input[0];
    const stateH = model.layers[4].output[1];
    const stateC = model.layers[4].output[2];
    const encoderStates = [stateH, stateC];
    this.encoder = tf.model({inputs: encInputs, outputs: encoderStates});
  }

  prepareDecoder(model) {
    const tmp = model.layers[4].output[1];
    const latentDim = tmp.shape[tmp.shape.length - 1];

    const decoderStateInputH = tf.input({shape: [latentDim], name: 'decoder_state_input_h'});
    const decoderStateInputC = tf.input({shape: [latentDim], name: 'decoder_state_input_c'});
    const decoderStateInputs = [decoderStateInputH, decoderStateInputC];

    const decoderLSTM = model.layers[5];
    const decoderInputs = model.input[1];
    const decoderEmbedding = decoderLSTM.input[0];

    const applyOutputs = decoderLSTM.apply(decoderEmbedding, {initialState: decoderStateInputs});
    let decoderOutputs = applyOutputs[0];
    const decoderStateH = applyOutputs[1];
    const decoderStateC = applyOutputs[2];

    const decoderStates = [decoderStateH, decoderStateC];
    const decoderDense = model.layers[6];
    decoderOutputs = decoderDense.apply(decoderOutputs);

    this.decoder = tf.model({
      inputs: [decoderInputs].concat(decoderStateInputs),
      outputs: [decoderOutputs].concat(decoderStates)
    });
  }

使用模型进行预测:

  botReply(input_seq) {
    let states_value = this.encoder.predict(input_seq);
    let target_seq = tf.buffer([1, data['dec_max_length']], 'int32');
    target_seq.set(dict['<START>'], 0, 0);
    let stop_condition = false;
    let decoded_sentence = '';
    let word_count = 1;
    while (!stop_condition) {
      let predict_outputs = this.decoder.predict([target_seq.toTensor()].concat(states_value));

这就是一切失败的地方,我收到以下错误:

ERROR Error: When inputs is an array, neither initialState or constants should be provided
    at standardizeArgs (recurrent.js:54)
    at LSTM.apply (recurrent.js:465)
    at execute (executor.js:275)
    at training.js:856
    at engine.js:307
    at Engine.scopedRun (engine.js:317)
    at Engine.tidy (engine.js:306)
    at Module.tidy (globals.js:166)
    at training.js:839
    at engine.js:307

可能值得一提的事情。

任何帮助将不胜感激,因为我真的 运行 没有想法。

试试这个 - 从嵌入层中删除 'mask_zero=True',看看是否能解决问题。