TensorFlow:从多个检查点恢复变量
TensorFlow: Restoring variables from from multiple checkpoints
我有以下情况:
我有 2 个模型用 2 个单独的脚本编写:
模型A由变量a1
、a2
、a3
组成,写成A.py
模型B由变量b1
、b2
、b3
组成,写成B.py
在A.py
和B.py
中,我有一个tf.train.Saver
保存所有局部变量的检查点,我们调用检查点文件ckptA
和ckptB
分别
我现在想制作一个使用 a1
和 b1
的模型 C。我可以通过使用 var_scope(b1
也是如此)在 A 和 C 中使用完全相同的 a1
变量名。
问题是如何将 a1
和 b1
从 ckptA
和 ckptB
加载到模型 C 中?例如,以下是否可行?
saver.restore(session, ckptA_location)
saver.restore(session, ckptB_location)
如果您尝试恢复同一个会话两次,是否会出现错误?它会抱怨没有为额外变量分配 "slots"(b2
、b3
、a2
、a3
),还是会简单地恢复变量它可以,并且只有在 C 中有一些其他未初始化的变量时才会抱怨?
我现在正在尝试编写一些代码来对此进行测试,但我很乐意看到解决此问题的规范方法,因为在尝试重新使用一些预训练的权重时经常会遇到这种情况。
谢谢!
如果您尝试使用保护程序(默认代表所有六个变量)从不包含保护程序代表的所有变量的检查点恢复,您将得到 tf.errors.NotFoundError
。 (但是请注意,您可以在同一会话中多次调用 Saver.restore()
,对于变量的任何子集,只要所有请求的变量都存在于相应的文件中即可。)
规范方法是定义两个单独的tf.train.Saver
实例,覆盖完全包含在单个检查点中的每个变量子集。例如:
saver_a = tf.train.Saver([a1])
saver_b = tf.train.Saver([b1])
saver_a.restore(session, ckptA_location)
saver_b.restore(session, ckptB_location)
根据您的代码构建方式,如果您在本地范围内有指向名为 a1
和 b1
的 tf.Variable
对象的指针,您可以在此处停止阅读。
另一方面,如果变量 a1
和 b1
定义在不同的文件中,您可能需要做一些创造性的事情来检索指向这些变量的指针。虽然不太理想,但人们通常做的是使用一个公共前缀,例如如下(假设变量名分别为 "a1:0"
和 "b1:0"
):
saver_a = tf.train.Saver([v for v in tf.all_variables() if v.name == "a1:0"])
saver_b = tf.train.Saver([v for v in tf.all_variables() if v.name == "b1:0"])
最后一点:您不必付出巨大的努力来确保变量在 A 和 C 中具有相同的名称。您可以将名称到 Variable
的字典作为第一个参数传递到 tf.train.Saver
构造函数,从而将检查点文件中的名称重新映射到代码中的 Variable
对象。如果 A.py
和 B.py
具有相似名称的变量,或者如果在 C.py
中您希望将这些文件中的模型代码组织在 tf.name_scope()
.[=31= 中,这将有所帮助]
我有以下情况:
我有 2 个模型用 2 个单独的脚本编写:
模型A由变量
a1
、a2
、a3
组成,写成A.py
模型B由变量
b1
、b2
、b3
组成,写成B.py
在A.py
和B.py
中,我有一个tf.train.Saver
保存所有局部变量的检查点,我们调用检查点文件ckptA
和ckptB
分别
我现在想制作一个使用 a1
和 b1
的模型 C。我可以通过使用 var_scope(b1
也是如此)在 A 和 C 中使用完全相同的 a1
变量名。
问题是如何将 a1
和 b1
从 ckptA
和 ckptB
加载到模型 C 中?例如,以下是否可行?
saver.restore(session, ckptA_location)
saver.restore(session, ckptB_location)
如果您尝试恢复同一个会话两次,是否会出现错误?它会抱怨没有为额外变量分配 "slots"(b2
、b3
、a2
、a3
),还是会简单地恢复变量它可以,并且只有在 C 中有一些其他未初始化的变量时才会抱怨?
我现在正在尝试编写一些代码来对此进行测试,但我很乐意看到解决此问题的规范方法,因为在尝试重新使用一些预训练的权重时经常会遇到这种情况。
谢谢!
如果您尝试使用保护程序(默认代表所有六个变量)从不包含保护程序代表的所有变量的检查点恢复,您将得到 tf.errors.NotFoundError
。 (但是请注意,您可以在同一会话中多次调用 Saver.restore()
,对于变量的任何子集,只要所有请求的变量都存在于相应的文件中即可。)
规范方法是定义两个单独的tf.train.Saver
实例,覆盖完全包含在单个检查点中的每个变量子集。例如:
saver_a = tf.train.Saver([a1])
saver_b = tf.train.Saver([b1])
saver_a.restore(session, ckptA_location)
saver_b.restore(session, ckptB_location)
根据您的代码构建方式,如果您在本地范围内有指向名为 a1
和 b1
的 tf.Variable
对象的指针,您可以在此处停止阅读。
另一方面,如果变量 a1
和 b1
定义在不同的文件中,您可能需要做一些创造性的事情来检索指向这些变量的指针。虽然不太理想,但人们通常做的是使用一个公共前缀,例如如下(假设变量名分别为 "a1:0"
和 "b1:0"
):
saver_a = tf.train.Saver([v for v in tf.all_variables() if v.name == "a1:0"])
saver_b = tf.train.Saver([v for v in tf.all_variables() if v.name == "b1:0"])
最后一点:您不必付出巨大的努力来确保变量在 A 和 C 中具有相同的名称。您可以将名称到 Variable
的字典作为第一个参数传递到 tf.train.Saver
构造函数,从而将检查点文件中的名称重新映射到代码中的 Variable
对象。如果 A.py
和 B.py
具有相似名称的变量,或者如果在 C.py
中您希望将这些文件中的模型代码组织在 tf.name_scope()
.[=31= 中,这将有所帮助]