DQN 为每个状态(车杆)预测相同的动作值

DQN predicts same action value for every state (cart pole)

我正在尝试实施 DQN。作为热身,我想使用由两个隐藏层以及输入和输出层组成的 MLP 来解决 CartPole-v0。输入是一个 4 元素数组 [小车位置、小车速度、杆角度、杆 angular 速度],输出是每个动作(左或右)的动作值。我并没有完全实现“使用 DRL 玩 Atari”论文中的 DQN(没有用于输入的帧堆叠等)。我也做了一些非标准的选择,比如把 done 和动作价值的目标网络预测放在经验回放中,但这些选择不应该影响学习。

无论如何,我都很难让它正常工作。无论我训练代理多长时间,它都会为一个动作预测比另一个动作更高的值,例如所有状态 s 的 Q(s, Right)> Q(s, Left)。下面是我的学习代码,我的网络定义,以及我训练得到的一些结果

class DQN:
    def __init__(self, env, steps_per_episode=200):
        self.env = env
        self.agent_network = MlpPolicy(self.env)
        self.target_network = MlpPolicy(self.env)
        self.target_network.load_state_dict(self.agent_network.state_dict())
        self.target_network.eval()
        self.optimizer = torch.optim.RMSprop(
            self.agent_network.parameters(), lr=0.005, momentum=0.95
        )
        self.replay_memory = ReplayMemory()
        self.gamma = 0.99
        self.steps_per_episode = steps_per_episode
        self.random_policy_stop = 1000
        self.start_learning_time = 1000
        self.batch_size = 32

    def learn(self, episodes):
        time = 0
        for episode in tqdm(range(episodes)):
            state = self.env.reset()
            for step in range(self.steps_per_episode):
                if time < self.random_policy_stop:
                    action = self.env.action_space.sample()
                else:
                    action = select_action(self.env, time, state, self.agent_network)
                new_state, reward, done, _ = self.env.step(action)
                target_value_pred = predict_target_value(
                    new_state, reward, done, self.target_network, self.gamma
                )
                experience = Experience(
                    state, action, reward, new_state, done, target_value_pred
                )
                self.replay_memory.append(experience)
                if time > self.start_learning_time:  # learning step
                    experience_batch = self.replay_memory.sample(self.batch_size)
                    target_preds = extract_value_predictions(experience_batch)
                    agent_preds = agent_batch_preds(
                        experience_batch, self.agent_network
                    )
                    loss = torch.square(agent_preds - target_preds).sum()
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                if time % 1_000 == 0:  # how frequently to update target net
                    self.target_network.load_state_dict(self.agent_network.state_dict())
                    self.target_network.eval()

                state = new_state
                time += 1

                if done:
                    break

def agent_batch_preds(experience_batch: list, agent_network: MlpPolicy):
    """
    Calculate the agent action value estimates using the old states and the
    actual actions that the agent took at that step.
    """
    old_states = extract_old_states(experience_batch)
    actions = extract_actions(experience_batch)
    agent_preds = agent_network(old_states)
    experienced_action_values = agent_preds.index_select(1, actions).diag()
    return experienced_action_values
def extract_actions(experience_batch: list) -> list:
    """
    Extract the list of actions from experience replay batch and torchify
    """
    actions = [exp.action for exp in experience_batch]
    actions = torch.tensor(actions)
    return actions
class MlpPolicy(nn.Module):
    """
    This class implements the MLP which will be used as the Q network. I only
    intend to solve classic control problems with this.
    """

    def __init__(self, env):
        super(MlpPolicy, self).__init__()
        self.env = env
        self.input_dim = self.env.observation_space.shape[0]
        self.output_dim = self.env.action_space.n
        self.fc1 = nn.Linear(self.input_dim, 32)
        self.fc2 = nn.Linear(32, 128)
        self.fc3 = nn.Linear(128, 32)
        self.fc4 = nn.Linear(32, self.output_dim)

    def forward(self, x):
        if type(x) != torch.Tensor:
            x = torch.tensor(x).float()
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

学习成绩:

在这里我看到一个动作总是比其他动作更有价值(Q(右,s)> Q(左,s))。同样清楚的是,网络正在为每个状态预测相同的动作值。

有人知道发生了什么事吗?我已经做了很多调试和仔细阅读原始论文(也考虑过“标准化”观察 space,即使速度可以是无限的)并且在这一点上可能会遗漏一些明显的东西。如果有用的话,我可以为辅助函数包含更多代码。

网络定义没有问题。事实证明学习率太高,将其降低 0.00025(如介绍 DQN 的原始自然论文中那样)导致可以解决 CartPole-v0 的代理。

也就是说,学习算法不正确。特别是我使用了错误的目标动作价值预测。请注意,上面列出的算法不使用目标网络的最新版本来进行预测。随着训练的进行,这会导致结果不佳,因为代理正在根据陈旧的目标数据进行学习。解决这个问题的方法是将 (s, a, r, s', done) 放入回放内存中,然后在对小批量进行采样时使用目标网络的最新版本进行目标预测。请参阅下面的代码以获取更新的学习循环。

def learn(self, episodes):
        time = 0
        for episode in tqdm(range(episodes)):
            state = self.env.reset()
            for step in range(self.steps_per_episode):
                if time < self.random_policy_stop:
                    action = self.env.action_space.sample()
                else:
                    action = select_action(self.env, time, state, self.agent_network)
                new_state, reward, done, _ = self.env.step(action)
                experience = Experience(state, action, reward, new_state, done)
                self.replay_memory.append(experience)
                if time > self.start_learning_time:  # learning step.
                    experience_batch = self.replay_memory.sample(self.batch_size)
                    target_preds = target_batch_preds(
                        experience_batch, self.target_network, self.gamma
                    )
                    agent_preds = agent_batch_preds(
                        experience_batch, self.agent_network
                    )
                    loss = torch.square(agent_preds - target_preds).sum()
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                if time % 1_000 == 0:  # how frequently to update target net
                    self.target_network.load_state_dict(self.agent_network.state_dict())
                    self.target_network.eval()

                state = new_state
                time += 1
                if done:
                    break