将列表和标量列表转换为 PyTorch 张量列表会引发警告

Converting a list of lists and scalars to a list of PyTorch tensors throws warning

我正在将列表的列表转换为 PyTorch 张量并收到一条警告消息。转换本身并不困难。例如:

>>> import torch
>>> thing = [[1, 2, 3, 4, 5], [2, 3], 2, 3]
>>> thing_tensor = list(map(torch.tensor, thing))

我收到警告:

home/user1/files/module.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

我想知道可能是什么原因。有没有其他方法可以将数据转换为我不知道的张量?谢谢

我试图重现您的警告,但没有成功。但是,如果我用张量替换 thing 中的列表,我可以通过创建得到相同的警告。

我将讨论为什么使用 x.clone().detach() 而不是 torch.tensor(x) 来制作副本更好:

在我的 pytorch 版本上使用 torch.tensor 将创建一个不再与计算图相关并且在内存中不占据相同位置的副本。但是,此行为可能会在未来的版本中发生变化,这就是为什么您应该使用将保持有效的命令。我将说明音符分离或在内存中占据相同位置所带来的问题。

不分离:

x = torch.tensor([0.],requires_grad=True)
y = x.clone()
y[0] = 1
z = 2 * y
z.backward()
print(x, x.grad)
tensor([0.], requires_grad=True) tensor([0.])

正如你所看到的,在对 y 进行计算时,x 的梯度正在更新,但是改变 y 的值不会改变 x 的值,因为它们不占用相同的内存space.

占用内存相同 space :

x = torch.tensor([0.],requires_grad=True)
y = x.detach().requires_grad_(True)
z = 2 * y
z.backward()
y[0] = 1
print(x, x.grad)
tensor([1.], requires_grad=True) None

在这种情况下,梯度不会更新,但改变 y 的值会改变 x 的值,因为它们占用相同的内存 space。

最佳实践:

正如警告所建议的,最佳做法是分离和克隆张量:

x = torch.tensor([0.],requires_grad=True)
y = x.clone().detach().requires_grad_(True)
z = 2 * y
z.backward()
y[0] = 1
print(x, x.grad)
tensor([0.], requires_grad=True) None

这确保了来自 y 的未来修改和计算不会影响 x

@StatisticDean 有一个很好的答案,我将针对您正在做的事情添加一点:

“我正在将列表的列表转换为 PyTorch 张量”——这根本不是正在发生的事情。您的示例代码将数字列表的列表转换为张量列表。打印出来thing_tensor,应该是:

[tensor([1, 2, 3, 4, 5]), tensor([2, 3]), tensor(2), tensor(3)] 

这是因为 map 在 top-level 列表的每个元素上调用 torch.tensor,创建单独的张量。此外,这运行没有任何错误。

可能发生的情况是,您首先尝试 torch.tensor(thing) 一次性转换列表列表,但出现错误 ValueError: expected sequence of length 5 at dim 1 (got 2)。原因是张量必须是矩形的——例如,对于二维张量,每个 row/column 应该是相同的大小。如果不更改某些元素的大小,您实际上无法将列表的列表转换为张量。

仅通过一次调用将列表的列表转换为单个张量的示例:

torch.tensor([[11,12,13],[21,22,23]])

工作正常,因为每行大小为 3,每列大小为 2;不需要地图。

它在我的 Pytorch 环境中运行良好。 我认为您遇到的警告是由于其他原因造成的,例如 Pytorch 版本或 python 版本等

This is my result to run your code without any correction.