TensorFlow:无法加载训练有素的模型
TensorFlow : Cant load trained model
我正在尝试使用 tflearn 训练、保存和加载张量流模型
# Building convolutional network
network = input_data(shape=[None, imageSize, imageSize, 1], name='input')
network = conv_2d(network, imageSize, self.windowSize, activation='relu', regularizer="L2")
network = max_pool_2d(network, 2)
network = local_response_normalization(network)
network = conv_2d(network, imageSize * 2, self.windowSize, activation='relu', regularizer="L2")
network = max_pool_2d(network, 2)
network = local_response_normalization(network)
network = fully_connected(network, (dim4 * dim4) * (imageSize * 2), activation='tanh')
network = dropout(network, keep)
network = fully_connected(network, (dim4 * dim4) * (imageSize * 2), activation='tanh')
network = dropout(network, keep)
network = fully_connected(network, n_classes, activation='softmax')
network = regression(network, optimizer='adam', learning_rate=self.learningRate,
loss='categorical_crossentropy', name='target')
model = tflearn.DNN(network, tensorboard_verbose=0, tensorboard_dir='some/dir')
model.fit(
{'input': np.array(myData.train_x).reshape(-1, self.imageSize, self.imageSize, 1)}, {'target': myData.train_y}, n_epoch=self.epochs,
validation_set=(
{'input': np.array(myData.test_x).reshape(-1, self.imageSize, self.imageSize, 1)},
{'target': myData.test_y}),
snapshot_step=100, show_metric=True, run_id='convnet')
model.save("some/path/model")
这部分有效。接下来,我做
model_path = "some/path/model.meta"
if os.path.exists(model_path):
model.load(model_path)
else :
return "need to train the model"
prediction = self.model.predict([<some_input>])
print(str(prediction))
return prediction
这在 model.load(model_path)
失败了。我收到以下错误跟踪
DataLossError (see above for traceback): Unable to open table file some/path/model.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
[[Node: save_5/RestoreV2_4 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_5/Const_0_0, save_5/RestoreV2_4/tensor_names, save_5/RestoreV2_4/shape_and_slices)]]
Caused by op 'save_5/RestoreV2_4', defined at:
是什么意思
Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
看到模型确实保存正确,不是空文件。为什么我加载不出来?
版本信息
tensorflow==1.4.0
tensorflow-tensorboard==0.4.0rc2
tflearn==0.3.2
Python 3.6.3 :: Anaconda, Inc.
答案:
如评论中所述,您保存变量的路径必须包含“.ckpt”文件名。
同样应该通过相同的“.ckpt”文件进行恢复
我正在尝试使用 tflearn 训练、保存和加载张量流模型
# Building convolutional network
network = input_data(shape=[None, imageSize, imageSize, 1], name='input')
network = conv_2d(network, imageSize, self.windowSize, activation='relu', regularizer="L2")
network = max_pool_2d(network, 2)
network = local_response_normalization(network)
network = conv_2d(network, imageSize * 2, self.windowSize, activation='relu', regularizer="L2")
network = max_pool_2d(network, 2)
network = local_response_normalization(network)
network = fully_connected(network, (dim4 * dim4) * (imageSize * 2), activation='tanh')
network = dropout(network, keep)
network = fully_connected(network, (dim4 * dim4) * (imageSize * 2), activation='tanh')
network = dropout(network, keep)
network = fully_connected(network, n_classes, activation='softmax')
network = regression(network, optimizer='adam', learning_rate=self.learningRate,
loss='categorical_crossentropy', name='target')
model = tflearn.DNN(network, tensorboard_verbose=0, tensorboard_dir='some/dir')
model.fit(
{'input': np.array(myData.train_x).reshape(-1, self.imageSize, self.imageSize, 1)}, {'target': myData.train_y}, n_epoch=self.epochs,
validation_set=(
{'input': np.array(myData.test_x).reshape(-1, self.imageSize, self.imageSize, 1)},
{'target': myData.test_y}),
snapshot_step=100, show_metric=True, run_id='convnet')
model.save("some/path/model")
这部分有效。接下来,我做
model_path = "some/path/model.meta"
if os.path.exists(model_path):
model.load(model_path)
else :
return "need to train the model"
prediction = self.model.predict([<some_input>])
print(str(prediction))
return prediction
这在 model.load(model_path)
失败了。我收到以下错误跟踪
DataLossError (see above for traceback): Unable to open table file some/path/model.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
[[Node: save_5/RestoreV2_4 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_5/Const_0_0, save_5/RestoreV2_4/tensor_names, save_5/RestoreV2_4/shape_and_slices)]]
Caused by op 'save_5/RestoreV2_4', defined at:
是什么意思
Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
看到模型确实保存正确,不是空文件。为什么我加载不出来?
版本信息
tensorflow==1.4.0
tensorflow-tensorboard==0.4.0rc2
tflearn==0.3.2
Python 3.6.3 :: Anaconda, Inc.
答案:
如评论中所述,您保存变量的路径必须包含“.ckpt”文件名。
同样应该通过相同的“.ckpt”文件进行恢复