python numpy maxpool:给定一个数组和来自 argmax 的索引,returns 最大值

python numpy maxpool: given an array and indices from argmax, returns max values

假设我有一个名为 view:

的数组
array([[[[ 7,  9],
         [10, 11]],

        [[19, 18],
         [20, 16]]],


       [[[24,  5],
         [ 6, 10]],

        [[18, 11],
         [45, 12]]]])

你可能从 maxpooling 知道,这是原始输入的视图,内核大小为 2x2:

[[ 7,  9],  [[19, 18],
 [10, 11]],  [20, 16]]], ....

目标是找到最大值及其索引。但是,argmax 仅适用于单轴,所以我需要 flatten view,即使用 flatten=view.reshape(2,2,4):

array([[[ 7,  9, 10, 11], [19, 18, 20, 16]],

       [[24,  5,  6, 10], [18, 11, 45, 12]]])

现在,在 的帮助下,我可以使用 inds = flatten.argmax(-1):

找到最大值的索引
array([[3, 2],
       [0, 2]])

和最大值:

i, j = np.indices(flatten.shape[:-1])
flatten[i, j, inds]

>>> array([[11, 20],
           [24, 45]])

问题
当我 flatten view 数组时出现问题。由于 view 数组是原始数组的视图,即 view = as_strided(original, newshape, newstrides),因此 vieworiginal 共享相同的数据。但是,reshape 会破坏它,因此 view 上的任何更改都不会反映在 original 上。这在反向传播过程中是有问题的。

我的问题
给定数组 view 和索引 ind,我想将 view 中的最大值更改为 1000,而不使用整形或任何破坏 'bond' 之间的操作 vieworiginal。感谢您的帮助!!!

可重现的例子

import numpy as np
from numpy.lib.stride_tricks import as_strided

original=np.array([[[7,9,19,18],[10,11,20,16]],[[24,5,18,11],[6,10,45,12]]],dtype=np.float64)
view=as_strided(original, shape=(2,1,2,2,2),strides=(64,32*2,8*2,32,8))

我想将 view 中每个内核的最大值更改为 1000,这可以反映在 original 上,即如果我 运行 view[0,0,0,0,0]=1000,那么view和original的第一个元素都是1000.

这个怎么样:

import numpy as np
view = np.array(
    [[[[ 7,  9],
       [10, 11]],
      [[19, 18],
       [20, 16]]],
     [[[24,  5],
       [ 6, 10]],
      [[18, 11],
       [45, 12]]]]
)
# Getting the indices of the max values
max0 = view.max(-2)
idx2 = view.argmax(-2)
idx2 = idx2.reshape(-1, idx2.shape[1])
max1 = max0.max(-1)
idx3 = max0.argmax(-1).flatten()
idx2 = idx2[np.arange(idx3.size), idx3]

idx0 = np.arange(view.shape[0]).repeat(view.shape[1])
idx1 = np.arange(view.shape[1]).reshape(1, -1).repeat(view.shape[0], 0).flatten()

# Replacing the maximal vlues with 1000
view[idx0, idx1, idx2, idx3] = 1000
print(f'view = \n{view}')

输出:

view = 
[[[[   7    9]
   [  10 1000]]

  [[  19   18]
   [1000   16]]]


 [[[1000    5]
   [   6   10]]

  [[  18   11]
   [1000   12]]]]

基本上,idx{n}是前两个维度中包含的每个矩阵的最后两个维度中的最大值的索引。