无法从检查点恢复:bidirectional/backward_lstm/bias
Cannot restore from checkpoint: bidirectional/backward_lstm/bias
我正在尝试在 tensor2tensor 中创建一个简单的基于 LSTM 的 RNN。
到目前为止训练似乎有效,但我无法恢复模型。尝试这样做会抛出一个 NotFoundError
指出来自 LSTM 的偏置节点:
NotFoundError: ..
Key bidirectional/backward_lstm/bias not found in checkpoint
我也不知道为什么会这样。
这实际上应该是另一个问题的解决方法,我可以使用来自 tensor2tensor (https://github.com/tensorflow/tensor2tensor/issues/1616) 的 LSTM 解决类似的问题。
环境
$ pip freeze | grep tensor
mesh-tensorflow==0.0.5
tensor2tensor==1.12.0
tensorboard==1.12.0
tensorflow-datasets==1.0.2
tensorflow-estimator==1.13.0
tensorflow-gpu==1.12.0
tensorflow-metadata==0.9.0
tensorflow-probability==0.5.0
模特身材
def body(self, features):
inputs = features['inputs'][:,:,0,:]
hparams = self._hparams
problem = hparams.problem
encoders = problem.feature_info
max_input_length = 350
max_output_length = 350
encoder = Bidirectional(LSTM(128, return_sequences=True, unroll=False), merge_mode='concat')(inputs)
encoder_last = encoder[:, -1, :]
decoder = LSTM(256, return_sequences=True, unroll=False)(inputs, initial_state=[encoder_last, encoder_last])
attention = dot([decoder, encoder], axes=[2, 2])
attention = Activation('softmax', name='attention')(attention)
context = dot([attention, encoder], axes=[2, 1])
concat = concatenate([context, decoder])
return tf.expand_dims(concat, 2)
完全错误
NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
Key while/lstm_keras/parallel_0_4/lstm_keras/lstm_keras/body/bidirectional/backward_lstm/bias not found in checkpoint
[[node save/RestoreV2 (defined at /home/sfalk/tmp/pycharm_project_265/asr/model/persistence.py:282) = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
可能是什么问题以及如何解决这个问题?
这似乎与 https://github.com/tensorflow/tensor2tensor/issues/1486 有关。 "while" 似乎在使用 tensor2tensor 从检查点恢复期间添加到键名之前。似乎是一个未解决的错误,我们将不胜感激 github。
如果可以我会评论这个,但我的声誉太低了。干杯。
我正在尝试在 tensor2tensor 中创建一个简单的基于 LSTM 的 RNN。
到目前为止训练似乎有效,但我无法恢复模型。尝试这样做会抛出一个 NotFoundError
指出来自 LSTM 的偏置节点:
NotFoundError: ..
Key bidirectional/backward_lstm/bias not found in checkpoint
我也不知道为什么会这样。
这实际上应该是另一个问题的解决方法,我可以使用来自 tensor2tensor (https://github.com/tensorflow/tensor2tensor/issues/1616) 的 LSTM 解决类似的问题。
环境
$ pip freeze | grep tensor
mesh-tensorflow==0.0.5
tensor2tensor==1.12.0
tensorboard==1.12.0
tensorflow-datasets==1.0.2
tensorflow-estimator==1.13.0
tensorflow-gpu==1.12.0
tensorflow-metadata==0.9.0
tensorflow-probability==0.5.0
模特身材
def body(self, features):
inputs = features['inputs'][:,:,0,:]
hparams = self._hparams
problem = hparams.problem
encoders = problem.feature_info
max_input_length = 350
max_output_length = 350
encoder = Bidirectional(LSTM(128, return_sequences=True, unroll=False), merge_mode='concat')(inputs)
encoder_last = encoder[:, -1, :]
decoder = LSTM(256, return_sequences=True, unroll=False)(inputs, initial_state=[encoder_last, encoder_last])
attention = dot([decoder, encoder], axes=[2, 2])
attention = Activation('softmax', name='attention')(attention)
context = dot([attention, encoder], axes=[2, 1])
concat = concatenate([context, decoder])
return tf.expand_dims(concat, 2)
完全错误
NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
Key while/lstm_keras/parallel_0_4/lstm_keras/lstm_keras/body/bidirectional/backward_lstm/bias not found in checkpoint
[[node save/RestoreV2 (defined at /home/sfalk/tmp/pycharm_project_265/asr/model/persistence.py:282) = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
可能是什么问题以及如何解决这个问题?
这似乎与 https://github.com/tensorflow/tensor2tensor/issues/1486 有关。 "while" 似乎在使用 tensor2tensor 从检查点恢复期间添加到键名之前。似乎是一个未解决的错误,我们将不胜感激 github。
如果可以我会评论这个,但我的声誉太低了。干杯。