切片的 numpy 重塑如何工作

How does numpy reshaping of a slice works

我正在尝试从头开始重新实现 numpy,但我无法弄清楚切片是如何工作的。以这个数组为例:

> a = np.array(list(range(0,20)))

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19])

让我们将它重塑为

> b = a.reshape(5,4)

array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15],
       [16, 17, 18, 19]])

这里的进步是有道理的

> b.strides

(32, 8) # float64 has 8 bytes, so row jumps by 4 floats and column moves by 1

> b.shape

(5, 4)

我可以使用 4*row+col 表达式索引这个数组

for row in range(b.shape[0]):
    for col in range(b.shape[1]):
        assert a[4*row+col] == b[row,col]

现在让我们切一片

> c = b[1:,1:]

array([[ 5,  6,  7],
       [ 9, 10, 11],
       [13, 14, 15],
       [17, 18, 19]])

> c.strides

(32, 8) # It has the same strides, which sounds logical.

> c.shape

(4, 3)

我想 c 只是将偏移量存储到 a。所以索引到这个数组将等同于 offset + 4*row+col

for row in range(c.shape[0]):
    for col in range(c.shape[1]):
        assert a[5 + 4*row+col] == c[row,col]

但真正的魔法发生在我执行重塑时。

d = c.reshape(3,4)

array([[ 5,  6,  7,  9],
       [10, 11, 13, 14],
       [15, 17, 18, 19]])

请注意 81216 缺失(如预期)。

然而,当我查询步幅时

> d.strides

(32, 8)

> d.shape

(3, 4)

他们看起来一样。这些进步显然是在撒谎。没有办法以这样的步伐索引到 a。那么这种魔法是如何发生的呢?一定是少了什么信息。

这必须以某种巧妙的方式实现。我不敢相信 numpy 会使用一些棘手的包装器对象。特别是如果我们考虑在 pytorch/tensorflow 和这些框架 运行 中应该在 CUDA 上实现完全相同的机制,这显然不能支持具有虚拟 类 和中间抽象层的高级 C++ 魔法。一定有一些简单明显的技巧可以在 GPU 上 运行。

我喜欢查看 a.__array_interface__,其中包含一个数据缓冲区 'pointer'。从中您可以看到 c 偏移量。我怀疑 d 是副本,不再是视图。这从界面上应该很明显。

In [155]: a=np.arange(20)
In [156]: a.__array_interface__
Out[156]: 
{'data': (34619344, False),
 'strides': None,
 'descr': [('', '<i8')],
 'typestr': '<i8',
 'shape': (20,),
 'version': 3}
In [157]: b = a.reshape(5,4)
In [158]: b.__array_interface__
Out[158]: 
{'data': (34619344, False),       # same
 'strides': None,
 'descr': [('', '<i8')],
 'typestr': '<i8',
 'shape': (5, 4),
 'version': 3}
In [159]: c = b[1:,1:]
In [160]: c.__array_interface__
Out[160]: 
{'data': (34619384, False),       # offset by 32+8
 'strides': (32, 8),
 'descr': [('', '<i8')],
 'typestr': '<i8',
 'shape': (4, 3),
 'version': 3}
In [161]: d=c.reshape(3,4)
In [162]: d.__array_interface__
Out[162]: 
{'data': (36537904, False),       # totally different
 'strides': None,
 'descr': [('', '<i8')],
 'typestr': '<i8',
 'shape': (3, 4),
 'version': 3}