如何从 ND4j 中的 NDArray select 一组给定的索引,类似于 numpy 的数组数据 [arrayIndex]?

How to select a given set of indexes from an NDArray in ND4j similarly to numpy's arraydata[arrayIndex]?

我正在使用 ND4j(当前版本 1.0.0-beta5)在 Java 中开发一个严重依赖数组操作的科学应用程序。在我的整个管道中,我需要动态 select [2,195102] 矩阵的一个非连续子集(更精确地说,有几个 tens/hundreds 列)。 知道如何在此框架中实现这一目标吗?

简而言之,我正在尝试实现这个python/numpy操作:

import numpy as np
arrayData = np.array([[1, 5, 0, 6, 2, 0, 9, 0, 5, 2],
       [3, 6, 1, 0, 4, 3, 1, 4, 8, 1]])
arrayIndex = np.array((1,5,6))
res  = arrayData[:, arrayIndex]
# res value is
# array([[5, 0, 9],
#        [6, 3, 1]])

到目前为止,我设法 select 使用 NDArray.getColumns function (along with the NDArray.data().asInt() from the indexArray to provide the values of the index). The problem is that the documentation explicitelly states, regarding the retrieval of information during a computation, "Note that THIS SHOULD NOT BE USED FOR SPEED" (see the documentation of NDArray.ToIntMatrix() 所需的列来查看完整消息 - 不同的方法,相同的操作)。

我查看了 NDArray.get() and none seem to fit the bill. I suppose that NDArray.getWhere() 的不同原型可能会起作用 - 如果它像我假设的那样只有 returns 满足条件的元素 - 但到目前为止,不成功在利用它。在解释所需的 arguments/usage 时,文档相对较少。

感谢大家的宝贵时间和帮助:)

编辑(2019 年 4 月 11 日): 关于我尝试过的一些精确度。我玩 NDArray.get() 并使用索引:

INDArray arrayData = Nd4j.create(new int[]
                    {1, 5, 0, 6, 2, 0, 9, 0, 5, 2,
                     3, 6, 1, 0, 4, 3, 1, 4, 8, 1},   new long[]{2, 10}, DataType.INT);
INDArray arrayIndex = Nd4j.create(new int[]{1, 5, 6}, new long[]{1,  3}, DataType.INT);

INDArray colSelection = null;

//index free version
colSelection = arrayData.getColumns(arrayIndex.toIntVector());
/*
* colSelection value is
* [[5, 0, 9],
*  [6, 3, 1]]
* but the toIntVector() call pulls the data from the back-end storage
* and re-inject them. That is presumed to be slow.
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 4001 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 5339 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 7016 ms for 100000 iterations
*/

//index version
colSelection = arrayData.get(NDArrayIndex.all(), NDArrayIndex.indices(arrayIndex.toLongVector()));
/*
* Same result, but same problem regarding toLongVector() this time around.
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 3200 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 4269 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 5252 ms for 100000 iterations
*/

//weird but functional version (that I just discovered)
colSelection = arrayData.transpose().get(arrayIndex); // the transpose operation is necessary to not hit an IllegalArgumentException: Illegal slice 5
// note that transposing the arrayIndex leads to an IllegalArgumentException: Illegal slice 6 (as it is trying to select the element at the line idx 1, column 5, depth 6, which does not exist)
/*
* colSelection value is
* [5, 6, 0, 3, 9, 1]
* The array is flattened... calling a reshape(arrayData.shape()[0],arrayIndex.shape()[1]) yields
* [[5, 6, 0],
*  [3, 9, 1]]
* which is wrong.
*/
colSelection = colSelection.reshape(arrayIndex.shape()[1],arrayData.shape()[0]).transpose();
/* yields the right result
* [[5, 0, 9],
*  [6, 3, 1]]
* While this seems to be the correct way to handle the memory the performance are low:
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 8225 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 8980 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 9453 ms for 100000 iterations
Plus, this is very roundabout method for such a "simple" operation
* if the repacking of the data is commented out, the timing become:
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 6987 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 7976 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 8336 ms for 100000 iterations
*/

在不知道我是什么机器的情况下,这些速度似乎还不错 运行,但等效的 python 代码产生:

那些 java 实现最多比 python-numpy 慢 20 倍。

org.nd4j.linalg.api.ndarray.INDArray arr = org.nd4j.linalg.factory.Nd4j.create(new double[][]{
                {1, 5, 0, 6, 2, 0, 9, 0, 5, 2},
                {3, 6, 1, 0, 4, 3, 1, 4, 8, 1}
        });

        org.nd4j.linalg.indexing.INDArrayIndex indices[] = {
                org.nd4j.linalg.indexing.NDArrayIndex.all(),
                new org.nd4j.linalg.indexing.SpecifiedIndex(1,5,6)
        };

        org.nd4j.linalg.api.ndarray.INDArray selected = arr.get(indices);
        System.out.println(selected);
    }

这应该适合你。这打印: SLF4J:无法加载 class "org.slf4j.impl.StaticLoggerBinder"。 SLF4J:默认为无操作 (NOP) 记录器实现 SLF4J:有关详细信息,请参阅 http://www.slf4j.org/codes.html#StaticLoggerBinder

[[    5.0000,         0,    9.0000], 
 [    6.0000,    3.0000,    1.0000]]

进程已完成,退出代码为 0