如何在pytorch中批处理矩阵向量乘法(一个矩阵,多个向量)而不复制内存中的矩阵
How to batch matrix-vector multiplication (one matrix, many vectors) in pytorch without duplicating the matrix in memory
我有 n
个大小为 d
的向量和一个 d x d
矩阵 J
。我想用每个 n
向量计算 J
的 n
矩阵向量乘法。
为此,我正在使用 pytorch 的 expand()
来获得 J
的 broadcast,但似乎在计算矩阵向量积时, pytorch 在内存中实例化一个完整的 n x d x d
张量。例如以下代码
device = torch.device("cuda:0")
n = 100_000_000
d = 10
x = torch.randn(n, d, dtype=torch.float32, device=device)
J = torch.randn(d, d, dtype=torch.float32, device=device).expand(n, d, d)
y = torch.sign(torch.matmul(J, x[..., None])[..., 0])
加注
RuntimeError: CUDA out of memory. Tried to allocate 37.25 GiB (GPU 0; 11.00 GiB total capacity; 3.73 GiB already allocated; 5.69 GiB free; 3.73 GiB reserved in total by PyTorch)
这意味着 pytorch 不必要地尝试为矩阵 J
的 n
副本分配 space
如何在不耗尽 GPU 内存的情况下以矢量化方式执行此任务(矩阵很小,所以我不想循环遍历每个矩阵-矢量乘法)?
我认为这会解决问题:
import torch
x = torch.randn(n, d)
J = torch.randn(d, d) # no need to expand
y = torch.matmul(J, x.T).T
正在使用您的表达式进行验证:
Jex = J.expand(n, d, d)
y1 = torch.matmul(Jex, x[..., None])[..., 0]
y = torch.matmul(J, x.T).T
torch.allclose(y1, y) # using allclose for float values
# tensor(True)
我有 n
个大小为 d
的向量和一个 d x d
矩阵 J
。我想用每个 n
向量计算 J
的 n
矩阵向量乘法。
为此,我正在使用 pytorch 的 expand()
来获得 J
的 broadcast,但似乎在计算矩阵向量积时, pytorch 在内存中实例化一个完整的 n x d x d
张量。例如以下代码
device = torch.device("cuda:0")
n = 100_000_000
d = 10
x = torch.randn(n, d, dtype=torch.float32, device=device)
J = torch.randn(d, d, dtype=torch.float32, device=device).expand(n, d, d)
y = torch.sign(torch.matmul(J, x[..., None])[..., 0])
加注
RuntimeError: CUDA out of memory. Tried to allocate 37.25 GiB (GPU 0; 11.00 GiB total capacity; 3.73 GiB already allocated; 5.69 GiB free; 3.73 GiB reserved in total by PyTorch)
这意味着 pytorch 不必要地尝试为矩阵 J
n
副本分配 space
如何在不耗尽 GPU 内存的情况下以矢量化方式执行此任务(矩阵很小,所以我不想循环遍历每个矩阵-矢量乘法)?
我认为这会解决问题:
import torch
x = torch.randn(n, d)
J = torch.randn(d, d) # no need to expand
y = torch.matmul(J, x.T).T
正在使用您的表达式进行验证:
Jex = J.expand(n, d, d)
y1 = torch.matmul(Jex, x[..., None])[..., 0]
y = torch.matmul(J, x.T).T
torch.allclose(y1, y) # using allclose for float values
# tensor(True)