TensorFlow中如何清除整个网络结构
How to clear the entire network structure in TensorFlow
我想在每次完成训练后清除内存/网络。我使用了在线建议的替代方案,但如果我正确解释我的结果,它们似乎不起作用。我使用 tf.compat.v1.reset_default_graph()
和 tf.keras.backend.clear_session()
,因为它们大多是在线推荐的。
import numpy as np
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras import backend as K
upper_limit = 2
lower_limit = -2
training_input= np.random.random ([100,5])*(upper_limit - lower_limit) + lower_limit
training_output = np.random.random ([100,1]) *10*(upper_limit - lower_limit) + lower_limit
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(5,)),
tf.keras.layers.Dense(12, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(loss="mse",optimizer = tf.keras.optimizers.Adam(learning_rate=0.01))
for layer in model.layers:
print("layer weights before fitting: ",layer.get_weights(),"\n") # weights
model.fit(training_input, training_output, epochs=5, batch_size=100,verbose=0)
for layer in model.layers:
print("layer weights after fitting: ",layer.get_weights(),"\n") # weights
print("\n")
tf.compat.v1.reset_default_graph()
tf.keras.backend.clear_session()
print("after clear","\n")
for layer in model.layers:
print(layer.get_weights(),"\n") # weights
当我尝试清除网络后打印层权重时,我得到的权重值与清除会话之前相同。
我认为您正在寻找的是重置模型的权重,这与会话或图表并没有真正的关系(有一些例外)。
权重的重置目前是一个有争议的话题,您可以在大多数情况下找到如何做到这一点here但是正如您所见,今天没有人计划实现此功能
为了便于访问我post下面的当前命题
def reset_weights(model):
for layer in model.layers:
if isinstance(layer, tf.keras.Model): #if you're using a model as a layer
reset_weights(layer) #apply function recursively
continue
#where are the initializers?
if hasattr(layer, 'cell'):
init_container = layer.cell
else:
init_container = layer
for key, initializer in init_container.__dict__.items():
if "initializer" not in key: #is this item an initializer?
continue #if no, skip it
# find the corresponding variable, like the kernel or the bias
if key == 'recurrent_initializer': #special case check
var = getattr(init_container, 'recurrent_kernel')
else:
var = getattr(init_container, key.replace("_initializer", ""))
var.assign(initializer(var.shape, var.dtype))
请记住,如果您没有定义种子,则每次调用重置时权重都会不同
我想在每次完成训练后清除内存/网络。我使用了在线建议的替代方案,但如果我正确解释我的结果,它们似乎不起作用。我使用 tf.compat.v1.reset_default_graph()
和 tf.keras.backend.clear_session()
,因为它们大多是在线推荐的。
import numpy as np
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras import backend as K
upper_limit = 2
lower_limit = -2
training_input= np.random.random ([100,5])*(upper_limit - lower_limit) + lower_limit
training_output = np.random.random ([100,1]) *10*(upper_limit - lower_limit) + lower_limit
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(5,)),
tf.keras.layers.Dense(12, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(loss="mse",optimizer = tf.keras.optimizers.Adam(learning_rate=0.01))
for layer in model.layers:
print("layer weights before fitting: ",layer.get_weights(),"\n") # weights
model.fit(training_input, training_output, epochs=5, batch_size=100,verbose=0)
for layer in model.layers:
print("layer weights after fitting: ",layer.get_weights(),"\n") # weights
print("\n")
tf.compat.v1.reset_default_graph()
tf.keras.backend.clear_session()
print("after clear","\n")
for layer in model.layers:
print(layer.get_weights(),"\n") # weights
当我尝试清除网络后打印层权重时,我得到的权重值与清除会话之前相同。
我认为您正在寻找的是重置模型的权重,这与会话或图表并没有真正的关系(有一些例外)。
权重的重置目前是一个有争议的话题,您可以在大多数情况下找到如何做到这一点here但是正如您所见,今天没有人计划实现此功能
为了便于访问我post下面的当前命题
def reset_weights(model):
for layer in model.layers:
if isinstance(layer, tf.keras.Model): #if you're using a model as a layer
reset_weights(layer) #apply function recursively
continue
#where are the initializers?
if hasattr(layer, 'cell'):
init_container = layer.cell
else:
init_container = layer
for key, initializer in init_container.__dict__.items():
if "initializer" not in key: #is this item an initializer?
continue #if no, skip it
# find the corresponding variable, like the kernel or the bias
if key == 'recurrent_initializer': #special case check
var = getattr(init_container, 'recurrent_kernel')
else:
var = getattr(init_container, key.replace("_initializer", ""))
var.assign(initializer(var.shape, var.dtype))
请记住,如果您没有定义种子,则每次调用重置时权重都会不同