使用 tf.estimator.Estimator 框架进行迁移学习
Transfer learning with tf.estimator.Estimator framework
我正在尝试使用我自己的数据集和 类 对在 imagenet 上预训练的 Inception-resnet v2 模型进行迁移学习。
我的原始代码库是对 tf.slim
示例的修改,我再也找不到了,现在我正在尝试使用 tf.estimator.*
框架重写相同的代码。
但是,我运行正在解决从预训练检查点仅加载一些 权重的问题,并使用默认初始化器初始化其余层。
研究问题,我发现 this GitHub issue and this question,两者都提到需要在我的 model_fn
中使用 tf.train.init_from_checkpoint
。我试过了,但鉴于两者都缺少示例,我想我弄错了。
这是我的最小示例:
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
import numpy as np
import inception_resnet_v2
NUM_CLASSES = 900
IMAGE_SIZE = 299
def input_fn(mode, num_classes, batch_size=1):
# some code that loads images, reshapes them to 299x299x3 and batches them
return tf.constant(np.zeros([batch_size, 299, 299, 3], np.float32)), tf.one_hot(tf.constant(np.zeros([batch_size], np.int32)), NUM_CLASSES)
def model_fn(images, labels, num_classes, mode):
with tf.contrib.slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
logits, end_points = inception_resnet_v2.inception_resnet_v2(images,
num_classes,
is_training=(mode==tf.estimator.ModeKeys.TRAIN))
predictions = {
'classes': tf.argmax(input=logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
scopes = { os.path.dirname(v.name) for v in variables_to_restore }
tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt',
{s+'/':s+'/' for s in scopes})
tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
total_loss = tf.losses.get_total_loss() #obtain the regularization losses as well
# Configure the training op
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.AdamOptimizer(learning_rate=0.00002)
train_op = optimizer.minimize(total_loss, global_step)
else:
train_op = None
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=total_loss,
train_op=train_op)
def main(unused_argv):
# Create the Estimator
classifier = tf.estimator.Estimator(
model_fn=lambda features, labels, mode: model_fn(features, labels, NUM_CLASSES, mode),
model_dir='model/MCVE')
# Train the model
classifier.train(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, NUM_CLASSES, batch_size=1),
steps=1000)
# Evaluate the model and print results
eval_results = classifier.evaluate(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL, NUM_CLASSES, batch_size=1))
print()
print('Evaluation results:\n %s' % eval_results)
if __name__ == '__main__':
tf.app.run(main=main, argv=[sys.argv[0]])
其中 inception_resnet_v2
是 the model implementation in Tensorflow's models repository。
如果我 运行 这个脚本,我从 init_from_checkpoint
得到一堆信息日志,但是,在会话创建时,它似乎试图加载 Logits
权重从检查点开始,由于形状不兼容而失败。这是完整的回溯:
Traceback (most recent call last):
File "<ipython-input-6-06fadd69ae8f>", line 1, in <module>
runfile('C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py', wdir='C:/Users/1/Desktop/transfer_learning_tutorial-master')
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile
execfile(filename, namespace)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 77, in <module>
tf.app.run(main=main, argv=[sys.argv[0]])
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 68, in main
steps=1000)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 302, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 780, in _train_model
log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 368, in MonitoredTrainingSession
stop_grace_period_secs=stop_grace_period_secs)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 673, in __init__
stop_grace_period_secs=stop_grace_period_secs)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 493, in __init__
self._sess = _RecoverableSession(self._coordinated_creator)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 851, in __init__
_WrappedSession.__init__(self, self._create_session())
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 856, in _create_session
return self._sess_creator.create_session()
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 554, in create_session
self.tf_sess = self._session_creator.create_session()
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 428, in create_session
init_fn=self._scaffold.init_fn)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\session_manager.py", line 279, in prepare_session
sess.run(init_op, feed_dict=init_feed_dict)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 889, in run
run_metadata_ptr)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run
feed_dict_tensor, options, run_metadata)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1317, in _do_run
options, run_metadata)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1336, in _do_call
raise type(e)(node_def, op, message)
InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [900] rhs shape= [1001] [[Node: Assign_1145 = Assign[T=DT_FLOAT,
_class=["loc:@InceptionResnetV2/Logits/Logits/biases"], use_locking=true, validate_shape=true,
_device="/job:localhost/replica:0/task:0/device:CPU:0"](InceptionResnetV2/Logits/Logits/biases, checkpoint_initializer_1145)]]
我在使用 init_from_checkpoint
时做错了什么?我们应该如何在 model_fn
中“使用”它?为什么当我明确告诉它不要时,估算器试图从检查点加载 Logits
' 权重?
更新:
根据评论中的建议,我尝试了其他方法来调用tf.train.init_from_checkpoint
。
使用{v.name: v.name}
如果按照评论中的建议,我将调用替换为 {v.name:v.name for v in variables_to_restore}
,我会收到此错误:
ValueError: Assignment map with scope only name InceptionResnetV2/Conv2d_2a_3x3 should map
to scope only InceptionResnetV2/Conv2d_2a_3x3/weights:0. Should be 'scope/': 'other_scope/'.
使用{v.name: v}
如果我尝试使用 name:variable
映射,则会收到以下错误:
ValueError: Tensor InceptionResnetV2/Conv2d_2a_3x3/weights:0 is not found in
inception_resnet_v2_2016_08_30.ckpt checkpoint
{'InceptionResnetV2/Repeat_2/block8_4/Branch_1/Conv2d_0c_3x1/BatchNorm/moving_mean': [256],
'InceptionResnetV2/Repeat/block35_9/Branch_0/Conv2d_1x1/BatchNorm/beta': [32], ...
错误继续列出我认为是检查点中的所有变量名(或者可能是范围?)。
更新(2)
检查上面的最新错误后,我在检查点变量列表中看到 InceptionResnetV2/Conv2d_2a_3x3/weights
是 。 问题是:0
最后!
我现在将验证这是否确实解决了问题,如果是的话 post 一个答案。
感谢@KathyWu 的评论,让我走上正轨,找到了问题。
事实上,我计算 scopes
的方式将包括 InceptionResnetV2/
范围,这将触发 all 变量的加载 "under"范围(即网络中的所有变量)。然而,用正确的字典替换它并非易事。
在可能的范围模式 init_from_checkpoint
accepts 中,我必须使用的是 'scope_variable_name': variable
一种,但没有使用实际的 variable.name
属性.
variable.name
看起来像:'some_scope/variable_name:0'
。 :0
不在检查点变量的名称中,因此使用 scopes = {v.name:v.name for v in variables_to_restore}
将引发 "Variable not found" 错误。
让它工作的技巧是从名称中剥离张量索引:
tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt',
{v.name.split(':')[0]: v for v in variables_to_restore})
我发现 {s+'/':s+'/' for s in scopes}
不起作用,只是因为 variables_to_restore
包含类似 "global_step"
的东西,所以范围包括可以包括所有内容的全局范围。你需要打印 variables_to_restore
,找到 "global_step"
的东西,然后把它放在 "exclude"
.
我正在尝试使用我自己的数据集和 类 对在 imagenet 上预训练的 Inception-resnet v2 模型进行迁移学习。
我的原始代码库是对 tf.slim
示例的修改,我再也找不到了,现在我正在尝试使用 tf.estimator.*
框架重写相同的代码。
但是,我运行正在解决从预训练检查点仅加载一些 权重的问题,并使用默认初始化器初始化其余层。
研究问题,我发现 this GitHub issue and this question,两者都提到需要在我的 model_fn
中使用 tf.train.init_from_checkpoint
。我试过了,但鉴于两者都缺少示例,我想我弄错了。
这是我的最小示例:
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
import numpy as np
import inception_resnet_v2
NUM_CLASSES = 900
IMAGE_SIZE = 299
def input_fn(mode, num_classes, batch_size=1):
# some code that loads images, reshapes them to 299x299x3 and batches them
return tf.constant(np.zeros([batch_size, 299, 299, 3], np.float32)), tf.one_hot(tf.constant(np.zeros([batch_size], np.int32)), NUM_CLASSES)
def model_fn(images, labels, num_classes, mode):
with tf.contrib.slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
logits, end_points = inception_resnet_v2.inception_resnet_v2(images,
num_classes,
is_training=(mode==tf.estimator.ModeKeys.TRAIN))
predictions = {
'classes': tf.argmax(input=logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
scopes = { os.path.dirname(v.name) for v in variables_to_restore }
tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt',
{s+'/':s+'/' for s in scopes})
tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
total_loss = tf.losses.get_total_loss() #obtain the regularization losses as well
# Configure the training op
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.AdamOptimizer(learning_rate=0.00002)
train_op = optimizer.minimize(total_loss, global_step)
else:
train_op = None
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=total_loss,
train_op=train_op)
def main(unused_argv):
# Create the Estimator
classifier = tf.estimator.Estimator(
model_fn=lambda features, labels, mode: model_fn(features, labels, NUM_CLASSES, mode),
model_dir='model/MCVE')
# Train the model
classifier.train(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, NUM_CLASSES, batch_size=1),
steps=1000)
# Evaluate the model and print results
eval_results = classifier.evaluate(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL, NUM_CLASSES, batch_size=1))
print()
print('Evaluation results:\n %s' % eval_results)
if __name__ == '__main__':
tf.app.run(main=main, argv=[sys.argv[0]])
其中 inception_resnet_v2
是 the model implementation in Tensorflow's models repository。
如果我 运行 这个脚本,我从 init_from_checkpoint
得到一堆信息日志,但是,在会话创建时,它似乎试图加载 Logits
权重从检查点开始,由于形状不兼容而失败。这是完整的回溯:
Traceback (most recent call last):
File "<ipython-input-6-06fadd69ae8f>", line 1, in <module>
runfile('C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py', wdir='C:/Users/1/Desktop/transfer_learning_tutorial-master')
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile
execfile(filename, namespace)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 77, in <module>
tf.app.run(main=main, argv=[sys.argv[0]])
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 68, in main
steps=1000)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 302, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 780, in _train_model
log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 368, in MonitoredTrainingSession
stop_grace_period_secs=stop_grace_period_secs)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 673, in __init__
stop_grace_period_secs=stop_grace_period_secs)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 493, in __init__
self._sess = _RecoverableSession(self._coordinated_creator)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 851, in __init__
_WrappedSession.__init__(self, self._create_session())
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 856, in _create_session
return self._sess_creator.create_session()
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 554, in create_session
self.tf_sess = self._session_creator.create_session()
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 428, in create_session
init_fn=self._scaffold.init_fn)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\session_manager.py", line 279, in prepare_session
sess.run(init_op, feed_dict=init_feed_dict)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 889, in run
run_metadata_ptr)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run
feed_dict_tensor, options, run_metadata)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1317, in _do_run
options, run_metadata)
File "C:\Users\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1336, in _do_call
raise type(e)(node_def, op, message)
InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [900] rhs shape= [1001] [[Node: Assign_1145 = Assign[T=DT_FLOAT,
_class=["loc:@InceptionResnetV2/Logits/Logits/biases"], use_locking=true, validate_shape=true,
_device="/job:localhost/replica:0/task:0/device:CPU:0"](InceptionResnetV2/Logits/Logits/biases, checkpoint_initializer_1145)]]
我在使用 init_from_checkpoint
时做错了什么?我们应该如何在 model_fn
中“使用”它?为什么当我明确告诉它不要时,估算器试图从检查点加载 Logits
' 权重?
更新:
根据评论中的建议,我尝试了其他方法来调用tf.train.init_from_checkpoint
。
使用{v.name: v.name}
如果按照评论中的建议,我将调用替换为 {v.name:v.name for v in variables_to_restore}
,我会收到此错误:
ValueError: Assignment map with scope only name InceptionResnetV2/Conv2d_2a_3x3 should map
to scope only InceptionResnetV2/Conv2d_2a_3x3/weights:0. Should be 'scope/': 'other_scope/'.
使用{v.name: v}
如果我尝试使用 name:variable
映射,则会收到以下错误:
ValueError: Tensor InceptionResnetV2/Conv2d_2a_3x3/weights:0 is not found in
inception_resnet_v2_2016_08_30.ckpt checkpoint
{'InceptionResnetV2/Repeat_2/block8_4/Branch_1/Conv2d_0c_3x1/BatchNorm/moving_mean': [256],
'InceptionResnetV2/Repeat/block35_9/Branch_0/Conv2d_1x1/BatchNorm/beta': [32], ...
错误继续列出我认为是检查点中的所有变量名(或者可能是范围?)。
更新(2)
检查上面的最新错误后,我在检查点变量列表中看到 InceptionResnetV2/Conv2d_2a_3x3/weights
是 。 问题是:0
最后!
我现在将验证这是否确实解决了问题,如果是的话 post 一个答案。
感谢@KathyWu 的评论,让我走上正轨,找到了问题。
事实上,我计算 scopes
的方式将包括 InceptionResnetV2/
范围,这将触发 all 变量的加载 "under"范围(即网络中的所有变量)。然而,用正确的字典替换它并非易事。
在可能的范围模式 init_from_checkpoint
accepts 中,我必须使用的是 'scope_variable_name': variable
一种,但没有使用实际的 variable.name
属性.
variable.name
看起来像:'some_scope/variable_name:0'
。 :0
不在检查点变量的名称中,因此使用 scopes = {v.name:v.name for v in variables_to_restore}
将引发 "Variable not found" 错误。
让它工作的技巧是从名称中剥离张量索引:
tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt',
{v.name.split(':')[0]: v for v in variables_to_restore})
我发现 {s+'/':s+'/' for s in scopes}
不起作用,只是因为 variables_to_restore
包含类似 "global_step"
的东西,所以范围包括可以包括所有内容的全局范围。你需要打印 variables_to_restore
,找到 "global_step"
的东西,然后把它放在 "exclude"
.