I get this error using PyTorch: RuntimeError: gather_out_cpu(): Expected dtype int64 for index

I get this error using PyTorch: RuntimeError: gather_out_cpu(): Expected dtype int64 for index

我正在尝试使用 PyTorch 制作 AI,但出现此错误:

RuntimeError: gather_out_cpu(): Expected dtype int64 for index

这是我的职能:

def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
    outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
    next_outputs = self.model(batch_next_state).detach().max(1)[0]
    target = self.gamma * next_outputs + batch_reward
    td_loss = F.smooth_l1_loss(outputs, target)
    self.optimizer.zero_grad()
    td_loss.backward(retain_variables = True)
    self.optimizer.step()

您需要先更改 batch_action 张量的数据类型,然后再将其传递给 torch.gather

def learn(...):
    batch_action = batch_action.type(torch.int64) 
    outputs = ...
    ...

# or
outputs = self.model(batch_state).gather(1, batch_action.type(torch.int64).unsqueeze(1)).squeeze(1)