TensorFlow中如何清除整个网络结构

How to clear the entire network structure in TensorFlow

我想在每次完成训练后清除内存/网络。我使用了在线建议的替代方案,但如果我正确解释我的结果,它们似乎不起作用。我使用 tf.compat.v1.reset_default_graph()tf.keras.backend.clear_session(),因为它们大多是在线推荐的。

import numpy as np
import random
import tensorflow as tf
from tensorflow import keras 
from tensorflow.python.keras import backend as K


upper_limit = 2
lower_limit = -2

training_input=  np.random.random ([100,5])*(upper_limit - lower_limit) + lower_limit
training_output = np.random.random ([100,1]) *10*(upper_limit - lower_limit) + lower_limit


model = tf.keras.Sequential([  
    tf.keras.layers.Flatten(input_shape=(5,)),
    tf.keras.layers.Dense(12, activation='relu'),
    tf.keras.layers.Dense(1) 
])   
                               
model.compile(loss="mse",optimizer = tf.keras.optimizers.Adam(learning_rate=0.01))

for layer in model.layers:
    print("layer weights before fitting: ",layer.get_weights(),"\n") # weights
    
model.fit(training_input, training_output, epochs=5, batch_size=100,verbose=0)
    

for layer in model.layers:
      print("layer weights after fitting: ",layer.get_weights(),"\n") # weights
print("\n") 

tf.compat.v1.reset_default_graph()
tf.keras.backend.clear_session()

print("after  clear","\n")
for layer in model.layers:
      print(layer.get_weights(),"\n") # weights  

当我尝试清除网络后打印层权重时,我得到的权重值与清除会话之前相同。

我认为您正在寻找的是重置模型的权重,这与会话或图表并没有真正的关系(有一些例外)。

权重的重置目前是一个有争议的话题,您可以在大多数情况下找到如何做到这一点here但是正如您所见,今天没有人计划实现此功能

为了便于访问我post下面的当前命题

def reset_weights(model):
  for layer in model.layers:
      if isinstance(layer, tf.keras.Model): #if you're using a model as a layer
          reset_weights(layer) #apply function recursively
          continue

      #where are the initializers?
      if hasattr(layer, 'cell'):
          init_container = layer.cell
      else:
          init_container = layer

      for key, initializer in init_container.__dict__.items():
          if "initializer" not in key: #is this item an initializer?
                continue #if no, skip it

          # find the corresponding variable, like the kernel or the bias
          if key == 'recurrent_initializer': #special case check
              var = getattr(init_container, 'recurrent_kernel')
          else:
              var = getattr(init_container, key.replace("_initializer", ""))

          var.assign(initializer(var.shape, var.dtype))

请记住,如果您没有定义种子,则每次调用重置时权重都会不同