在pytorch中,用方括号索引和"index_select"有什么区别?

In pytorch, what is the difference between indexing with square brackets and "index_select"?

假设有两个 pytorch 张量 a,即 float32,形状为 [M, N],和 b,即 int64,形状为 [K]b 中的值在 [0, M-1] 范围内,因此下行给出了一个新的张量 c,索引为 b

c = a[b]    # [K, N] tensor whose i-th row is a[b[i]], with `IndexBackward`

但是在我的一个项目中,这一行总是报如下错误(用torch.autograd.detect_anomaly()检测到:

  with torch.autograd.detect_anomaly():
[W python_anomaly_mode.cpp:104] Warning: Error detected in IndexBackward. Traceback of forward call that caused the error:
...
File "/home/user/project/model/network.py", line 60, in index_points
    c = a[b]
 (function _print_stack)

Traceback (most recent call last):
  File "main.py", line 589, in <module>
    main()
  File "main.py", line 439, in main
    train_stats = train(
  File "/home/user/project/train_eval.py", line 866, in train
    total_loss.backward()
  File "/home/user/.local/lib/python3.8/site-packages/torch/_tensor.py", line 255, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/user/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 147, in backward
    Variable._execution_engine.run_backward(
RuntimeError: merge_sort: failed to synchronize: cudaErrorIllegalAddress: an illegal memory access was encountered

请注意,上面的 c = a[b] 不是唯一出现的 所述错误,而只是许多其他带有方括号索引的行之一。

但是,当我从

更改索引样式时,问题神奇地消失了
c = a[b]

c = a.index_select(0, b)

我不明白为什么用方括号索引会导致非法内存访问,但这让我有足够的理由相信方括号索引和 index_select 的实现方式不同。理解这一点可能是解释这一点的关键。另外,由于项目相当大而不是 public,我不能在这里分享确切的代码。您可以将上面的内容视为背景,并关注方括号索引和 index_select 的不同之处。谢谢!


附加信息:

torch.index_select returns 将索引字段复制到新内存位置的新张量 (docs)。

torch.Tensor.select 或切片 returns 原始张量 (docs) 的 view

在没有看到更多代码的情况下,很难说出为什么这种特殊的功能差异会导致上述错误。