torch.addmm 收到无效的参数组合

torch.addmm received an invalid combination of arguments

在pytorch的官网上看到了下面的代码和答案:

>> a = torch.randn(4, 4)
>> a

0.0692  0.3142  1.2513 -0.5428
0.9288  0.8552 -0.2073  0.6409
1.0695 -0.0101 -2.4507 -1.2230
0.7426 -0.7666  0.4862 -0.6628
torch.FloatTensor of size 4x4]

>>> torch.max(a, 1)
(
 1.2513
 0.9288
 1.0695
 0.7426
[torch.FloatTensor of size 4]
,
 2
 0
 0
 0
[torch.LongTensor of size 4]
)

我知道第一个结果对应每行的最大数量,但是我没有得到第二个张量(LongTensor)

我尝试了其他随机示例,在 pytorch.max 之后,我找到了这些结果

0.9477  1.0090  0.8348 -1.3513
-0.4861  1.2581  0.3972  1.5751
-1.2277 -0.6201 -1.0553  0.6069
 0.1688  0.1373  0.6544 -0.7784
[torch.FloatTensor of size 4x4]

(
 1.0090
 1.5751
 0.6069
 0.6544
[torch.FloatTensor of size 4]
, 
 1
 3
 3
 2
[torch.LongTensor of size 4]
)

谁能告诉我这些 LongTensor 数据到底是什么意思?我认为这是张量之间的奇怪转换,但是在对浮点张量进行简单转换后,我发现它只是削减了小数位

谢谢

它只是告诉 index 原始张量中最大元素沿查询维度的

例如

0.9477  1.0090  0.8348 -1.3513
-0.4861  1.2581  0.3972  1.5751
-1.2277 -0.6201 -1.0553  0.6069
 0.1688  0.1373  0.6544 -0.7784
[torch.FloatTensor of size 4x4]

# torch.max(a, 1)
(
 1.0090
 1.5751
 0.6069
 0.6544
[torch.FloatTensor of size 4]
, 
 1
 3
 3
 2
[torch.LongTensor of size 4]
)

在上面的例子中torch.LongTensor

11.0090 在你原来的张量中的索引 (torch.FloatTensor)
31.5751 在你原来的张量中的索引 (torch.FloatTensor)
30.6069 在你原来的张量中的索引 (torch.FloatTensor)
20.6544 在你原来的张量中的索引 (torch.FloatTensor)

沿着 维度 1


相反,如果您请求 torch.max(a, 0)torch.LongTensor 中的条目将对应于原始张量 维度中最大元素的 indices 0.