交换批处理轴对 pytorch 的性能有影响吗?

Swapping the batch axis has effect on the performance in pytorch?

我知道批次维度通常是零轴,我想这是有原因的:批次中每个项目的底层内存都是连续的。

如果我在第一个轴上有另一个维度,我的模型调用的函数会变得更简单,这样我就可以使用 x[k] 而不是 x[:, k]

算术运算的结果似乎保持相同的内存布局

x = torch.ones(2,3,4).transpose(0,1)
y = torch.ones_like(x)
u = (x + 1)
v = (x + y)
print(x.stride(), u.stride(), v.stride())

当我创建额外的变量时,我使用 torch.zeros 创建它们,然后转置,这样最大的步幅也到达轴 1。

例如

a,b,c = torch.zeros(
         (3, x.shape[1], ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).transpose(1,2)

将创建三个具有相同批量大小的张量 x.shape[1]。 就内存位置而言,拥有

会有什么不同
a,b,c = torch.zeros(
  (x.shape[1], 3, ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).permute(1,2,0, ...)

相反。

我应该关心这个吗?

TLDR;切片看似包含的信息较少......但实际上与原始张量共享相同的存储缓冲区。由于 permute 不影响底层内存布局,因此这两个操作本质上是等价的。


这两个本质上是一样的,底层数据存储缓冲区保持不变,只有元数据 你如何与之交互该缓冲区(步幅和形状)发生变化。

让我们看一个简单的例子:

>>> x = torch.ones(2,3,4).transpose(0,1)
>>> x_ptr = x.data_ptr()

>>> x.shape, x.stride(), x_ptr
(3, 2, 4), (4, 12, 1), 94674451667072

我们将 'base' 张量的数据指针保存在 x_ptr:

  1. 在第二个轴上切片:

    >>> y = x[:, 0]
    
    >>> y.shape, y.stride(), x_ptr == y.data_ptr()
    (3, 4), (4, 1), True
    

    如您所见,xx[:, k] 共享相同的存储空间。

  2. 排列前两个轴然后在第一个轴上切片:

    >>> z = x.permute(1, 0, 2)[0]
    
    >>> z.shape, z.stride(), x_ptr == z.data_ptr()
    (3, 4), (4, 1), True
    

    在这里,您再次注意到 x.data_ptrz.data_ptr 相同。


事实上,您甚至可以使用 torch.as_strided:

yx 的表示
>>> torch.as_strided(y, size=x.shape, stride=x.stride())
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]]])

z相同:

>>> torch.as_strided(z, size=x.shape, stride=x.stride())

两者都将 return xcopy 因为 torch.as_strided 正在为新创建的张量分配内存。这两行只是为了说明我们如何仍然可以从 x 的切片 'get back' 到 x,我们可以通过更改张量的元数据来恢复表观内容。