Python 中多维矩阵的多维索引

Multidimensional Indexing of an multidemensional matrix in Python

我想用一个多维索引矩阵访问另一个多维矩阵。 我的问题是,由于广播问题(形状不匹配),np.newaxis 等方法无法正常工作。

我的数据矩阵的形状为 (5001, 3, 240, 16)。

import numpy as np

# n_examples, n_channels, n_pictures, n_meta_information
data = np.ones((5001, 3, 240, 16))

# select randomly 32 examples
batch_size = 32
possible_indices = np.arange(5001, dtype=np.int)
random_example_indices = np.random.choice(possible_indices, size=batch_size, replace=True)


# select all three channels
n_channels = 3
channel_indices = np.arange(n_channels)
#channel_indices = np.expand_dims(channel_indices , axis=0)
#channel_indices = np.repeat(channel_indices, batch_size, axis=0)

final_pictures_indices = []
for random_sample_idx in range(batch_size):
    # select a random start index and take 120 successive indices
    # is the same for all three channels
    start_index = np.random.randint(0, max(1, 240 - 120 + 1))
    end_index = start_index + 120
    pictures_indices = np.arange(start_index , end_index , dtype=np.int)
    final_pictures_indices.append(pictures_indices)

# batch_size x n_pictures
final_pictures_indices = np.array(final_pictures_indices)


# should have the shape: (32, 3, 120, 16)
result = data[random_example_indices[:, np.newaxis], channel_indices, final_pictures_indices].shape
print(result)

不幸的是,我收到以下错误:

result = data[random_example_indices[:, np.newaxis], channel_indices, final_pictures_indices].shape
IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (32,1) (3,) (32,120)

我也曾尝试将所有索引信息融合到一个矩阵中,但我遇到了无法堆叠具有不同形状的矩阵的问题。

谢谢你的每一个提示。

您的索引数组需要能够相互广播,即从右边开始,要么具有您预期的最终维度 1,要么不存在。所以:

result = data[random_example_indices[:, np.newaxis, np.newaxis],   # (32,   1,    1)
              channel_indices[:, np.newaxis],                      #       (3,    1)
              final_pictures_indices[:, np.newaxis, :].shape       # (32,   1,  120)

这应该会让你得到预期的形状。