如何在Tensorflow中实现预训练?如何部分使用检查点文件中保存的权重?
How to implement pre-training in Tensorflow? How to partially use saved weights from checkpoint file?
为方便讨论,对以下模型进行了简化。
假设我的训练集中有大约 40,000 张 512x512 图像。我正在尝试实施预训练,我的计划如下:
1.Train 一个神经网络(我们称它为 net_1),它接收 256x256 图像,并将经过训练的模型保存为 tensorflow 检查点文件格式。
net_1: input -> 3 conv2d -> maxpool2d -> 2 conv2d -> rmspool -> flatten -> dense
我们称这个结构为net_1_kernel:
net_1_kernel: 3 conv2d -> maxpool2d -> 3 conv2d
并调用剩余部分other_layers:
other_layers: rmspool -> flatten -> dense
因此我们可以用下面的形式表示net_1:
net_1: input -> net_1_kernel -> other_layers
2.Insert几层到net_1的结构,现在称之为net_2。它应该是这样的:
net_2: input -> net_1_kernel -> maxpool2d -> 3 conv2d -> other_layers
net_2 将采用 512x512 图像作为输入。
当我训练net_2时,我想使用net_1的检查点文件中保存的权重和偏差来初始化net_1_kernel部分net_2。我该怎么做?
我知道我可以加载检查点来预测测试数据。但在那种情况下,它将加载所有内容(net_1_kernel 和 other_layers)。我想要的是仅加载 net_1_kernel 并将其用于 net_2.
中的 weight/bias 初始化
我也知道我可以将检查点文件中的内容打印到txt,然后复制和粘贴以手动初始化权重和偏差。但是,这些权重和偏差中的数字太多了,这将是我最后的选择。
首先,您可以使用以下代码查看您保存的ckpt文件中所有检查点的列表。
from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file(file_name="file.ckpt", tensor_name="xxx", all_tensors=False, all_tensor_names=True)
请记住,当您恢复检查点文件时,它会恢复检查点文件中的所有变量。如果你必须保存和恢复特定的变量,你可以这样做:
- 列出要从
tf.trainable_variables()
中保存的所有变量
var = [v for v in tf.trainable_variables() if "net_1_kernel" in v.name]
saverAndRestore = tf.train.Saver(var)
- 现在您可以轻松保存或恢复var列表中的所有变量,如下所示:
saverAndRestore.save(sess_1,"net_1.ckpt")
这只会将列表 var 中的变量保存到 net_1.ckpt。
saverAndRestore.restore(sess_1,"net_1.ckpt")
这只会从 net_1.ckpt.
中恢复列表 var 中的变量
除上述之外,您所要做的就是name/scope您的变量,这样您就可以轻松地执行上面的第 1 步。
为方便讨论,对以下模型进行了简化。
假设我的训练集中有大约 40,000 张 512x512 图像。我正在尝试实施预训练,我的计划如下:
1.Train 一个神经网络(我们称它为 net_1),它接收 256x256 图像,并将经过训练的模型保存为 tensorflow 检查点文件格式。
net_1: input -> 3 conv2d -> maxpool2d -> 2 conv2d -> rmspool -> flatten -> dense
我们称这个结构为net_1_kernel:
net_1_kernel: 3 conv2d -> maxpool2d -> 3 conv2d
并调用剩余部分other_layers:
other_layers: rmspool -> flatten -> dense
因此我们可以用下面的形式表示net_1:
net_1: input -> net_1_kernel -> other_layers
2.Insert几层到net_1的结构,现在称之为net_2。它应该是这样的:
net_2: input -> net_1_kernel -> maxpool2d -> 3 conv2d -> other_layers
net_2 将采用 512x512 图像作为输入。
当我训练net_2时,我想使用net_1的检查点文件中保存的权重和偏差来初始化net_1_kernel部分net_2。我该怎么做?
我知道我可以加载检查点来预测测试数据。但在那种情况下,它将加载所有内容(net_1_kernel 和 other_layers)。我想要的是仅加载 net_1_kernel 并将其用于 net_2.
中的 weight/bias 初始化我也知道我可以将检查点文件中的内容打印到txt,然后复制和粘贴以手动初始化权重和偏差。但是,这些权重和偏差中的数字太多了,这将是我最后的选择。
首先,您可以使用以下代码查看您保存的ckpt文件中所有检查点的列表。
from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file(file_name="file.ckpt", tensor_name="xxx", all_tensors=False, all_tensor_names=True)
请记住,当您恢复检查点文件时,它会恢复检查点文件中的所有变量。如果你必须保存和恢复特定的变量,你可以这样做:
- 列出要从
tf.trainable_variables()
中保存的所有变量
var = [v for v in tf.trainable_variables() if "net_1_kernel" in v.name]
saverAndRestore = tf.train.Saver(var)
- 现在您可以轻松保存或恢复var列表中的所有变量,如下所示:
saverAndRestore.save(sess_1,"net_1.ckpt")
这只会将列表 var 中的变量保存到 net_1.ckpt。
saverAndRestore.restore(sess_1,"net_1.ckpt")
这只会从 net_1.ckpt.
中恢复列表 var 中的变量除上述之外,您所要做的就是name/scope您的变量,这样您就可以轻松地执行上面的第 1 步。