torch.einsum 的内存使用情况

Memory usage of torch.einsum

我一直在尝试调试某个模型,该模型在重复几次的层中使用 torch.einsum 运算符。

在尝试分析训练期间模型的 GPU 内存使用情况时,我注意到某个 Einsum 操作显着增加了内存使用量。我正在处理多维矩阵。操作是torch.einsum('b q f n, b f n d -> b q f d', A, B).

另外值得一提的是:

我一直想知道为什么这个操作会使用这么多内存,以及为什么在该层类型的每次迭代后内存仍然分配。

变量“x”确实被覆盖了,但是张量数据保存在内存中(也称为层的activation)以供以后在反向传递中使用。

因此,您实际上是在为 torch.einsum 的结果分配新的内存数据,但您不会替换 x 的内存,即使它看似已被覆盖。


要将其传递给测试,您可以在 torch.no_grad() 上下文管理器(这些激活不会保存在内存中)下计算前向传递,并查看与标准相比的内存使用差异推理。