在 Pytorch 中计算欧几里德范数.. 麻烦理解和实现

Calculating Euclidian Norm in Pytorch.. Trouble understanding an implementation

我看到另一个 Whosebug 线程讨论了计算欧几里得范数的各种实现,但我无法看到 why/how 特定的实现有效。

代码在 MMD 度量的实现中找到:https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/statistics_diff.py

这是一些开始的样板文件:

import torch
sample_1, sample_2 = torch.ones((10,2)), torch.zeros((10,2))

然后下一部分是我们从上面的代码中提取的内容..我不确定为什么样本被连接在一起..

sample_12 = torch.cat((sample_1, sample_2), 0)
distances = pdist(sample_12, sample_12, norm=2)

然后传递给 pdist 函数:

def pdist(sample_1, sample_2, norm=2, eps=1e-5):
    r"""Compute the matrix of all squared pairwise distances.
    Arguments
    ---------
    sample_1 : torch.Tensor or Variable
        The first sample, should be of shape ``(n_1, d)``.
    sample_2 : torch.Tensor or Variable
        The second sample, should be of shape ``(n_2, d)``.
    norm : float
        The l_p norm to be used.
    Returns
    -------
    torch.Tensor or Variable
        Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
        ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""

这里我们进入计算的重点

    n_1, n_2 = sample_1.size(0), sample_2.size(0)
    norm = float(norm)
    if norm == 2.:
        norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True)
        norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True)
        norms = (norms_1.expand(n_1, n_2) +
             norms_2.transpose(0, 1).expand(n_1, n_2))
        distances_squared = norms - 2 * sample_1.mm(sample_2.t())
        return torch.sqrt(eps + torch.abs(distances_squared))

我很困惑为什么要这样计算欧几里德范数。任何见解将不胜感激

让我们逐步浏览一下这个代码块。欧氏距离的定义,即L2范数是

让我们考虑最简单的情况。我们有两个样本,

样本 a 有两个向量 [a00, a01][a10, a11]。样本 b 相同。让我们首先计算 norm

n1, n2 = a.size(0), b.size(0)  # here both n1 and n2 have the value 2
norm1 = torch.sum(a**2, dim=1)
norm2 = torch.sum(b**2, dim=1)

现在我们得到

接下来,我们有 norms_1.expand(n_1, n_2)norms_2.transpose(0, 1).expand(n_1, n_2)

注意b是转置的。两者相加得到 norm

sample_1.mm(sample_2.t()),就是两个矩阵相乘。

因此,手术后

distances_squared = norms - 2 * sample_1.mm(sample_2.t())

你得到

最后,最后一步是对矩阵中的每个元素求平方根。