numpy take 无法使用切片进行索引

numpy take can't index using slice

根据 take 的 numpy 文档 它与“花式”索引(使用数组索引数组)做同样的事情。但是,如果您需要沿着给定轴的元素,它会更容易使用。

但是,与 "fancy" 或常规 numpy 索引不同,似乎不支持使用切片作为索引:

In [319]: A = np.arange(20).reshape(4, 5)

In [320]: A[..., 1:4]
Out[320]: 
array([[ 1,  2,  3],
       [ 6,  7,  8],
       [11, 12, 13],
       [16, 17, 18]])

In [321]: np.take(A, slice(1, 4), axis=-1)
TypeError: long() argument must be a string or a number, not 'slice'

使用仅在运行时已知的轴上的切片来索引数组的最佳方法是什么?

我认为你的意思是:

In [566]: np.take(A, slice(1,4))
...
TypeError: int() argument must be a string or a number, not 'slice'

但是

np.take(A, np.r_[1:4])

就像 A[1:4]

一样工作

np.insertnp.apply_along_axis 等函数通过构造可能包含标量、切片和数组的索引元组来实现通用性。

ind = tuple([slice(1,4)])  # ndim terms to match array
A[ind]

np.tensordot是使用np.transpose将动作轴移动到最后的例子(供np.dot使用)。

另一个技巧是通过重塑将所有 'surplus' 轴折叠成一个。然后重塑回来。

According to the numpy docs for take it does the same thing as “fancy” indexing (indexing arrays using arrays).

np.take 的第二个参数必须是 类数组 (数组、列表、元组等),而不是 slice目的。您可以构建一个索引数组或列表来执行您想要的切片:

a = np.arange(24).reshape(2, 3, 4)

np.take(a, slice(1, 4, 2), 2)
# TypeError: long() argument must be a string or a number, not 'slice'

np.take(a, range(1, 4, 2), 2)
# array([[[ 1,  3],
#         [ 5,  7],
#         [ 9, 11]],

#        [[13, 15],
#         [17, 19],
#         [21, 23]]])

What is the best way to index an array using slices along an axis only known at runtime?

我经常喜欢做的是使用np.rollaxis将要索引的轴设置为第一个轴,进行索引,然后将其回滚到原来的位置。

例如,假设我想要 3D 阵列沿其第 3 轴的奇数切片:

sliced1 = a[:, :, 1::2]

如果我想在运行时指定要沿其切片的轴,我可以这样做:

n = 2    # axis to slice along

sliced2 = np.rollaxis(np.rollaxis(a, n, 0)[1::2], 0, n + 1)

assert np.all(sliced1 == sliced2)

稍微拆开那一行:

# roll the nth axis to the 0th position
np.rollaxis(a, n, 0)

# index odd-numbered slices along the 0th axis
np.rollaxis(a, n, 0)[1::2]

# roll the 0th axis back so that it lies before position n + 1 (note the '+ 1'!)
np.rollaxis(np.rollaxis(a, n, 0)[1::2], 0, n + 1)

最有效的方法似乎是A[(slice(None),) * axis + (slice(1, 4),)]:

In [19]: import numpy as np
    ...: x = np.random.normal(0, 1, (50, 50, 50))
    ...: s = slice(10, 20)
    ...: axis = 2
    ...: 
    ...: 

In [20]: timeit np.rollaxis(np.rollaxis(x, axis, 0)[s], 0, axis + 1)
2.32 µs ± 15.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [21]: timeit x.take(np.arange(x.shape[axis])[s], axis)
28.5 µs ± 38.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [22]: timeit x[(slice(None),) * axis + (s,)]
321 ns ± 0.341 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)