如何使用 tf.train.Checkpoint 在 tensorflow 2.0 中保存和加载选定变量和所有变量?
How to save and load selected and all variables in tensorflow 2.0 using tf.train.Checkpoint?
如何将如下所示的 tensorflow 2.0 中的选定变量保存在一个文件中,并使用 tf.train.Checkpoint 将它们加载到另一个代码中定义的一些变量中?
class manyVariables:
def __init__(self):
self.initList = [None]*100
for i in range(100):
self.initList[i] = tf.Variable(tf.random.normal([5,5]))
self.makeSomeMoreVariables()
def makeSomeMoreVariables(self):
self.moreList = [None]*10
for i in range(10):
self.moreList[i] = tf.Variable(tf.random.normal([3,3]))
def saveVariables(self):
# how to save self.initList's 3,55 and 60th elements and self.moreList's 4th element
此外,请展示如何保存所有变量并使用 tf.train.Checkpoint 重新加载。提前致谢。
在下面的代码中,我使用您选择的名称将一个名为变量的数组保存到一个 .txt 文件中。该文件将与您的 python 文件位于同一文件夹中。 open 函数中的 'wb' 表示截断写入(因此删除文件中以前的所有内容)并使用字节格式。我使用 pickle 来处理 saving/parsing 列表。
import pickle
def saveVariables(self, variables): #where 'variables' is a list of variables
with open("nameOfYourFile.txt", 'wb+') as file:
pickle.dump(variables, file)
def retrieveVariables(self, filename):
variables = []
with open(str(filename), 'rb') as file:
variables = pickle.load(file)
return variables
要将特定内容保存到文件中,只需将其添加为 saveVariables 中的变量参数,如下所示:
myVariables = [initList[2], initList[54], initList[59], moreList[3]]
saveVariables(myVariables)
从具有特定名称的文本文件中检索变量:
myVariables = retrieveVariables("theNameOfYourFile.txt")
thirdEl = myVariables[0]
fiftyFifthEl = myVariables[1]
SixtiethEl = myVariables[2]
fourthEl = myVariables[3]
您可以在 class 中的任何位置添加这些函数。
然而,为了能够在您的示例中访问 initList/moreList,您应该从它们的函数中 return 它们(就像我对 variables
列表所做的那样) 或将它们设为全局。
我不确定这是否是您的意思,但是您可以专门为要保存和恢复的变量创建一个 tf.train.Checkpoint
对象。请参阅以下示例:
import tensorflow as tf
class manyVariables:
def __init__(self):
self.initList = [None]*100
for i in range(100):
self.initList[i] = tf.Variable(tf.random.normal([5,5]))
self.makeSomeMoreVariables()
self.ckpt = self.makeCheckpoint()
def makeSomeMoreVariables(self):
self.moreList = [None]*10
for i in range(10):
self.moreList[i] = tf.Variable(tf.random.normal([3,3]))
def makeCheckpoint(self):
return tf.train.Checkpoint(
init3=self.initList[3], init55=self.initList[55],
init60=self.initList[60], more4=self.moreList[4])
def saveVariables(self):
self.ckpt.save('./ckpt')
def restoreVariables(self):
status = self.ckpt.restore(tf.train.latest_checkpoint('.'))
status.assert_consumed() # Optional check
# Create variables
v1 = manyVariables()
# Assigned fixed values
for i, v in enumerate(v1.initList):
v.assign(i * tf.ones_like(v))
for i, v in enumerate(v1.moreList):
v.assign(100 + i * tf.ones_like(v))
# Save them
v1.saveVariables()
# Create new variables
v2 = manyVariables()
# Check initial values
print(v2.initList[2].numpy())
# [[-1.9110833 0.05956204 -1.1753829 -0.3572553 -0.95049495]
# [ 0.31409055 1.1262076 0.47890127 -0.1699607 0.4409122 ]
# [-0.75385517 -0.13847834 0.97012395 0.42515194 -1.4371008 ]
# [ 0.44205236 0.86158335 0.6919655 -2.5156968 0.16496429]
# [-1.241602 -0.15177743 0.5603795 -0.3560254 -0.18536267]]
print(v2.initList[3].numpy())
# [[-3.3441594 -0.18425298 -0.4898144 -1.2330629 0.08798431]
# [ 1.5002227 0.99475247 0.7817361 0.3849587 -0.59548247]
# [-0.57121766 -1.277224 0.6957546 -0.67618763 0.0510064 ]
# [ 0.85491985 0.13310803 -0.93152267 0.10205163 0.57520276]
# [-1.0606447 -0.16966362 -1.0448577 0.56799036 -0.90726566]]
# Restore them
v2.restoreVariables()
# Check values after restoring
print(v2.initList[2].numpy())
# [[-1.9110833 0.05956204 -1.1753829 -0.3572553 -0.95049495]
# [ 0.31409055 1.1262076 0.47890127 -0.1699607 0.4409122 ]
# [-0.75385517 -0.13847834 0.97012395 0.42515194 -1.4371008 ]
# [ 0.44205236 0.86158335 0.6919655 -2.5156968 0.16496429]
# [-1.241602 -0.15177743 0.5603795 -0.3560254 -0.18536267]]
print(v2.initList[3].numpy())
# [[3. 3. 3. 3. 3.]
# [3. 3. 3. 3. 3.]
# [3. 3. 3. 3. 3.]
# [3. 3. 3. 3. 3.]
# [3. 3. 3. 3. 3.]]
如果你想保存列表中的所有变量,你可以用这样的东西替换 makeCheckpoint
:
def makeCheckpoint(self):
return tf.train.Checkpoint(
**{f'init{i}': v for i, v in enumerate(self.initList)},
**{f'more{i}': v for i, v in enumerate(self.moreList)})
请注意,您可以有 "nested" 个检查点,因此,更一般地说,您可以有一个为变量列表创建检查点的函数,例如:
def listCheckpoint(varList):
# Use 'item{}'.format(i) if using Python <3.6
return tf.train.Checkpoint(**{f'item{i}': v for i, v in enumerate(varList)})
那么你就可以拥有这个:
def makeCheckpoint(self):
return tf.train.Checkpoint(init=listCheckpoint(self.initList),
more=listCheckpoint(self.moreList))
如何将如下所示的 tensorflow 2.0 中的选定变量保存在一个文件中,并使用 tf.train.Checkpoint 将它们加载到另一个代码中定义的一些变量中?
class manyVariables:
def __init__(self):
self.initList = [None]*100
for i in range(100):
self.initList[i] = tf.Variable(tf.random.normal([5,5]))
self.makeSomeMoreVariables()
def makeSomeMoreVariables(self):
self.moreList = [None]*10
for i in range(10):
self.moreList[i] = tf.Variable(tf.random.normal([3,3]))
def saveVariables(self):
# how to save self.initList's 3,55 and 60th elements and self.moreList's 4th element
此外,请展示如何保存所有变量并使用 tf.train.Checkpoint 重新加载。提前致谢。
在下面的代码中,我使用您选择的名称将一个名为变量的数组保存到一个 .txt 文件中。该文件将与您的 python 文件位于同一文件夹中。 open 函数中的 'wb' 表示截断写入(因此删除文件中以前的所有内容)并使用字节格式。我使用 pickle 来处理 saving/parsing 列表。
import pickle
def saveVariables(self, variables): #where 'variables' is a list of variables
with open("nameOfYourFile.txt", 'wb+') as file:
pickle.dump(variables, file)
def retrieveVariables(self, filename):
variables = []
with open(str(filename), 'rb') as file:
variables = pickle.load(file)
return variables
要将特定内容保存到文件中,只需将其添加为 saveVariables 中的变量参数,如下所示:
myVariables = [initList[2], initList[54], initList[59], moreList[3]]
saveVariables(myVariables)
从具有特定名称的文本文件中检索变量:
myVariables = retrieveVariables("theNameOfYourFile.txt")
thirdEl = myVariables[0]
fiftyFifthEl = myVariables[1]
SixtiethEl = myVariables[2]
fourthEl = myVariables[3]
您可以在 class 中的任何位置添加这些函数。
然而,为了能够在您的示例中访问 initList/moreList,您应该从它们的函数中 return 它们(就像我对 variables
列表所做的那样) 或将它们设为全局。
我不确定这是否是您的意思,但是您可以专门为要保存和恢复的变量创建一个 tf.train.Checkpoint
对象。请参阅以下示例:
import tensorflow as tf
class manyVariables:
def __init__(self):
self.initList = [None]*100
for i in range(100):
self.initList[i] = tf.Variable(tf.random.normal([5,5]))
self.makeSomeMoreVariables()
self.ckpt = self.makeCheckpoint()
def makeSomeMoreVariables(self):
self.moreList = [None]*10
for i in range(10):
self.moreList[i] = tf.Variable(tf.random.normal([3,3]))
def makeCheckpoint(self):
return tf.train.Checkpoint(
init3=self.initList[3], init55=self.initList[55],
init60=self.initList[60], more4=self.moreList[4])
def saveVariables(self):
self.ckpt.save('./ckpt')
def restoreVariables(self):
status = self.ckpt.restore(tf.train.latest_checkpoint('.'))
status.assert_consumed() # Optional check
# Create variables
v1 = manyVariables()
# Assigned fixed values
for i, v in enumerate(v1.initList):
v.assign(i * tf.ones_like(v))
for i, v in enumerate(v1.moreList):
v.assign(100 + i * tf.ones_like(v))
# Save them
v1.saveVariables()
# Create new variables
v2 = manyVariables()
# Check initial values
print(v2.initList[2].numpy())
# [[-1.9110833 0.05956204 -1.1753829 -0.3572553 -0.95049495]
# [ 0.31409055 1.1262076 0.47890127 -0.1699607 0.4409122 ]
# [-0.75385517 -0.13847834 0.97012395 0.42515194 -1.4371008 ]
# [ 0.44205236 0.86158335 0.6919655 -2.5156968 0.16496429]
# [-1.241602 -0.15177743 0.5603795 -0.3560254 -0.18536267]]
print(v2.initList[3].numpy())
# [[-3.3441594 -0.18425298 -0.4898144 -1.2330629 0.08798431]
# [ 1.5002227 0.99475247 0.7817361 0.3849587 -0.59548247]
# [-0.57121766 -1.277224 0.6957546 -0.67618763 0.0510064 ]
# [ 0.85491985 0.13310803 -0.93152267 0.10205163 0.57520276]
# [-1.0606447 -0.16966362 -1.0448577 0.56799036 -0.90726566]]
# Restore them
v2.restoreVariables()
# Check values after restoring
print(v2.initList[2].numpy())
# [[-1.9110833 0.05956204 -1.1753829 -0.3572553 -0.95049495]
# [ 0.31409055 1.1262076 0.47890127 -0.1699607 0.4409122 ]
# [-0.75385517 -0.13847834 0.97012395 0.42515194 -1.4371008 ]
# [ 0.44205236 0.86158335 0.6919655 -2.5156968 0.16496429]
# [-1.241602 -0.15177743 0.5603795 -0.3560254 -0.18536267]]
print(v2.initList[3].numpy())
# [[3. 3. 3. 3. 3.]
# [3. 3. 3. 3. 3.]
# [3. 3. 3. 3. 3.]
# [3. 3. 3. 3. 3.]
# [3. 3. 3. 3. 3.]]
如果你想保存列表中的所有变量,你可以用这样的东西替换 makeCheckpoint
:
def makeCheckpoint(self):
return tf.train.Checkpoint(
**{f'init{i}': v for i, v in enumerate(self.initList)},
**{f'more{i}': v for i, v in enumerate(self.moreList)})
请注意,您可以有 "nested" 个检查点,因此,更一般地说,您可以有一个为变量列表创建检查点的函数,例如:
def listCheckpoint(varList):
# Use 'item{}'.format(i) if using Python <3.6
return tf.train.Checkpoint(**{f'item{i}': v for i, v in enumerate(varList)})
那么你就可以拥有这个:
def makeCheckpoint(self):
return tf.train.Checkpoint(init=listCheckpoint(self.initList),
more=listCheckpoint(self.moreList))