在 Pytorch 中查找前 k 个匹配项
Finding the top k matches in Pytorch
我正在使用以下代码使用 pytorch 查找 topk 匹配项:
def find_top(self, x, y, n_neighbors, unit_vectors=False, cuda=False):
if not unit_vectors:
x = __to_unit_torch__(x, cuda=cuda)
y = __to_unit_torch__(y, cuda=cuda)
with torch.no_grad():
d = 1. - torch.matmul(x, y.transpose(0, 1))
values, indices = torch.topk(d, n_neighbors, dim=1, largest=False, sorted=True)
return indices.cpu().numpy()
不幸的是,它抛出以下错误:
values, indices = torch.topk(d, n_neighbors, dim=1, largest=False, sorted=True)
RuntimeError: invalid argument 5: k not in range for dimension at /pytorch/aten/src/THC/generic/THCTensorTopK.cu:23
d 的大小为 (1793,1)
。我错过了什么?
This error 当您调用 torch.topk
且 k
大于 类 的总数时,会发生
This error。减少你的争论,它应该 运行 没问题。
我正在使用以下代码使用 pytorch 查找 topk 匹配项:
def find_top(self, x, y, n_neighbors, unit_vectors=False, cuda=False):
if not unit_vectors:
x = __to_unit_torch__(x, cuda=cuda)
y = __to_unit_torch__(y, cuda=cuda)
with torch.no_grad():
d = 1. - torch.matmul(x, y.transpose(0, 1))
values, indices = torch.topk(d, n_neighbors, dim=1, largest=False, sorted=True)
return indices.cpu().numpy()
不幸的是,它抛出以下错误:
values, indices = torch.topk(d, n_neighbors, dim=1, largest=False, sorted=True)
RuntimeError: invalid argument 5: k not in range for dimension at /pytorch/aten/src/THC/generic/THCTensorTopK.cu:23
d 的大小为 (1793,1)
。我错过了什么?
This error 当您调用 torch.topk
且 k
大于 类 的总数时,会发生
This error。减少你的争论,它应该 运行 没问题。