如何提取 MXNet 中的典型行?
How to extract typical rows in MXNet?
这些是数据(批量大小 2)和批量索引
import mxnet as mx
data=mx.nd.array(range(24)).reshape(2,3,4)
index=mx.nd.array([[0,1],[1,2]])
如何获取选中的数据?我尝试了 Pick
和 take
函数,但不知道该怎么做。
似乎gather_nd
有效
mx.nd.gather_nd(data,mx.nd.array([[0,0,1,1],[0,1,1,2]])).reshape(2,2,4)
这些是数据(批量大小 2)和批量索引
import mxnet as mx
data=mx.nd.array(range(24)).reshape(2,3,4)
index=mx.nd.array([[0,1],[1,2]])
如何获取选中的数据?我尝试了 Pick
和 take
函数,但不知道该怎么做。
似乎gather_nd
有效
mx.nd.gather_nd(data,mx.nd.array([[0,0,1,1],[0,1,1,2]])).reshape(2,2,4)