模型和权重不从检查点加载
Model and Weights do not load from checkpoint
我正在使用 OpenAI gym 的 cartpole 环境训练强化学习模型。尽管我的权重和模型的 .h5 文件出现在目标目录中,但在 运行 后我得到 None 以下代码 - tf.train.get_checkpoint_state("C:/Users/dgt/Documents") .
这是我的全部代码 -
## Slightly modified from the following repository - https://github.com/gsurma/cartpole
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import random
import gym
import numpy as np
import tensorflow as tf
from collections import deque
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
ENV_NAME = "CartPole-v1"
GAMMA = 0.95
LEARNING_RATE = 0.001
MEMORY_SIZE = 1000000
BATCH_SIZE = 20
EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995
checkpoint_path = "training_1/cp.ckpt"
class DQNSolver:
def __init__(self, observation_space, action_space):
# save_dir = args.save_dir
# self.save_dir = save_dir
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
self.exploration_rate = EXPLORATION_MAX
self.action_space = action_space
self.memory = deque(maxlen=MEMORY_SIZE)
self.model = Sequential()
self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
self.model.add(Dense(24, activation="relu"))
self.model.add(Dense(self.action_space, activation="linear"))
self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def act(self, state):
if np.random.rand() < self.exploration_rate:
return random.randrange(self.action_space)
q_values = self.model.predict(state)
return np.argmax(q_values[0])
def experience_replay(self):
if len(self.memory) < BATCH_SIZE:
return
batch = random.sample(self.memory, BATCH_SIZE)
for state, action, reward, state_next, terminal in batch:
q_update = reward
if not terminal:
q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
q_values = self.model.predict(state)
q_values[0][action] = q_update
self.model.fit(state, q_values, verbose=0)
self.exploration_rate *= EXPLORATION_DECAY
self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)
def cartpole():
env = gym.make(ENV_NAME)
#score_logger = ScoreLogger(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)
checkpoint = tf.train.get_checkpoint_state("C:/Users/dgt/Documents")
print('checkpoint:', checkpoint)
if checkpoint and checkpoint.model_checkpoint_path:
dqn_solver.model = keras.models.load_model('cartpole.h5')
dqn_solver.model = model.load_weights('cartpole_weights.h5')
run = 0
i = 0
while i<2:
i = i + 1
#total = 0
run += 1
state = env.reset()
state = np.reshape(state, [1, observation_space])
step = 0
while True:
step += 1
#env.render()
action = dqn_solver.act(state)
state_next, reward, terminal, info = env.step(action)
#total += reward
reward = reward if not terminal else -reward
state_next = np.reshape(state_next, [1, observation_space])
dqn_solver.remember(state, action, reward, state_next, terminal)
state = state_next
dqn_solver.model.save('cartpole.h5')
dqn_solver.model.save_weights('cartpole_weights.h5')
if terminal:
print("Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step))
#score_logger.add_score(step, run)
break
dqn_solver.experience_replay()
if __name__ == "__main__":
cartpole()
cartpole_weights.h5 和 cartpole.h5 文件都出现在我的目标目录中。但是,我相信另一个名为 'checkpoint' 的文件也应该出现。我的理解是,这就是我的代码没有 运行.
的原因
首先,如果您还没有保存 weights/model,代码将不会 运行。所以我注释掉了下面的行和 运行 第一次生成文件的脚本。
checkpoint = tf.train.get_checkpoint_state(".")
print('checkpoint:', checkpoint)
if checkpoint and checkpoint.model_checkpoint_path:
dqn_solver.model = tf.keras.models.load_model('cartpole.h5')
dqn_solver.model.load_weights('cartpole_weights.h5')
注意我也修改了上面的代码——之前有一些语法错误。特别是 post
中的这一行
dqn_solver.model = model.load_weights('cartpole_weights.h5')
可能是导致问题的原因,因为 model.load_weights('file') 方法会改变模型(而不是返回模型)。
然后我测试了模型权重 saved/loaded 是否正确。为此,您可以做
dqn_solver = DQNSolver(observation_space, action_space)
dqn_solver.model.trainable_variables
查看模型首次制作时的(运行domly 初始化)权重。然后你可以加载权重
dqn_solver.model = tf.keras.models.load_model('cartpole.h5')
或
dqn_solver.model.load_weights('cartpole_weights.h5')
然后您可以再次查看 trainable_variables 以确保它们与初始权重不同,并且它们是等效的。
当您保存模型时,它会保存完整的架构 - 层的确切配置。当你保存权重时,它只会保存你可以看到的所有张量列表trainable_variables。
请注意,当您 load_weights 时,它需要加载到权重所适用的确切架构中,否则它将无法正常工作。因此,如果您在 DQNSolver 中更改了模型架构,然后尝试对旧模型进行 load_weights,它将无法正常工作。如果您 load_model 它会将模型重置为准确的体系结构,并设置权重。
编辑 - 整个修改后的脚本
## Slightly modified from the following repository - https://github.com/gsurma/cartpole
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import random
import gym
import numpy as np
import tensorflow as tf
from collections import deque
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
ENV_NAME = "CartPole-v1"
GAMMA = 0.95
LEARNING_RATE = 0.001
MEMORY_SIZE = 1000000
BATCH_SIZE = 20
EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995
checkpoint_path = "training_1/cp.ckpt"
class DQNSolver:
def __init__(self, observation_space, action_space):
# save_dir = args.save_dir
# self.save_dir = save_dir
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
self.exploration_rate = EXPLORATION_MAX
self.action_space = action_space
self.memory = deque(maxlen=MEMORY_SIZE)
self.model = Sequential()
self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
self.model.add(Dense(24, activation="relu"))
self.model.add(Dense(self.action_space, activation="linear"))
self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def act(self, state):
if np.random.rand() < self.exploration_rate:
return random.randrange(self.action_space)
q_values = self.model.predict(state)
return np.argmax(q_values[0])
def experience_replay(self):
if len(self.memory) < BATCH_SIZE:
return
batch = random.sample(self.memory, BATCH_SIZE)
for state, action, reward, state_next, terminal in batch:
q_update = reward
if not terminal:
q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
q_values = self.model.predict(state)
q_values[0][action] = q_update
self.model.fit(state, q_values, verbose=0)
self.exploration_rate *= EXPLORATION_DECAY
self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)
def cartpole():
env = gym.make(ENV_NAME)
#score_logger = ScoreLogger(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)
# checkpoint = tf.train.get_checkpoint_state(".")
# print('checkpoint:', checkpoint)
# if checkpoint and checkpoint.model_checkpoint_path:
# dqn_solver.model = tf.keras.models.load_model('cartpole.h5')
# dqn_solver.model.load_weights('cartpole_weights.h5')
run = 0
i = 0
while i<2:
i = i + 1
#total = 0
run += 1
state = env.reset()
state = np.reshape(state, [1, observation_space])
step = 0
while True:
step += 1
#env.render()
action = dqn_solver.act(state)
state_next, reward, terminal, info = env.step(action)
#total += reward
reward = reward if not terminal else -reward
state_next = np.reshape(state_next, [1, observation_space])
dqn_solver.remember(state, action, reward, state_next, terminal)
state = state_next
dqn_solver.model.save('cartpole.h5')
dqn_solver.model.save_weights('cartpole_weights.h5')
if terminal:
print("Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step))
#score_logger.add_score(step, run)
break
dqn_solver.experience_replay()
if __name__ == "__main__":
cartpole()
#%% to load saved results
env = gym.make(ENV_NAME)
#score_logger = ScoreLogger(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)
dqn_solver.model = tf.keras.models.load_model('cartpole.h5') # or
dqn_solver.model.load_weights('cartpole_weights.h5')
我正在使用 OpenAI gym 的 cartpole 环境训练强化学习模型。尽管我的权重和模型的 .h5 文件出现在目标目录中,但在 运行 后我得到 None 以下代码 - tf.train.get_checkpoint_state("C:/Users/dgt/Documents") .
这是我的全部代码 -
## Slightly modified from the following repository - https://github.com/gsurma/cartpole
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import random
import gym
import numpy as np
import tensorflow as tf
from collections import deque
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
ENV_NAME = "CartPole-v1"
GAMMA = 0.95
LEARNING_RATE = 0.001
MEMORY_SIZE = 1000000
BATCH_SIZE = 20
EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995
checkpoint_path = "training_1/cp.ckpt"
class DQNSolver:
def __init__(self, observation_space, action_space):
# save_dir = args.save_dir
# self.save_dir = save_dir
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
self.exploration_rate = EXPLORATION_MAX
self.action_space = action_space
self.memory = deque(maxlen=MEMORY_SIZE)
self.model = Sequential()
self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
self.model.add(Dense(24, activation="relu"))
self.model.add(Dense(self.action_space, activation="linear"))
self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def act(self, state):
if np.random.rand() < self.exploration_rate:
return random.randrange(self.action_space)
q_values = self.model.predict(state)
return np.argmax(q_values[0])
def experience_replay(self):
if len(self.memory) < BATCH_SIZE:
return
batch = random.sample(self.memory, BATCH_SIZE)
for state, action, reward, state_next, terminal in batch:
q_update = reward
if not terminal:
q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
q_values = self.model.predict(state)
q_values[0][action] = q_update
self.model.fit(state, q_values, verbose=0)
self.exploration_rate *= EXPLORATION_DECAY
self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)
def cartpole():
env = gym.make(ENV_NAME)
#score_logger = ScoreLogger(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)
checkpoint = tf.train.get_checkpoint_state("C:/Users/dgt/Documents")
print('checkpoint:', checkpoint)
if checkpoint and checkpoint.model_checkpoint_path:
dqn_solver.model = keras.models.load_model('cartpole.h5')
dqn_solver.model = model.load_weights('cartpole_weights.h5')
run = 0
i = 0
while i<2:
i = i + 1
#total = 0
run += 1
state = env.reset()
state = np.reshape(state, [1, observation_space])
step = 0
while True:
step += 1
#env.render()
action = dqn_solver.act(state)
state_next, reward, terminal, info = env.step(action)
#total += reward
reward = reward if not terminal else -reward
state_next = np.reshape(state_next, [1, observation_space])
dqn_solver.remember(state, action, reward, state_next, terminal)
state = state_next
dqn_solver.model.save('cartpole.h5')
dqn_solver.model.save_weights('cartpole_weights.h5')
if terminal:
print("Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step))
#score_logger.add_score(step, run)
break
dqn_solver.experience_replay()
if __name__ == "__main__":
cartpole()
cartpole_weights.h5 和 cartpole.h5 文件都出现在我的目标目录中。但是,我相信另一个名为 'checkpoint' 的文件也应该出现。我的理解是,这就是我的代码没有 运行.
的原因首先,如果您还没有保存 weights/model,代码将不会 运行。所以我注释掉了下面的行和 运行 第一次生成文件的脚本。
checkpoint = tf.train.get_checkpoint_state(".")
print('checkpoint:', checkpoint)
if checkpoint and checkpoint.model_checkpoint_path:
dqn_solver.model = tf.keras.models.load_model('cartpole.h5')
dqn_solver.model.load_weights('cartpole_weights.h5')
注意我也修改了上面的代码——之前有一些语法错误。特别是 post
中的这一行dqn_solver.model = model.load_weights('cartpole_weights.h5')
可能是导致问题的原因,因为 model.load_weights('file') 方法会改变模型(而不是返回模型)。
然后我测试了模型权重 saved/loaded 是否正确。为此,您可以做
dqn_solver = DQNSolver(observation_space, action_space)
dqn_solver.model.trainable_variables
查看模型首次制作时的(运行domly 初始化)权重。然后你可以加载权重
dqn_solver.model = tf.keras.models.load_model('cartpole.h5')
或
dqn_solver.model.load_weights('cartpole_weights.h5')
然后您可以再次查看 trainable_variables 以确保它们与初始权重不同,并且它们是等效的。
当您保存模型时,它会保存完整的架构 - 层的确切配置。当你保存权重时,它只会保存你可以看到的所有张量列表trainable_variables。 请注意,当您 load_weights 时,它需要加载到权重所适用的确切架构中,否则它将无法正常工作。因此,如果您在 DQNSolver 中更改了模型架构,然后尝试对旧模型进行 load_weights,它将无法正常工作。如果您 load_model 它会将模型重置为准确的体系结构,并设置权重。
编辑 - 整个修改后的脚本
## Slightly modified from the following repository - https://github.com/gsurma/cartpole
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import random
import gym
import numpy as np
import tensorflow as tf
from collections import deque
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
ENV_NAME = "CartPole-v1"
GAMMA = 0.95
LEARNING_RATE = 0.001
MEMORY_SIZE = 1000000
BATCH_SIZE = 20
EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995
checkpoint_path = "training_1/cp.ckpt"
class DQNSolver:
def __init__(self, observation_space, action_space):
# save_dir = args.save_dir
# self.save_dir = save_dir
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
self.exploration_rate = EXPLORATION_MAX
self.action_space = action_space
self.memory = deque(maxlen=MEMORY_SIZE)
self.model = Sequential()
self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
self.model.add(Dense(24, activation="relu"))
self.model.add(Dense(self.action_space, activation="linear"))
self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def act(self, state):
if np.random.rand() < self.exploration_rate:
return random.randrange(self.action_space)
q_values = self.model.predict(state)
return np.argmax(q_values[0])
def experience_replay(self):
if len(self.memory) < BATCH_SIZE:
return
batch = random.sample(self.memory, BATCH_SIZE)
for state, action, reward, state_next, terminal in batch:
q_update = reward
if not terminal:
q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
q_values = self.model.predict(state)
q_values[0][action] = q_update
self.model.fit(state, q_values, verbose=0)
self.exploration_rate *= EXPLORATION_DECAY
self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)
def cartpole():
env = gym.make(ENV_NAME)
#score_logger = ScoreLogger(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)
# checkpoint = tf.train.get_checkpoint_state(".")
# print('checkpoint:', checkpoint)
# if checkpoint and checkpoint.model_checkpoint_path:
# dqn_solver.model = tf.keras.models.load_model('cartpole.h5')
# dqn_solver.model.load_weights('cartpole_weights.h5')
run = 0
i = 0
while i<2:
i = i + 1
#total = 0
run += 1
state = env.reset()
state = np.reshape(state, [1, observation_space])
step = 0
while True:
step += 1
#env.render()
action = dqn_solver.act(state)
state_next, reward, terminal, info = env.step(action)
#total += reward
reward = reward if not terminal else -reward
state_next = np.reshape(state_next, [1, observation_space])
dqn_solver.remember(state, action, reward, state_next, terminal)
state = state_next
dqn_solver.model.save('cartpole.h5')
dqn_solver.model.save_weights('cartpole_weights.h5')
if terminal:
print("Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step))
#score_logger.add_score(step, run)
break
dqn_solver.experience_replay()
if __name__ == "__main__":
cartpole()
#%% to load saved results
env = gym.make(ENV_NAME)
#score_logger = ScoreLogger(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)
dqn_solver.model = tf.keras.models.load_model('cartpole.h5') # or
dqn_solver.model.load_weights('cartpole_weights.h5')