在 PyTorch 中计算欧氏距离而不是矩阵乘法

In PyTorch calc Euclidean distance instead of matrix multiplication

假设我们有 2 个矩阵:

mat = torch.randn([20, 7]) * 100
mat2 = torch.randn([7, 20]) * 100

n, m = mat.shape

最简单的常用矩阵乘法如下所示:

def mat_vec_dot_product(mat, vect):
    n, m = mat.shape
    
    res = torch.zeros([n])
    for i in range(n):
        for j in range(m):
            res[i] += mat[i][j] * vect[j]
        
    return res

res = torch.zeros([n, n])
for k in range(n):
    res[:, k] = mat_vec_dot_product(mat, mat2[:, k])
    

但是如果我需要应用 L2 范数而不是点积怎么办?接下来是代码:

def mat_vec_l2_mult(mat, vect):
    n, m = mat.shape
    
    res = torch.zeros([n])
    for i in range(n):
        for j in range(m):
            res[i] += (mat[i][j] - vect[j]) ** 2
            
    res = res.sqrt()
        
    return res

for k in range(n):
    res[:, k] = mat_vec_l2_mult(mat, mat2[:, k])

我们能否使用 Torch 或任何其他库以最佳方式做到这一点?因为天真的 O(n^3) Python 代码运行起来真的很慢。

首先,PyTorch中的矩阵乘法有一个built-in运算符:@。 因此,要将 mat 和 mat2 相乘,您只需执行以下操作:

mat @ mat2

(应该可以,假设尺寸一致)。

现在,要计算您似乎在第二个块中计算的差平方和(SSD,或差的 L2 范数),您可以使用一个简单的技巧。 由于平方 L2 范数 ||m_i - v||^2(其中 m_i 是矩阵 M 的第 i 行,v 是向量)等于点积 <m_i - v, m_i-v> - 从您获得的点积的线性度:<m_i,m_i> - 2<m_i,v> + <v,v> 因此您可以通过计算一次每行的平方 L2 范数,从向量 v 计算 M 中每一行的 SSD ,一次是每行和向量之间的点积,一次是向量的 L2 范数。这可以在 O(n^2) 中完成。 但是,对于 2 个矩阵之间的 SSD,您仍然会得到 O(n^3)。尽管可以通过矢量化操作而不是使用循环来进行改进。 这是 2 个矩阵的简单实现:

def mat_mat_l2_mult(mat,mat2):
    rows_norm = (torch.norm(mat, dim=1, p=2, keepdim=True)**2).repeat(1,mat2.shape[1])
    cols_norm = (torch.norm(mat2, dim=0, p=2, keepdim=True)**2).repeat(mat.shape[0], 1)
    rows_cols_dot_product = mat @ mat2
    ssd = rows_norm -2*rows_cols_dot_product + cols_norm
    return ssd.sqrt()

mat = torch.randn([20, 7])
mat2 = torch.randn([7,20])
print(mat_mat_l2_mult(mat, mat2))

生成的矩阵将在每个单元格 i,j 中包含 mat 中每一行 ij 中每一列之间差异的 L2 范数27=].

对 L2 范数 - 欧氏距离使用 torch.cdist

res = torch.cdist(mat, mat2.permute(1,0), p=2)

在这里,我使用 permutemat2 的 dim 从 7,20 交换为 20,7