根据现有检查点定义 TensorFlow 网络密钥名称
Define TensorFlow network key names according to an existing checkpoint
我使用 Nvidia DIGITS 训练了 LeNet-gray-28x28 图像检测 Tensorflow 模型,得到了我期望的结果。
现在,我必须 class 验证 DIGITS 之外的一些图像,我想使用我训练过的模型。
所以我得到了 DIGITS 使用的 LeNet 模型,我创建了一个 class 来使用它:
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tflearn
from tflearn.layers.core import input_data
class LeNetModel():
def gray28(self, nclasses):
x = input_data(shape=[None, 28, 28, 1])
# scale (divide by MNIST std)
# x = x * 0.0125
with slim.arg_scope([slim.conv2d, slim.fully_connected],
model = slim.conv2d(x, 20, [5, 5], padding='VALID', scope='conv1')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool1')
model = slim.conv2d(model, 50, [5, 5], padding='VALID', scope='conv2')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool2')
model = slim.flatten(model)
model = slim.fully_connected(model, 500, scope='fc1')
model = slim.dropout(model, 0.5, is_training=False, scope='do1')
model = slim.fully_connected(model, nclasses, activation_fn=None, scope='fc2')
return tflearn.DNN(model)
我从 DIGITS 下载我的模型并使用(在另一个文件中)实例化它:
self.ballmodel = LeNetModel().gray28(2)
2017-11-26 14:55:50.330524: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conv1/biases not found in checkpoint
2017-11-26 14:55:50.330948: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Global_Step not found in checkpoint
2017-11-26 14:55:50.331270: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key is_training not found in checkpoint
2017-11-26 14:55:50.331564: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conv2/weights not found in checkpoint
2017-11-26 14:55:50.332823: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conv1/weights not found in checkpoint
2017-11-26 14:55:50.332891: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conv2/biases not found in checkpoint
2017-11-26 14:55:50.333620: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key fc2/weights not found in checkpoint
2017-11-26 14:55:50.334021: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key fc1/weights not found in checkpoint
2017-11-26 14:55:50.334173: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key fc1/biases not found in checkpoint
2017-11-26 14:55:50.334431: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key fc2/biases not found in checkpoint
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Key conv1/biases not found in checkpoint
[[Node: save_1/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_1/tensor_names, save_1/RestoreV2_1/shape_and_slices)]]
[[Node: save_1/RestoreV2_1/_19 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_38_save_1/RestoreV2_1", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
所以我使用 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/inspect_checkpoint.py 脚本来检查我的检查点包含的键名,我得到类似的东西:
model = slim.conv2d(x, 20, [5, 5], padding='VALID', scope='model/conv1')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='model/pool1')
model = slim.conv2d(model, 50, [5, 5], padding='VALID', scope='model/conv2')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='model/pool2')
model = slim.flatten(model)
model = slim.fully_connected(model, 500, scope='model/fc1')
model = slim.dropout(model, 0.5, is_training=False, scope='model/do1')
model = slim.fully_connected(model, nclasses,
- 我感觉这不是解决问题的正确方法
- 我无法修复两个键:
- Global_Step(我的检查点中有一个 global_step 密钥)
- is_training(不知道是什么)
首先,我结合使用contrib/slim和contrib/tflearn,即使可以,也不是很相关。所以我只使用 slim 重写了网络:
import tensorflow as tf
import tensorflow.contrib.slim as slim
class LeNetModel():
def gray28(self, nclasses):
# x = input_data(shape=[None, 28, 28, 1])
x = tf.placeholder(tf.float32, shape=[1, 28, 28], name="x")
rs = tf.reshape(x, shape=[-1, 28, 28, 1])
# scale (divide by MNIST std)
# x = x * 0.0125
with slim.arg_scope([slim.conv2d, slim.fully_connected],
model = slim.conv2d(rs, 20, [5, 5], padding='VALID', scope='conv1')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool1')
model = slim.conv2d(model, 50, [5, 5], padding='VALID', scope='conv2')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool2')
model = slim.flatten(model)
model = slim.fully_connected(model, 500, scope='fc1')
model = slim.dropout(model, 0.5, is_training=True, scope='do1')
model = slim.fully_connected(model, nclasses, activation_fn=None, scope='fc2')
return x, model
我 return x 占位符和模型,我用它来加载 DIGITS 预训练模型(检查点):
import tensorflow as tf
import tensorflow.contrib.slim as slim
import cv2
from models.lenet import LeNetModel
# Helper function to load/resize images
def image(path):
img = cv2.imread(path, 0)
return cv2.resize(img, dsize=(28,28))
# Define a function that adds the model/ prefix to all variables :
def name_in_checkpoint(var):
return 'model/' + var.op.name
#Instantiate the model
x, model = LeNetModel().gray28(2)
# Define the variables to restore :
# Exclude the "is_training" that I don't care about
variables_to_restore = slim.get_variables_to_restore(exclude=["is_training"])
# Rename the other variables with the function name_in_checkpoint
variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}
# Create a Saver to restore the checkpoint, given the variables
restorer = tf.train.Saver(variables_to_restore)
#Launch a session to restore the checkpoint and try to infer some images :
with tf.Session() as sess:
# Restore variables from disk.
restorer.restore(sess, "src/prototype/models/snapshot_5.ckpt")
print("Model restored.")
print(sess.run(model, feed_dict={x:[image("/home/damien/Vidéos/1/positives/img/1-img143.jpg")]}))
print(sess.run(model, feed_dict={x:[image("/home/damien/Vidéos/0/positives/img/1-img1.jpg")]}))
我使用 Nvidia DIGITS 训练了 LeNet-gray-28x28 图像检测 Tensorflow 模型,得到了我期望的结果。 现在,我必须 class 验证 DIGITS 之外的一些图像,我想使用我训练过的模型。
所以我得到了 DIGITS 使用的 LeNet 模型,我创建了一个 class 来使用它:
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tflearn
from tflearn.layers.core import input_data
class LeNetModel():
def gray28(self, nclasses):
x = input_data(shape=[None, 28, 28, 1])
# scale (divide by MNIST std)
# x = x * 0.0125
with slim.arg_scope([slim.conv2d, slim.fully_connected],
model = slim.conv2d(x, 20, [5, 5], padding='VALID', scope='conv1')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool1')
model = slim.conv2d(model, 50, [5, 5], padding='VALID', scope='conv2')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool2')
model = slim.flatten(model)
model = slim.fully_connected(model, 500, scope='fc1')
model = slim.dropout(model, 0.5, is_training=False, scope='do1')
model = slim.fully_connected(model, nclasses, activation_fn=None, scope='fc2')
return tflearn.DNN(model)
我从 DIGITS 下载我的模型并使用(在另一个文件中)实例化它:
self.ballmodel = LeNetModel().gray28(2)
2017-11-26 14:55:50.330524: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conv1/biases not found in checkpoint
2017-11-26 14:55:50.330948: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Global_Step not found in checkpoint
2017-11-26 14:55:50.331270: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key is_training not found in checkpoint
2017-11-26 14:55:50.331564: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conv2/weights not found in checkpoint
2017-11-26 14:55:50.332823: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conv1/weights not found in checkpoint
2017-11-26 14:55:50.332891: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conv2/biases not found in checkpoint
2017-11-26 14:55:50.333620: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key fc2/weights not found in checkpoint
2017-11-26 14:55:50.334021: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key fc1/weights not found in checkpoint
2017-11-26 14:55:50.334173: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key fc1/biases not found in checkpoint
2017-11-26 14:55:50.334431: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key fc2/biases not found in checkpoint
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Key conv1/biases not found in checkpoint
[[Node: save_1/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_1/tensor_names, save_1/RestoreV2_1/shape_and_slices)]]
[[Node: save_1/RestoreV2_1/_19 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_38_save_1/RestoreV2_1", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
所以我使用 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/inspect_checkpoint.py 脚本来检查我的检查点包含的键名,我得到类似的东西:
model = slim.conv2d(x, 20, [5, 5], padding='VALID', scope='model/conv1')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='model/pool1')
model = slim.conv2d(model, 50, [5, 5], padding='VALID', scope='model/conv2')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='model/pool2')
model = slim.flatten(model)
model = slim.fully_connected(model, 500, scope='model/fc1')
model = slim.dropout(model, 0.5, is_training=False, scope='model/do1')
model = slim.fully_connected(model, nclasses,
- 我感觉这不是解决问题的正确方法
- 我无法修复两个键:
- Global_Step(我的检查点中有一个 global_step 密钥)
- is_training(不知道是什么)
首先,我结合使用contrib/slim和contrib/tflearn,即使可以,也不是很相关。所以我只使用 slim 重写了网络:
import tensorflow as tf
import tensorflow.contrib.slim as slim
class LeNetModel():
def gray28(self, nclasses):
# x = input_data(shape=[None, 28, 28, 1])
x = tf.placeholder(tf.float32, shape=[1, 28, 28], name="x")
rs = tf.reshape(x, shape=[-1, 28, 28, 1])
# scale (divide by MNIST std)
# x = x * 0.0125
with slim.arg_scope([slim.conv2d, slim.fully_connected],
model = slim.conv2d(rs, 20, [5, 5], padding='VALID', scope='conv1')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool1')
model = slim.conv2d(model, 50, [5, 5], padding='VALID', scope='conv2')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool2')
model = slim.flatten(model)
model = slim.fully_connected(model, 500, scope='fc1')
model = slim.dropout(model, 0.5, is_training=True, scope='do1')
model = slim.fully_connected(model, nclasses, activation_fn=None, scope='fc2')
return x, model
我 return x 占位符和模型,我用它来加载 DIGITS 预训练模型(检查点):
import tensorflow as tf
import tensorflow.contrib.slim as slim
import cv2
from models.lenet import LeNetModel
# Helper function to load/resize images
def image(path):
img = cv2.imread(path, 0)
return cv2.resize(img, dsize=(28,28))
# Define a function that adds the model/ prefix to all variables :
def name_in_checkpoint(var):
return 'model/' + var.op.name
#Instantiate the model
x, model = LeNetModel().gray28(2)
# Define the variables to restore :
# Exclude the "is_training" that I don't care about
variables_to_restore = slim.get_variables_to_restore(exclude=["is_training"])
# Rename the other variables with the function name_in_checkpoint
variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}
# Create a Saver to restore the checkpoint, given the variables
restorer = tf.train.Saver(variables_to_restore)
#Launch a session to restore the checkpoint and try to infer some images :
with tf.Session() as sess:
# Restore variables from disk.
restorer.restore(sess, "src/prototype/models/snapshot_5.ckpt")
print("Model restored.")
print(sess.run(model, feed_dict={x:[image("/home/damien/Vidéos/1/positives/img/1-img143.jpg")]}))
print(sess.run(model, feed_dict={x:[image("/home/damien/Vidéos/0/positives/img/1-img1.jpg")]}))