不匹配,因为某些关键字不正确:dtype
didn't match because some of the keywords were incorrect: dtype
下面的代码是在class
中构建的实例方法
def get_samples_from_component(self,batchSize):
SMALL = torch.tensor(1e-10, dtype=torch.float64, device=local_device)
a_inv = torch.pow(self.kumar_a,-1)
b_inv = torch.pow(self.kumar_b,-1)
r1 = torch.tensor(SMALL, dtype=torch.float64,device=self.device)
r2 = torch.tensor(1-SMALL, dtype=torch.float64, device=self.device)
v_means = torch.mul(self.kumar_b, beta_fn(1.+a_inv, self.kumar_b)).to(device=self.device)
u = torch.distributions.uniform.Uniform(low=r1, high=r2).sample([1]).squeeze()
v_samples = torch.pow(1 - torch.pow(u, b_inv), a_inv).to(device=self.device)
if v_samples.ndim > 2:
v_samples = v_samples.squeeze()
v0 = v_samples[:, -1].pow(0).reshape(v_samples.shape[0], 1)
v1 = torch.cat([v_samples[:, :self.z_dim - 1], v0], dim=1)
n_samples = v1.size()[0]
n_dims = v1.size()[1]
components = torch.zeros((n_samples, n_dims)).to(device=self.device)
for k in range(n_dims):
if k == 0:
components[:, k] = v1[:, k]
else:
components[:, k] = v1[:, k] * torch.stack([(1 - v1[:, j]) for j in range(n_dims) if j < k]).prod(axis=0)
# ensure stick segments sum to 1
assert_almost_equal(torch.ones(n_samples,device=self.device).cpu().numpy(), components.sum(axis=1).detach().cpu().numpy(),
decimal=4, err_msg='stick segments do not sum to 1')
print(f'size of sticks: {components}')
components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), dtype=torch.long, device=self.device)
components = torch.cat( [torch.range(0, batchSize).unsqueeze(1), components.unsqueeze(1)], 1)
print(f'size of sticks: {components}')
all_z = []
for d in range(self.z_dim):
temp_z = torch.cat(1, [self.z_sample_list[k][:, d].unsqueeze(1) for k in range(self.K)])
all_z.append(gather_nd(temp_z, components).unsqueeze(1))
out = torch.cat( all_z,1)
return out
通过 运行 我的代码,我收到以下错误消息
components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), dtype=torch.long, device=self.device)
TypeError: new() received an invalid combination of arguments - got (Tensor, device=torch.device, dtype=torch.dtype), but expected one of:
* (*, torch.device device)
* (torch.Storage storage)
* (Tensor other)
* (tuple of ints size, *, torch.device device)
didn't match because some of the keywords were incorrect: dtype
* (object data, *, torch.device device)
didn't match because some of the keywords were incorrect: dtype
如果有人对此错误提出解决方案,我将不胜感激。
v_means
已经是一个张量,尝试简单地删除张量 re-implementation in:
components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), dtype=torch.long, device=self.device)
至:
components = torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1)
或者简单地删除数据类型,因为它似乎无论如何都将其转换为整数:
components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), device=self.device)
下面的代码是在class
中构建的实例方法def get_samples_from_component(self,batchSize):
SMALL = torch.tensor(1e-10, dtype=torch.float64, device=local_device)
a_inv = torch.pow(self.kumar_a,-1)
b_inv = torch.pow(self.kumar_b,-1)
r1 = torch.tensor(SMALL, dtype=torch.float64,device=self.device)
r2 = torch.tensor(1-SMALL, dtype=torch.float64, device=self.device)
v_means = torch.mul(self.kumar_b, beta_fn(1.+a_inv, self.kumar_b)).to(device=self.device)
u = torch.distributions.uniform.Uniform(low=r1, high=r2).sample([1]).squeeze()
v_samples = torch.pow(1 - torch.pow(u, b_inv), a_inv).to(device=self.device)
if v_samples.ndim > 2:
v_samples = v_samples.squeeze()
v0 = v_samples[:, -1].pow(0).reshape(v_samples.shape[0], 1)
v1 = torch.cat([v_samples[:, :self.z_dim - 1], v0], dim=1)
n_samples = v1.size()[0]
n_dims = v1.size()[1]
components = torch.zeros((n_samples, n_dims)).to(device=self.device)
for k in range(n_dims):
if k == 0:
components[:, k] = v1[:, k]
else:
components[:, k] = v1[:, k] * torch.stack([(1 - v1[:, j]) for j in range(n_dims) if j < k]).prod(axis=0)
# ensure stick segments sum to 1
assert_almost_equal(torch.ones(n_samples,device=self.device).cpu().numpy(), components.sum(axis=1).detach().cpu().numpy(),
decimal=4, err_msg='stick segments do not sum to 1')
print(f'size of sticks: {components}')
components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), dtype=torch.long, device=self.device)
components = torch.cat( [torch.range(0, batchSize).unsqueeze(1), components.unsqueeze(1)], 1)
print(f'size of sticks: {components}')
all_z = []
for d in range(self.z_dim):
temp_z = torch.cat(1, [self.z_sample_list[k][:, d].unsqueeze(1) for k in range(self.K)])
all_z.append(gather_nd(temp_z, components).unsqueeze(1))
out = torch.cat( all_z,1)
return out
通过 运行 我的代码,我收到以下错误消息
components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), dtype=torch.long, device=self.device)
TypeError: new() received an invalid combination of arguments - got (Tensor, device=torch.device, dtype=torch.dtype), but expected one of:
* (*, torch.device device)
* (torch.Storage storage)
* (Tensor other)
* (tuple of ints size, *, torch.device device)
didn't match because some of the keywords were incorrect: dtype
* (object data, *, torch.device device)
didn't match because some of the keywords were incorrect: dtype
如果有人对此错误提出解决方案,我将不胜感激。
v_means
已经是一个张量,尝试简单地删除张量 re-implementation in:
components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), dtype=torch.long, device=self.device)
至:
components = torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1)
或者简单地删除数据类型,因为它似乎无论如何都将其转换为整数:
components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), device=self.device)