torch.flatten() 和 nn.Flatten() 之间的区别
Difference between torch.flatten() and nn.Flatten()
torch.flatten()
和torch.nn.Flatten()
有什么区别?
PyTorch 提供三种形式的展平
作为张量方法(oop 样式)torch.Tensor.flatten
直接应用于张量:x.flatten()
.
作为函数(函数形式)torch.flatten
应用为:torch.flatten(x)
.
作为模块(层nn.Module
)nn.Flatten()
。通常用于模型定义。
这三个是相同的并且共享相同的实现,唯一的区别是 nn.Flatten
默认将 start_dim
设置为 1
以避免压扁第一个轴(通常是批处理轴) ).而其他两个从 axis=0
变平到 axis=-1
- 即 整个张量 - 如果没有给出参数。
你可以把 torch.flatten()
的工作想象成简单地对张量进行展平操作,没有任何附加条件。你给一个张量,它变平,然后 returns 它。仅此而已。
相反,nn.Flatten()
复杂得多(即,它是一个神经网络层)。作为面向对象,它继承自 nn.Module
,尽管它 internally uses the plain tensor.flatten() OP in the forward()
method 用于展平张量。你可以把它想象成 torch.flatten()
.
的语法糖
重要区别:一个显着的区别是 torch.flatten()
总是 return 作为结果的一维张量,前提是输入至少是 1D 或更大,而 nn.Flatten()
always return 是 2D 张量,前提是输入至少是 2D 或更大(随着一维张量作为输入,它会抛出一个 IndexError).
比较:
torch.flatten()
是一个 API 而 nn.Flatten()
是一个神经网络层。
torch.flatten()
是一个 python 函数 而 nn.Flatten()
是一个 python class.
因为以上几点,nn.Flatten()
comes with lot of methods and attributes
torch.flatten()
可以在野外使用(例如,用于简单的张量 OP),而 nn.Flatten()
预计将在 nn.Sequential()
块中用作一个层数。
torch.flatten()
没有关于计算图的信息,除非它被卡在其他图形感知块中(tensor.requires_grad
标志设置为 True
),而 nn.Flatten()
总是被 autograd 跟踪。
torch.flatten()
无法接受和处理(例如,linear/conv1D)层作为输入,而 nn.Flatten()
主要用于处理这些神经网络层。
同时 torch.flatten()
和 nn.Flatten()
return views 输入张量。因此,对结果的任何修改也会影响输入张量。 (见下方代码)
代码演示:
# input tensors to work with
In [109]: t1 = torch.arange(12).reshape(3, -1)
In [110]: t2 = torch.arange(12, 24).reshape(3, -1)
In [111]: t3 = torch.arange(12, 36).reshape(3, 2, -1) # 3D tensor
用torch.flatten()
展平:
In [113]: t1flat = torch.flatten(t1)
In [114]: t1flat
Out[114]: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
# modification to the flattened tensor
In [115]: t1flat[-1] = -1
# input tensor is also modified; thus flattening is a view.
In [116]: t1
Out[116]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, -1]])
用nn.Flatten()
展平:
In [123]: nnfl = nn.Flatten()
In [124]: t3flat = nnfl(t3)
# note that the result is 2D, as opposed to 1D with torch.flatten
In [125]: t3flat
Out[125]:
tensor([[12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27],
[28, 29, 30, 31, 32, 33, 34, 35]])
# modification to the result
In [126]: t3flat[-1, -1] = -1
# input tensor also modified. Thus, flattened result is a view.
In [127]: t3
Out[127]:
tensor([[[12, 13, 14, 15],
[16, 17, 18, 19]],
[[20, 21, 22, 23],
[24, 25, 26, 27]],
[[28, 29, 30, 31],
[32, 33, 34, -1]]])
花絮:torch.flatten()
是 nn.Flatten()
及其 brethren nn.Unflatten()
since it existed from the very beginning. Then, there was a legitimate use-case for nn.Flatten()
, since this is a common requirement for almost all ConvNets (just before the softmax or elsewhere). So it was added later on in the PR #22245 的前身。
torch.flatten()
和torch.nn.Flatten()
有什么区别?
PyTorch 提供三种形式的展平
作为张量方法(oop 样式)
torch.Tensor.flatten
直接应用于张量:x.flatten()
.作为函数(函数形式)
torch.flatten
应用为:torch.flatten(x)
.作为模块(层
nn.Module
)nn.Flatten()
。通常用于模型定义。
这三个是相同的并且共享相同的实现,唯一的区别是 nn.Flatten
默认将 start_dim
设置为 1
以避免压扁第一个轴(通常是批处理轴) ).而其他两个从 axis=0
变平到 axis=-1
- 即 整个张量 - 如果没有给出参数。
你可以把 torch.flatten()
的工作想象成简单地对张量进行展平操作,没有任何附加条件。你给一个张量,它变平,然后 returns 它。仅此而已。
相反,nn.Flatten()
复杂得多(即,它是一个神经网络层)。作为面向对象,它继承自 nn.Module
,尽管它 internally uses the plain tensor.flatten() OP in the forward()
method 用于展平张量。你可以把它想象成 torch.flatten()
.
重要区别:一个显着的区别是 torch.flatten()
总是 return 作为结果的一维张量,前提是输入至少是 1D 或更大,而 nn.Flatten()
always return 是 2D 张量,前提是输入至少是 2D 或更大(随着一维张量作为输入,它会抛出一个 IndexError).
比较:
torch.flatten()
是一个 API 而nn.Flatten()
是一个神经网络层。torch.flatten()
是一个 python 函数 而nn.Flatten()
是一个 python class.因为以上几点,
nn.Flatten()
comes with lot of methods and attributestorch.flatten()
可以在野外使用(例如,用于简单的张量 OP),而nn.Flatten()
预计将在nn.Sequential()
块中用作一个层数。torch.flatten()
没有关于计算图的信息,除非它被卡在其他图形感知块中(tensor.requires_grad
标志设置为True
),而nn.Flatten()
总是被 autograd 跟踪。torch.flatten()
无法接受和处理(例如,linear/conv1D)层作为输入,而nn.Flatten()
主要用于处理这些神经网络层。同时
torch.flatten()
和nn.Flatten()
return views 输入张量。因此,对结果的任何修改也会影响输入张量。 (见下方代码)
代码演示:
# input tensors to work with
In [109]: t1 = torch.arange(12).reshape(3, -1)
In [110]: t2 = torch.arange(12, 24).reshape(3, -1)
In [111]: t3 = torch.arange(12, 36).reshape(3, 2, -1) # 3D tensor
用torch.flatten()
展平:
In [113]: t1flat = torch.flatten(t1)
In [114]: t1flat
Out[114]: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
# modification to the flattened tensor
In [115]: t1flat[-1] = -1
# input tensor is also modified; thus flattening is a view.
In [116]: t1
Out[116]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, -1]])
用nn.Flatten()
展平:
In [123]: nnfl = nn.Flatten()
In [124]: t3flat = nnfl(t3)
# note that the result is 2D, as opposed to 1D with torch.flatten
In [125]: t3flat
Out[125]:
tensor([[12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27],
[28, 29, 30, 31, 32, 33, 34, 35]])
# modification to the result
In [126]: t3flat[-1, -1] = -1
# input tensor also modified. Thus, flattened result is a view.
In [127]: t3
Out[127]:
tensor([[[12, 13, 14, 15],
[16, 17, 18, 19]],
[[20, 21, 22, 23],
[24, 25, 26, 27]],
[[28, 29, 30, 31],
[32, 33, 34, -1]]])
花絮:torch.flatten()
是 nn.Flatten()
及其 brethren nn.Unflatten()
since it existed from the very beginning. Then, there was a legitimate use-case for nn.Flatten()
, since this is a common requirement for almost all ConvNets (just before the softmax or elsewhere). So it was added later on in the PR #22245 的前身。