Pytorch 复制张量的首选方式

Pytorch preferred way to copy a tensor

似乎有几种方法可以在 Pytorch 中创建张量的副本,包括

y = tensor.new_tensor(x) #a

y = x.clone().detach() #b

y = torch.empty_like(x).copy_(x) #c

y = torch.tensor(x) #d
根据我执行 ad 时收到的 UserWarning,

b 明显优于 ad。为什么它是首选?表现?我认为它的可读性较差。

有什么原因 for/against 使用 c 吗?

Pytorch '1.1.0' 现在推荐 #b 并显示 #d 的警告

根据 Pytorch documentation #a 和 #b 是等价的。它还说

The equivalents using clone() and detach() are recommended.

所以如果你想复制一个张量并从你应该使用的计算图中分离出来

y = x.clone().detach()

因为这是最干净、最易读的方式。所有其他版本都有一些隐藏的逻辑,也不是 100% 清楚计算图和梯度传播会发生什么。

关于#c:对于实际完成的事情来说似乎有点复杂,并且还可能引入一些开销,但我不确定。

编辑:既然评论中有人问为什么不直接使用 .clone().

来自pytorch docs

Unlike copy_(), this function is recorded in the computation graph. Gradients propagating to the cloned tensor will propagate to the original tensor.

因此,在 .clone() returns 数据副本的同时,它会保留计算图并在其中记录克隆操作。如前所述,这将导致传播到克隆张量的梯度也传播到原始张量。此行为可能导致错误并且并不明显。由于这些可能的副作用,如果明确需要此行为,则只能通过 .clone() 克隆张量。为了避免这些副作用,添加了 .detach() 以断开计算图与克隆张量的连接。

由于通常对于复制操作,人们想要一个不会导致不可预见的副作用的干净副本,因此复制张量的首选方法是 .clone().detach()

TL;DR

使用.clone().detach()(或者最好是.detach().clone()

If you first detach the tensor and then clone it, the computation path is not copied, the other way around it is copied and then abandoned. Thus, .detach().clone() is very slightly more efficient.-- pytorch forums

因为它的功能稍快且明确。


使用perflot,我绘制了复制pytorch张量的各种方法的时间。

y = tensor.new_tensor(x) # method a

y = x.clone().detach() # method b

y = torch.empty_like(x).copy_(x) # method c

y = torch.tensor(x) # method d

y = x.detach().clone() # method e

x轴是创建的tensor的维度,y轴是时间。该图是线性比例的。正如您可以清楚地看到,与其他三种方法相比,tensor()new_tensor() 需要更多时间。

注意: 在多次运行中,我注意到在 b、c、e 中,任何方法的时间都可能最短。 a和d也是一样。但是,方法 b、c、e 的时间始终低于 a 和 d。

import torch
import perfplot

perfplot.show(
    setup=lambda n: torch.randn(n),
    kernels=[
        lambda a: a.new_tensor(a),
        lambda a: a.clone().detach(),
        lambda a: torch.empty_like(a).copy_(a),
        lambda a: torch.tensor(a),
        lambda a: a.detach().clone(),
    ],
    labels=["new_tensor()", "clone().detach()", "empty_like().copy()", "tensor()", "detach().clone()"],
    n_range=[2 ** k for k in range(15)],
    xlabel="len(a)",
    logx=False,
    logy=False,
    title='Timing comparison for copying a pytorch tensor',
)

检查张量是否被复制的一个例子:

import torch
def samestorage(x,y):
    if x.storage().data_ptr()==y.storage().data_ptr():
        print("same storage")
    else:
        print("different storage")
a = torch.ones((1,2), requires_grad=True)
print(a)
b = a
c = a.data
d = a.detach()
e = a.data.clone()
f = a.clone()
g = a.detach().clone()
i = torch.empty_like(a).copy_(a)
j = torch.tensor(a) # UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


print("a:",end='');samestorage(a,a)
print("b:",end='');samestorage(a,b)
print("c:",end='');samestorage(a,c)
print("d:",end='');samestorage(a,d)
print("e:",end='');samestorage(a,e)
print("f:",end='');samestorage(a,f)
print("g:",end='');samestorage(a,g)
print("i:",end='');samestorage(a,i)

输出:

tensor([[1., 1.]], requires_grad=True)
a:same storage
b:same storage
c:same storage
d:same storage
e:different storage
f:different storage
g:different storage
i:different storage
j:different storage

如果不同的存储出现,张量被复制。 PyTorch 有将近 100 种不同的构造函数,所以你可以添加更多的方法。

如果我需要复制张量,我会使用 copy(),这也会复制 AD 相关信息,所以如果我需要删除 AD 相关信息,我会使用:

y = x.clone().detach()