mxnet中如何计算bincount

How to calculate bincount in mxnet

我正在尝试使用 mxnet 为神经网络实现我的客户层。我想知道mxnet中是否有类似np.bincount的功能。如果没有,有没有一种方法可以计算它而不必将我的 mx.ndarray 转换为 numpy

MXNet 没有这样的功能。您可以通过编写这样的循环来实现它:

bins = []
for i in range(max_value):
    bins.append(nd.sum(my_array == i))
bins = nd.concat(bins)

请记住,如果您使用 numpy,不仅会从 GPU 上下文切换到 CPU 并减慢计算速度,而且您也无法对计算进行反向传播。