在 CuPy 中替换 numpy.apply_along_axis

Replacement for numpy.apply_along_axis in CuPy

我有一个基于 NumPy 的神经网络,我正在尝试将其移植到 CuPy。我有一个函数如下:

import numpy as np

def tensor_diag(x): return np.apply_along_axis(np.diag, -1, x)

# Usage: (x is a matrix, i.e. a 2-tensor)
def sigmoid_prime(x): return tensor_diag(sigmoid(x) * (1 - sigmoid(x)))

这可以使用 NumPy,但 CuPy 没有该函数的类似物(自 2020 年 5 月 8 日起不受支持)。我如何在 CuPy 中模拟这种行为?

In [284]: arr = np.arange(24).reshape(2,3,4)                                                           

np.diag 采用一维数组,returns 采用对角线上的值的二维。 apply_along_axis 只是迭代除最后一个维度之外的所有维度,并将最后一个数组一次传递给 diag:

In [285]: np.apply_along_axis(np.diag,-1,arr)                                                          
Out[285]: 
array([[[[ 0,  0,  0,  0],
         [ 0,  1,  0,  0],
         [ 0,  0,  2,  0],
         [ 0,  0,  0,  3]],

        [[ 4,  0,  0,  0],
         [ 0,  5,  0,  0],
         [ 0,  0,  6,  0],
         [ 0,  0,  0,  7]],

        [[ 8,  0,  0,  0],
         [ 0,  9,  0,  0],
         [ 0,  0, 10,  0],
         [ 0,  0,  0, 11]]],


       [[[12,  0,  0,  0],
         [ 0, 13,  0,  0],
         [ 0,  0, 14,  0],
         [ 0,  0,  0, 15]],

        [[16,  0,  0,  0],
         [ 0, 17,  0,  0],
         [ 0,  0, 18,  0],
         [ 0,  0,  0, 19]],

        [[20,  0,  0,  0],
         [ 0, 21,  0,  0],
         [ 0,  0, 22,  0],
         [ 0,  0,  0, 23]]]])
In [286]: _.shape                                                                                      
Out[286]: (2, 3, 4, 4)

我可以做同样的映射:

In [287]: res = np.zeros((2,3,4,4),int)                                                                
In [288]: res[:,:,np.arange(4),np.arange(4)] = arr                                                     

检查 apply 结果:

In [289]: np.allclose(_285, res)                                                                       
Out[289]: True

或者对于 apply 的更直接复制,使用 np.ndindex 生成所有 i,j 元组对以迭代 arr 的前两个维度:

In [298]: res = np.zeros((2,3,4,4),int)                                                                
In [299]: for ij in np.ndindex(2,3): 
     ...:     res[ij]=np.diag(arr[ij]) 
     ...:                                                                                              
In [300]: np.allclose(_285, res)                                                                       
Out[300]: True