Numpy ndarray 的动态轴索引

Dynamic axis indexing of Numpy ndarray

我想获得 3D 阵列给定方向的 2D 切片,其中 direction(或要提取切片的轴)由另一个变量给出。

假设 idx 3D 数组中 2D 切片的索引,direction 获取该 2D 切片的轴,初始方法是:

if direction == 0:
    return A[idx, :, :]
elif direction == 1:
    return A[:, idx, :]
else:
    return A[:, :, idx]

我很确定一定有一种方法可以做到这一点而无需执行条件,或者至少不用原始 python。 numpy 有这方面的捷径吗?

到目前为止我找到的更好的解决方案(动态执行)依赖于转置运算符:

# for 3 dimensions [0,1,2] and direction == 1 --> [1, 0, 2]
tr = [direction] + range(A.ndim)
del tr[direction+1]

return np.transpose(A, tr)[idx]

但我想知道是否有任何 better/easier/faster 函数用于此,因为对于 3D,转置代码看起来几乎比 3 if/elif 更糟糕。它对 ND 的泛化效果更好,N 越大,相比之下代码越漂亮,但对于 3D 是完全一样的。

转置很便宜(按时间)。有 numpy 函数使用它将操作轴(或多个轴)移动到已知位置 - 通常是形状列表的前端或末尾。 tensordot 是我想到的。

其他函数构造索引元组。它们可能从列表或数组开始,以便于操作,然后将其变成元组以供应用。例如

I = [slice(None)]*A.ndim
I[axis] = idx
A[tuple(I)]

np.apply_along_axis 做了类似的事情。查看此类函数的代码很有启发性。

我想 numpy 函数的编写者最担心的是它是否运行稳定,其次是速度,最后是它是否漂亮。一个函数就可以埋下各种丑陋的代码!


tensordot 结尾为

at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = dot(at, bt)
return res.reshape(olda + oldb)

前面的代码计算了 newaxes_..newshape....

apply_along_axis构造一个(0...,:,0...)索引元组

i = zeros(nd, 'O')
i[axis] = slice(None, None)
i.put(indlist, ind)
....arr[tuple(i.tolist())]

要动态索引维度,可以使用swapaxes,如下图:

a = np.arange(7 * 8 * 9).reshape((7, 8, 9))

axis = 1
idx = 2

np.swapaxes(a, 0, axis)[idx]

运行时比较

自然法(非动态):

%timeit a[:, idx, :]
300 ns ± 1.58 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

交换轴:

%timeit np.swapaxes(a, 0, axis)[idx]
752 ns ± 4.54 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

具有列表理解的索引:

%timeit a[[idx if i==axis else slice(None) for i in range(a.ndim)]]