Python:使用迭代方法实现的向量化计算

Python: Vectorize Calculation Implemented using Iterative Approach

我正在尝试执行一些计算,但我不知道如何矢量化我的代码而不使用循环。

让我解释一下:我有一个 01 的矩阵 M[N,C]。另一个矩阵 Y[N,1] 包含 [0,C-1] 的值(我的 类)。另一个矩阵 ds[N,M] 这是我的数据集。

我的输出矩阵的大小为 grad[M,C],应按如下方式计算:我将对 grad[:,0] 进行解释,其他列的逻辑相同。

对于ds中的每一行(样本),如果Y[that sample] != 0(输出矩阵的当前列)和M[that sample, 0] > 0,则grad[:,0] += ds[that sample]

如果Y[that sample] == 0,则grad[:,0] -= (ds[that sample] * <Num of non zeros in M[that sample,:]>)

这是我的迭代方法:

    for i in range(M.size(dim=1)):
        for j in range(ds.size(dim=0)):
            if y[j] == i:
                grad[:,i] = grad[:,i] - (ds[j,:].T * sum(M[j,:]))
            else:
                if M[j,i] > 0:
                    grad[:,i] = grad[:,i] + ds[j,:].T 

由于您正在处理三个维度 nmc(小写以避免歧义),更改所有张量的形状可能很有用到 (n, m, c),通过在缺失的维度上复制它们的值(例如 M(m, c) 变成 M(n, m, c))。

但是,您可以跳过显式复制并使用广播,因此解压缩丢失的维度就足够了(例如 M(m, c) 变为 M(1, m, c)

考虑到这些因素,你的代码向量化如下

cond = y.unsqueeze(2) == torch.arange(M.size(dim=1)).unsqueeze(0)
pos = ds.unsqueeze(2) * M.unsqueeze(1) * cond
neg = ds.unsqueeze(2) * M.unsqueeze(1).sum(dim=0, keepdim=True) * ~cond
grad += (pos - neg).sum(dim=0)

这里有一个小测试来检查解决方案的有效性

import torch

n, m, c = 11, 5, 7

y = torch.randint(c, size=(n, 1))
ds = torch.rand(n, m)
M = torch.randint(2, size=(n, c))
grad = torch.rand(m, c)


def slow_grad(y, ds, M, grad):
    for i in range(M.size(dim=1)):
        for j in range(ds.size(dim=0)):
            if y[j] == i:
                grad[:,i] = grad[:,i] - (ds[j,:].T * sum(M[j,:]))
            else:
                if M[j,i] > 0:
                    grad[:,i] = grad[:,i] + ds[j,:].T
    return grad


def fast_grad(y, ds, M, grad):
    cond = y.unsqueeze(2) == torch.arange(M.size(dim=1)).unsqueeze(0)
    pos = ds.unsqueeze(2) * M.unsqueeze(1) * cond
    neg = ds.unsqueeze(2) * M.unsqueeze(1).sum(dim=0, keepdim=True) * ~cond
    grad += (pos - neg).sum(dim=0)
    return grad
  
# Assert equality of all elements function outputs, throws an exception if false
assert torch.all(slow_grad(y, ds, M, grad) == fast_grad(y, ds, M, grad))

也可以随意测试其他案例!