不匹配,因为某些关键字不正确:dtype

didn't match because some of the keywords were incorrect: dtype

下面的代码是在class

中构建的实例方法
def get_samples_from_component(self,batchSize):
    SMALL = torch.tensor(1e-10, dtype=torch.float64, device=local_device)
    a_inv = torch.pow(self.kumar_a,-1)
    b_inv = torch.pow(self.kumar_b,-1)
    r1    = torch.tensor(SMALL, dtype=torch.float64,device=self.device)
    r2    = torch.tensor(1-SMALL, dtype=torch.float64, device=self.device)
    v_means = torch.mul(self.kumar_b, beta_fn(1.+a_inv, self.kumar_b)).to(device=self.device)
    u       = torch.distributions.uniform.Uniform(low=r1, high=r2).sample([1]).squeeze()
    v_samples  = torch.pow(1 - torch.pow(u, b_inv), a_inv).to(device=self.device)
    if v_samples.ndim > 2:
        v_samples = v_samples.squeeze()
    v0 = v_samples[:, -1].pow(0).reshape(v_samples.shape[0], 1)
    v1 = torch.cat([v_samples[:, :self.z_dim - 1], v0], dim=1)
    n_samples = v1.size()[0]
    n_dims = v1.size()[1]
    components = torch.zeros((n_samples, n_dims)).to(device=self.device)

    for k in range(n_dims):
        if k == 0:
            components[:, k] = v1[:, k]
        else:
            components[:, k] = v1[:, k] * torch.stack([(1 - v1[:, j]) for j in range(n_dims) if j < k]).prod(axis=0)
    # ensure stick segments sum to 1
    assert_almost_equal(torch.ones(n_samples,device=self.device).cpu().numpy(), components.sum(axis=1).detach().cpu().numpy(),
                        decimal=4, err_msg='stick segments do not sum to 1')
    print(f'size of sticks: {components}')


    components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), dtype=torch.long, device=self.device)
    components = torch.cat( [torch.range(0, batchSize).unsqueeze(1), components.unsqueeze(1)], 1)
    print(f'size of sticks: {components}')
    all_z = []
    for d in range(self.z_dim):
        temp_z = torch.cat(1, [self.z_sample_list[k][:, d].unsqueeze(1) for k in range(self.K)])
        all_z.append(gather_nd(temp_z, components).unsqueeze(1))
    out       = torch.cat( all_z,1)
    return out 

通过 运行 我的代码,我收到以下错误消息

components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), dtype=torch.long, device=self.device)
TypeError: new() received an invalid combination of arguments - got (Tensor, device=torch.device, dtype=torch.dtype), but expected one of:
 * (*, torch.device device)
 * (torch.Storage storage)
 * (Tensor other)
 * (tuple of ints size, *, torch.device device)
      didn't match because some of the keywords were incorrect: dtype
 * (object data, *, torch.device device)
      didn't match because some of the keywords were incorrect: dtype

如果有人对此错误提出解决方案,我将不胜感激。

v_means 已经是一个张量,尝试简单地删除张量 re-implementation in:

components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), dtype=torch.long, device=self.device)

至:

components = torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1)

或者简单地删除数据类型,因为它似乎无论如何都将其转换为整数:

components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), device=self.device)