我怎样才能有效地 modify/make 成对距离矩阵?

How can I efficiently modify/make pairwise distance matrix?

    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y = x
        y_norm = x_norm.view(1, -1)
    dist = (x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)))
    return dist

上面是一段代码,用于计算x(M点)和y(N点)之间的成对距离矩阵(M*N)。

当两点之间的距离大于特定值时,我希望制作具有0元素的成对距离矩阵'T'。

遇到这种情况,我该怎么办?

谢谢

我想你在找 torch.where:

new_dist = troch.where(dist > T, dist, 0.)