如何提取 MXNet 中的典型行?

How to extract typical rows in MXNet?

这些是数据(批量大小 2)和批量索引

import mxnet as mx
data=mx.nd.array(range(24)).reshape(2,3,4)
index=mx.nd.array([[0,1],[1,2]])

如何获取选中的数据?我尝试了 Picktake 函数,但不知道该怎么做。

似乎gather_nd有效

mx.nd.gather_nd(data,mx.nd.array([[0,0,1,1],[0,1,1,2]])).reshape(2,2,4)