Numpy 3d 数组索引
Numpy 3d array indexing
我在下面的示例中有一个 3d numpy 数组 (n_samples x num_components x 2) n_samples = 5 和 num_components = 7.
我有另一个数组 (indices) 这是每个样本的 selected 组件,形状为 (n_samples ,).
我想从给定索引的数据数组中 select 得到的数组是 n_samples x 2.
代码如下:
import numpy as np
np.random.seed(77)
data=np.random.randint(low=0, high=10, size=(5, 7, 2))
indices = np.array([0, 1, 6, 4, 5])
#how can I select indices from the data array?
例如,对于数据 0,selected 分量应该是第 0 个,对于数据 1,selected 分量应该是 1。
请注意,我不能使用任何 for 循环,因为我在 Theano 中使用它,解决方案应该完全基于 numpy。
要获取组件 #0,请使用
data[:, 0]
即我们得到轴 0(样本)上的每个条目,轴 1(组件)上只有条目 #0,以及其余轴上的所有内容。
这可以很容易地概括为
data[:, indices]
到select所有相关组件。
但是OP真正想要的只是这个数组的对角线,即(data[0, indices[0]], (data[1, indices[1]]), ...)
高维数组的对角线可以使用diagonal
函数提取:
>>> np.diagonal(data[:, indices])
array([[7, 7, 4, 8, 5],
[4, 3, 5, 2, 8]])
(您可能需要转置结果。)
您有多种方法可以做到这一点,但这是我的循环推荐:
selection = np.array([ datum[indices[k]] for k,datum in enumerate(data)])
生成的数组 selection
具有所需的形状。
这是您要找的吗?
In [36]: data[np.arange(data.shape[0]),indices,:]
Out[36]:
array([[7, 4],
[7, 3],
[4, 5],
[8, 2],
[5, 8]])
我在下面的示例中有一个 3d numpy 数组 (n_samples x num_components x 2) n_samples = 5 和 num_components = 7.
我有另一个数组 (indices) 这是每个样本的 selected 组件,形状为 (n_samples ,).
我想从给定索引的数据数组中 select 得到的数组是 n_samples x 2.
代码如下:
import numpy as np
np.random.seed(77)
data=np.random.randint(low=0, high=10, size=(5, 7, 2))
indices = np.array([0, 1, 6, 4, 5])
#how can I select indices from the data array?
例如,对于数据 0,selected 分量应该是第 0 个,对于数据 1,selected 分量应该是 1。
请注意,我不能使用任何 for 循环,因为我在 Theano 中使用它,解决方案应该完全基于 numpy。
要获取组件 #0,请使用
data[:, 0]
即我们得到轴 0(样本)上的每个条目,轴 1(组件)上只有条目 #0,以及其余轴上的所有内容。
这可以很容易地概括为
data[:, indices]
到select所有相关组件。
但是OP真正想要的只是这个数组的对角线,即(data[0, indices[0]], (data[1, indices[1]]), ...)
高维数组的对角线可以使用diagonal
函数提取:
>>> np.diagonal(data[:, indices])
array([[7, 7, 4, 8, 5],
[4, 3, 5, 2, 8]])
(您可能需要转置结果。)
您有多种方法可以做到这一点,但这是我的循环推荐:
selection = np.array([ datum[indices[k]] for k,datum in enumerate(data)])
生成的数组 selection
具有所需的形状。
这是您要找的吗?
In [36]: data[np.arange(data.shape[0]),indices,:]
Out[36]:
array([[7, 4],
[7, 3],
[4, 5],
[8, 2],
[5, 8]])