在 PyTorch 中计算欧氏距离而不是矩阵乘法
In PyTorch calc Euclidean distance instead of matrix multiplication
假设我们有 2 个矩阵:
mat = torch.randn([20, 7]) * 100
mat2 = torch.randn([7, 20]) * 100
n, m = mat.shape
最简单的常用矩阵乘法如下所示:
def mat_vec_dot_product(mat, vect):
n, m = mat.shape
res = torch.zeros([n])
for i in range(n):
for j in range(m):
res[i] += mat[i][j] * vect[j]
return res
res = torch.zeros([n, n])
for k in range(n):
res[:, k] = mat_vec_dot_product(mat, mat2[:, k])
但是如果我需要应用 L2 范数而不是点积怎么办?接下来是代码:
def mat_vec_l2_mult(mat, vect):
n, m = mat.shape
res = torch.zeros([n])
for i in range(n):
for j in range(m):
res[i] += (mat[i][j] - vect[j]) ** 2
res = res.sqrt()
return res
for k in range(n):
res[:, k] = mat_vec_l2_mult(mat, mat2[:, k])
我们能否使用 Torch 或任何其他库以最佳方式做到这一点?因为天真的 O(n^3) Python 代码运行起来真的很慢。
首先,PyTorch中的矩阵乘法有一个built-in运算符:@
。
因此,要将 mat 和 mat2 相乘,您只需执行以下操作:
mat @ mat2
(应该可以,假设尺寸一致)。
现在,要计算您似乎在第二个块中计算的差平方和(SSD,或差的 L2 范数),您可以使用一个简单的技巧。
由于平方 L2 范数 ||m_i - v||^2
(其中 m_i
是矩阵 M
的第 i 行,v
是向量)等于点积 <m_i - v, m_i-v>
- 从您获得的点积的线性度:<m_i,m_i> - 2<m_i,v> + <v,v>
因此您可以通过计算一次每行的平方 L2 范数,从向量 v
计算 M
中每一行的 SSD ,一次是每行和向量之间的点积,一次是向量的 L2 范数。这可以在 O(n^2)
中完成。
但是,对于 2 个矩阵之间的 SSD,您仍然会得到 O(n^3)
。尽管可以通过矢量化操作而不是使用循环来进行改进。
这是 2 个矩阵的简单实现:
def mat_mat_l2_mult(mat,mat2):
rows_norm = (torch.norm(mat, dim=1, p=2, keepdim=True)**2).repeat(1,mat2.shape[1])
cols_norm = (torch.norm(mat2, dim=0, p=2, keepdim=True)**2).repeat(mat.shape[0], 1)
rows_cols_dot_product = mat @ mat2
ssd = rows_norm -2*rows_cols_dot_product + cols_norm
return ssd.sqrt()
mat = torch.randn([20, 7])
mat2 = torch.randn([7,20])
print(mat_mat_l2_mult(mat, mat2))
生成的矩阵将在每个单元格 i,j
中包含 mat
中每一行 i
与 j
中每一列之间差异的 L2 范数27=].
对 L2 范数 - 欧氏距离使用 torch.cdist
res = torch.cdist(mat, mat2.permute(1,0), p=2)
在这里,我使用 permute
将 mat2
的 dim 从 7,20
交换为 20,7
假设我们有 2 个矩阵:
mat = torch.randn([20, 7]) * 100
mat2 = torch.randn([7, 20]) * 100
n, m = mat.shape
最简单的常用矩阵乘法如下所示:
def mat_vec_dot_product(mat, vect):
n, m = mat.shape
res = torch.zeros([n])
for i in range(n):
for j in range(m):
res[i] += mat[i][j] * vect[j]
return res
res = torch.zeros([n, n])
for k in range(n):
res[:, k] = mat_vec_dot_product(mat, mat2[:, k])
但是如果我需要应用 L2 范数而不是点积怎么办?接下来是代码:
def mat_vec_l2_mult(mat, vect):
n, m = mat.shape
res = torch.zeros([n])
for i in range(n):
for j in range(m):
res[i] += (mat[i][j] - vect[j]) ** 2
res = res.sqrt()
return res
for k in range(n):
res[:, k] = mat_vec_l2_mult(mat, mat2[:, k])
我们能否使用 Torch 或任何其他库以最佳方式做到这一点?因为天真的 O(n^3) Python 代码运行起来真的很慢。
首先,PyTorch中的矩阵乘法有一个built-in运算符:@
。
因此,要将 mat 和 mat2 相乘,您只需执行以下操作:
mat @ mat2
(应该可以,假设尺寸一致)。
现在,要计算您似乎在第二个块中计算的差平方和(SSD,或差的 L2 范数),您可以使用一个简单的技巧。
由于平方 L2 范数 ||m_i - v||^2
(其中 m_i
是矩阵 M
的第 i 行,v
是向量)等于点积 <m_i - v, m_i-v>
- 从您获得的点积的线性度:<m_i,m_i> - 2<m_i,v> + <v,v>
因此您可以通过计算一次每行的平方 L2 范数,从向量 v
计算 M
中每一行的 SSD ,一次是每行和向量之间的点积,一次是向量的 L2 范数。这可以在 O(n^2)
中完成。
但是,对于 2 个矩阵之间的 SSD,您仍然会得到 O(n^3)
。尽管可以通过矢量化操作而不是使用循环来进行改进。
这是 2 个矩阵的简单实现:
def mat_mat_l2_mult(mat,mat2):
rows_norm = (torch.norm(mat, dim=1, p=2, keepdim=True)**2).repeat(1,mat2.shape[1])
cols_norm = (torch.norm(mat2, dim=0, p=2, keepdim=True)**2).repeat(mat.shape[0], 1)
rows_cols_dot_product = mat @ mat2
ssd = rows_norm -2*rows_cols_dot_product + cols_norm
return ssd.sqrt()
mat = torch.randn([20, 7])
mat2 = torch.randn([7,20])
print(mat_mat_l2_mult(mat, mat2))
生成的矩阵将在每个单元格 i,j
中包含 mat
中每一行 i
与 j
中每一列之间差异的 L2 范数27=].
对 L2 范数 - 欧氏距离使用 torch.cdist
res = torch.cdist(mat, mat2.permute(1,0), p=2)
在这里,我使用 permute
将 mat2
的 dim 从 7,20
交换为 20,7