如何子类化`tf.train.Saver()`?
How to subclass `tf.train.Saver()`?
我正在 colaboratory
上训练,有时会失去与服务器的连接。闲置 90 分钟后,虚拟机也将重置。
我想用回调覆盖 tf.train.Saver.save()
,以便我可以按时间或步长间隔将检查点复制到我的 Google 云存储帐户。
参见:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
#
# override tf_saver, add callback after save()
#
import os
import tensorflow as tf
from tensorflow.python.training import saver as tf_saver
## override saver
class Saver_with_callback(tf_saver.Saver):
_callback_op = None
def __init__(self, callback_op, **kwargs ):
self._callback_op = callback_op
super(tf_saver.Saver, self).__init__(**kwargs)
def save(self, sess, save_path, **kwargs):
"""
see: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
"""
model_checkpoint_path = super.save(sess, save_path, **kwargs)
if self._callback_op is not None:
## call on a new thread?
self._callback_op(sess, save_path,
model_checkpoint_path=model_checkpoint_path,
**kwargs)
return model_checkpoint_path
但是当我 运行 slim.learning.train(saver=callback_saver)
时出现错误
final_loss = slim.learning.train(train_op, log_dir,
init_fn=init_fn,
global_step=global_step,
number_of_steps=steps,
save_summaries_secs=300,
save_interval_secs=600,
saver=callback_saver,
# saver=tf_saver.Saver(),
)
错误:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-41-dfb09327cccd> in <module>()
149 save_summaries_secs=300,
150 save_interval_secs=600,
--> 151 saver=callback_saver,
152 # saver=tf_saver.Saver,
153 )
/anaconda/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/slim/python/slim/learning.py in train(train_op, logdir, train_step_fn, train_step_kwargs, log_every_n_steps, graph, master, is_chief, global_step, number_of_steps, init_op, init_feed_dict, local_init_op, init_fn, ready_op, summary_op, save_summaries_secs, summary_writer, startup_delay_steps, saver, save_interval_secs, sync_optimizer, session_config, session_wrapper, trace_every_n_steps)
730 save_summaries_secs=save_summaries_secs,
731 save_model_secs=save_interval_secs,
--> 732 init_fn=init_fn)
733
734 if summary_writer is not None:
/anaconda/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/supervisor.py in __init__(self, graph, ready_op, ready_for_local_init_op, is_chief, init_op, init_feed_dict, local_init_op, logdir, summary_op, saver, global_step, save_summaries_secs, save_model_secs, recovery_wait_secs, stop_grace_secs, checkpoint_basename, session_manager, summary_writer, init_fn)
304 self._meta_graph_def = meta_graph.create_meta_graph_def(
305 graph_def=graph.as_graph_def(add_shapes=True),
--> 306 saver_def=self._saver.saver_def if self._saver else None)
307 self._is_chief = is_chief
308 self._coord = coordinator.Coordinator()
AttributeError: 'Saver_with_callback' object has no attribute 'saver_def'
``
isinstance(callback_saver, tf_saver.Saver)==True
如果我使用 saver=tf_saver.Saver()
它工作正常。
您没有在 Saver_with_callback.__init__()
中调用 tf_saver.Saver
的 __init__
函数。
当您调用 super(tf_saver.Saver, self).__init__(**kwargs)
时调用了 tf_saver.Saver
的父 class 的 __init__
函数。
这是因为 super(tf_saver.Saver, self)
returns tf_saver.Saver
的父 class,而不是您期望的 tf_saver.Saver
本身。
你应该打电话给
super(Saver_with_callback, self).__init__(**kwargs)
或 Python3,只需
super().__init__(**kwargs)
我正在 colaboratory
上训练,有时会失去与服务器的连接。闲置 90 分钟后,虚拟机也将重置。
我想用回调覆盖 tf.train.Saver.save()
,以便我可以按时间或步长间隔将检查点复制到我的 Google 云存储帐户。
参见:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
#
# override tf_saver, add callback after save()
#
import os
import tensorflow as tf
from tensorflow.python.training import saver as tf_saver
## override saver
class Saver_with_callback(tf_saver.Saver):
_callback_op = None
def __init__(self, callback_op, **kwargs ):
self._callback_op = callback_op
super(tf_saver.Saver, self).__init__(**kwargs)
def save(self, sess, save_path, **kwargs):
"""
see: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
"""
model_checkpoint_path = super.save(sess, save_path, **kwargs)
if self._callback_op is not None:
## call on a new thread?
self._callback_op(sess, save_path,
model_checkpoint_path=model_checkpoint_path,
**kwargs)
return model_checkpoint_path
但是当我 运行 slim.learning.train(saver=callback_saver)
final_loss = slim.learning.train(train_op, log_dir,
init_fn=init_fn,
global_step=global_step,
number_of_steps=steps,
save_summaries_secs=300,
save_interval_secs=600,
saver=callback_saver,
# saver=tf_saver.Saver(),
)
错误:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-41-dfb09327cccd> in <module>()
149 save_summaries_secs=300,
150 save_interval_secs=600,
--> 151 saver=callback_saver,
152 # saver=tf_saver.Saver,
153 )
/anaconda/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/slim/python/slim/learning.py in train(train_op, logdir, train_step_fn, train_step_kwargs, log_every_n_steps, graph, master, is_chief, global_step, number_of_steps, init_op, init_feed_dict, local_init_op, init_fn, ready_op, summary_op, save_summaries_secs, summary_writer, startup_delay_steps, saver, save_interval_secs, sync_optimizer, session_config, session_wrapper, trace_every_n_steps)
730 save_summaries_secs=save_summaries_secs,
731 save_model_secs=save_interval_secs,
--> 732 init_fn=init_fn)
733
734 if summary_writer is not None:
/anaconda/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/supervisor.py in __init__(self, graph, ready_op, ready_for_local_init_op, is_chief, init_op, init_feed_dict, local_init_op, logdir, summary_op, saver, global_step, save_summaries_secs, save_model_secs, recovery_wait_secs, stop_grace_secs, checkpoint_basename, session_manager, summary_writer, init_fn)
304 self._meta_graph_def = meta_graph.create_meta_graph_def(
305 graph_def=graph.as_graph_def(add_shapes=True),
--> 306 saver_def=self._saver.saver_def if self._saver else None)
307 self._is_chief = is_chief
308 self._coord = coordinator.Coordinator()
AttributeError: 'Saver_with_callback' object has no attribute 'saver_def'
``
isinstance(callback_saver, tf_saver.Saver)==True
如果我使用 saver=tf_saver.Saver()
它工作正常。
您没有在 Saver_with_callback.__init__()
中调用 tf_saver.Saver
的 __init__
函数。
当您调用 super(tf_saver.Saver, self).__init__(**kwargs)
时调用了 tf_saver.Saver
的父 class 的 __init__
函数。
这是因为 super(tf_saver.Saver, self)
returns tf_saver.Saver
的父 class,而不是您期望的 tf_saver.Saver
本身。
你应该打电话给
super(Saver_with_callback, self).__init__(**kwargs)
或 Python3,只需
super().__init__(**kwargs)