交换批处理轴对 pytorch 的性能有影响吗?
Swapping the batch axis has effect on the performance in pytorch?
我知道批次维度通常是零轴,我想这是有原因的:批次中每个项目的底层内存都是连续的。
如果我在第一个轴上有另一个维度,我的模型调用的函数会变得更简单,这样我就可以使用 x[k]
而不是 x[:, k]
。
算术运算的结果似乎保持相同的内存布局
x = torch.ones(2,3,4).transpose(0,1)
y = torch.ones_like(x)
u = (x + 1)
v = (x + y)
print(x.stride(), u.stride(), v.stride())
当我创建额外的变量时,我使用 torch.zeros
创建它们,然后转置,这样最大的步幅也到达轴 1。
例如
a,b,c = torch.zeros(
(3, x.shape[1], ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).transpose(1,2)
将创建三个具有相同批量大小的张量 x.shape[1]
。
就内存位置而言,拥有
会有什么不同
a,b,c = torch.zeros(
(x.shape[1], 3, ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).permute(1,2,0, ...)
相反。
我应该关心这个吗?
TLDR;切片看似包含的信息较少......但实际上与原始张量共享相同的存储缓冲区。由于 permute 不影响底层内存布局,因此这两个操作本质上是等价的。
这两个本质上是一样的,底层数据存储缓冲区保持不变,只有元数据 即你如何与之交互该缓冲区(步幅和形状)发生变化。
让我们看一个简单的例子:
>>> x = torch.ones(2,3,4).transpose(0,1)
>>> x_ptr = x.data_ptr()
>>> x.shape, x.stride(), x_ptr
(3, 2, 4), (4, 12, 1), 94674451667072
我们将 'base' 张量的数据指针保存在 x_ptr
:
在第二个轴上切片:
>>> y = x[:, 0]
>>> y.shape, y.stride(), x_ptr == y.data_ptr()
(3, 4), (4, 1), True
如您所见,x
和 x[:, k]
共享相同的存储空间。
排列前两个轴然后在第一个轴上切片:
>>> z = x.permute(1, 0, 2)[0]
>>> z.shape, z.stride(), x_ptr == z.data_ptr()
(3, 4), (4, 1), True
在这里,您再次注意到 x.data_ptr
与 z.data_ptr
相同。
事实上,您甚至可以使用 torch.as_strided
:
从 y
到 x
的表示
>>> torch.as_strided(y, size=x.shape, stride=x.stride())
tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.]]])
与z
相同:
>>> torch.as_strided(z, size=x.shape, stride=x.stride())
两者都将 return x
的 copy 因为 torch.as_strided
正在为新创建的张量分配内存。这两行只是为了说明我们如何仍然可以从 x
的切片 'get back' 到 x
,我们可以通过更改张量的元数据来恢复表观内容。
我知道批次维度通常是零轴,我想这是有原因的:批次中每个项目的底层内存都是连续的。
如果我在第一个轴上有另一个维度,我的模型调用的函数会变得更简单,这样我就可以使用 x[k]
而不是 x[:, k]
。
算术运算的结果似乎保持相同的内存布局
x = torch.ones(2,3,4).transpose(0,1)
y = torch.ones_like(x)
u = (x + 1)
v = (x + y)
print(x.stride(), u.stride(), v.stride())
当我创建额外的变量时,我使用 torch.zeros
创建它们,然后转置,这样最大的步幅也到达轴 1。
例如
a,b,c = torch.zeros(
(3, x.shape[1], ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).transpose(1,2)
将创建三个具有相同批量大小的张量 x.shape[1]
。
就内存位置而言,拥有
a,b,c = torch.zeros(
(x.shape[1], 3, ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).permute(1,2,0, ...)
相反。
我应该关心这个吗?
TLDR;切片看似包含的信息较少......但实际上与原始张量共享相同的存储缓冲区。由于 permute 不影响底层内存布局,因此这两个操作本质上是等价的。
这两个本质上是一样的,底层数据存储缓冲区保持不变,只有元数据 即你如何与之交互该缓冲区(步幅和形状)发生变化。
让我们看一个简单的例子:
>>> x = torch.ones(2,3,4).transpose(0,1)
>>> x_ptr = x.data_ptr()
>>> x.shape, x.stride(), x_ptr
(3, 2, 4), (4, 12, 1), 94674451667072
我们将 'base' 张量的数据指针保存在 x_ptr
:
在第二个轴上切片:
>>> y = x[:, 0] >>> y.shape, y.stride(), x_ptr == y.data_ptr() (3, 4), (4, 1), True
如您所见,
x
和x[:, k]
共享相同的存储空间。排列前两个轴然后在第一个轴上切片:
>>> z = x.permute(1, 0, 2)[0] >>> z.shape, z.stride(), x_ptr == z.data_ptr() (3, 4), (4, 1), True
在这里,您再次注意到
x.data_ptr
与z.data_ptr
相同。
事实上,您甚至可以使用 torch.as_strided
:
y
到 x
的表示
>>> torch.as_strided(y, size=x.shape, stride=x.stride())
tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.]]])
与z
相同:
>>> torch.as_strided(z, size=x.shape, stride=x.stride())
两者都将 return x
的 copy 因为 torch.as_strided
正在为新创建的张量分配内存。这两行只是为了说明我们如何仍然可以从 x
的切片 'get back' 到 x
,我们可以通过更改张量的元数据来恢复表观内容。