Stablebaselines3 自定义健身房记录奖励
Stablebaselines3 logging reward with custom gym
我有这个自定义回调来在我的自定义矢量化环境中记录奖励,但奖励一如既往地出现在控制台中 [0] 并且根本没有记录在 tensorboard 中
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard.
"""
def __init__(self, verbose=0):
super(TensorboardCallback, self).__init__(verbose)
def _on_step(self) -> bool:
self.logger.record('reward', self.training_env.get_attr('total_reward'))
return True
这是主要功能的一部分
model = PPO(
"MlpPolicy", env,
learning_rate=3e-4,
policy_kwargs=policy_kwargs,
verbose=1,
# as the environment is not serializable, we need to set a new instance of the environment
loaded_model = model = PPO.load("model", env=env)
loaded_model.set_env(env)
# and continue training
loaded_model.learn(1e+6, callback=TensorboardCallback())
tensorboard_log="./tensorboard/")
您需要添加 [0]
作为索引,
所以你写 self.logger.record('reward', self.training_env.get_attr('total_reward'))
的地方你只需要用 self.logger.record('reward', self.training_env.get_attr ('total_reward')[0]
)
索引
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard.
"""
def __init__(self, verbose=0):
super(TensorboardCallback, self).__init__(verbose)
def _on_step(self) -> bool:
self.logger.record('reward', self.training_env.get_attr('total_reward')[0])
return True
我有这个自定义回调来在我的自定义矢量化环境中记录奖励,但奖励一如既往地出现在控制台中 [0] 并且根本没有记录在 tensorboard 中
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard.
"""
def __init__(self, verbose=0):
super(TensorboardCallback, self).__init__(verbose)
def _on_step(self) -> bool:
self.logger.record('reward', self.training_env.get_attr('total_reward'))
return True
这是主要功能的一部分
model = PPO(
"MlpPolicy", env,
learning_rate=3e-4,
policy_kwargs=policy_kwargs,
verbose=1,
# as the environment is not serializable, we need to set a new instance of the environment
loaded_model = model = PPO.load("model", env=env)
loaded_model.set_env(env)
# and continue training
loaded_model.learn(1e+6, callback=TensorboardCallback())
tensorboard_log="./tensorboard/")
您需要添加 [0]
作为索引,
所以你写 self.logger.record('reward', self.training_env.get_attr('total_reward'))
的地方你只需要用 self.logger.record('reward', self.training_env.get_attr ('total_reward')[0]
)
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard.
"""
def __init__(self, verbose=0):
super(TensorboardCallback, self).__init__(verbose)
def _on_step(self) -> bool:
self.logger.record('reward', self.training_env.get_attr('total_reward')[0])
return True