Numpy:没有循环的多轴上的argmax
Numpy: argmax over multiple axes without loop
我有一个 N 维数组(名为 A)。对于 A 的第一轴的每一行,我想获得沿 A 的其他轴的最大值的坐标。然后我将 return 一个二维数组,其中每行的最大值的坐标A的第一个轴。
我已经使用循环解决了我的问题,但我想知道是否有更有效的方法来执行此操作。我当前的解决方案(例如数组 A)如下:
import numpy as np
A=np.reshape(np.concatenate((np.arange(0,12),np.arange(0,-4,-1))),(4,2,2))
maxpos=np.empty(shape=(4,2))
for n in range(0, 4):
maxpos[n,:]=np.unravel_index(np.argmax(A[n,:,:]), A[n,:,:].shape)
在这里,我们将有:
A:
[[[ 0 1]
[ 2 3]]
[[ 4 5]
[ 6 7]]
[[ 8 9]
[10 11]]
[[ 0 -1]
[-2 -3]]]
maxpos:
[[ 1. 1.]
[ 1. 1.]
[ 1. 1.]
[ 0. 0.]]
如果有多个最大化器,我不介意选择哪个。
我试过np.apply_over_axes
,但没能return达到我想要的结果。
你可以这样做 -
# Reshape input array to a 2D array with rows being kept as with original array.
# Then, get idnices of max values along the columns.
max_idx = A.reshape(A.shape[0],-1).argmax(1)
# Get unravel indices corresponding to original shape of A
maxpos_vect = np.column_stack(np.unravel_index(max_idx, A[0,:,:].shape))
示例 运行 -
In [214]: # Input array
...: A = np.random.rand(5,4,3,7,8)
In [215]: # Setup output array and use original loopy code
...: maxpos=np.empty(shape=(5,4)) # 4 because ndims in A is 5
...: for n in range(0, 5):
...: maxpos[n,:]=np.unravel_index(np.argmax(A[n,:,:,:,:]), A[n,:,:,:,:].shape)
...:
In [216]: # Proposed approach
...: max_idx = A.reshape(A.shape[0],-1).argmax(1)
...: maxpos_vect = np.column_stack(np.unravel_index(max_idx, A[0,:,:].shape))
...:
In [219]: # Verify results
...: np.array_equal(maxpos.astype(int),maxpos_vect)
Out[219]: True
推广到 n 维数组
我们可以概括求解 n-dim 数组以获得最后 N
轴的 argmax
结合这样的东西 -
def argmax_lastNaxes(A, N):
s = A.shape
new_shp = s[:-N] + (np.prod(s[-N:]),)
max_idx = A.reshape(new_shp).argmax(-1)
return np.unravel_index(max_idx, s[-N:])
结果将是索引数组的元组。如果您需要最终输出为数组,我们可以使用 np.stack
或 np.concatenate
.
您可以使用列表理解
result = [np.unravel_index(np.argmax(r), r.shape) for r in a]
IMO 更具可读性,但速度不会比显式循环好多少。
只有在第一个维度实际上是非常大的维度时,主外循环在 Python 中这一事实才有意义。
如果是这种情况(即你有 1000 万个 2x2 矩阵),那么翻转会更快...
# true if 0,0 is not smaller than others
m00 = ((data[:,0,0] >= data[:,1,0]) &
(data[:,0,0] >= data[:,0,1]) &
(data[:,0,0] >= data[:,1,1]))
# true if 0,1 is not smaller than others
m01 = ((data[:,0,1] >= data[:,1,0]) &
(data[:,0,1] >= data[:,0,0]) &
(data[:,0,1] >= data[:,1,1]))
# true if 1,0 is not smaller than others
m10 = ((data[:,1,0] >= data[:,0,0]) &
(data[:,1,0] >= data[:,0,1]) &
(data[:,1,0] >= data[:,1,1]))
# true if 1,1 is not smaller than others
m11 = ((data[:,1,1] >= data[:,1,0]) &
(data[:,1,1] >= data[:,0,1]) &
(data[:,1,1] >= data[:,0,0]))
# choose which is max on equality
m01 &= ~m00
m10 &= ~(m00|m01)
m11 &= ~(m00|m01|m10)
# compute result
result = np.zeros((len(data), 2), np.int32)
result[:,1] |= m01|m11
result[:,0] |= m10|m11
在我的机器上,上面的代码大约快 50 倍(对于一百万个 2x2 矩阵)。
我有一个 N 维数组(名为 A)。对于 A 的第一轴的每一行,我想获得沿 A 的其他轴的最大值的坐标。然后我将 return 一个二维数组,其中每行的最大值的坐标A的第一个轴。
我已经使用循环解决了我的问题,但我想知道是否有更有效的方法来执行此操作。我当前的解决方案(例如数组 A)如下:
import numpy as np
A=np.reshape(np.concatenate((np.arange(0,12),np.arange(0,-4,-1))),(4,2,2))
maxpos=np.empty(shape=(4,2))
for n in range(0, 4):
maxpos[n,:]=np.unravel_index(np.argmax(A[n,:,:]), A[n,:,:].shape)
在这里,我们将有:
A:
[[[ 0 1]
[ 2 3]]
[[ 4 5]
[ 6 7]]
[[ 8 9]
[10 11]]
[[ 0 -1]
[-2 -3]]]
maxpos:
[[ 1. 1.]
[ 1. 1.]
[ 1. 1.]
[ 0. 0.]]
如果有多个最大化器,我不介意选择哪个。
我试过np.apply_over_axes
,但没能return达到我想要的结果。
你可以这样做 -
# Reshape input array to a 2D array with rows being kept as with original array.
# Then, get idnices of max values along the columns.
max_idx = A.reshape(A.shape[0],-1).argmax(1)
# Get unravel indices corresponding to original shape of A
maxpos_vect = np.column_stack(np.unravel_index(max_idx, A[0,:,:].shape))
示例 运行 -
In [214]: # Input array
...: A = np.random.rand(5,4,3,7,8)
In [215]: # Setup output array and use original loopy code
...: maxpos=np.empty(shape=(5,4)) # 4 because ndims in A is 5
...: for n in range(0, 5):
...: maxpos[n,:]=np.unravel_index(np.argmax(A[n,:,:,:,:]), A[n,:,:,:,:].shape)
...:
In [216]: # Proposed approach
...: max_idx = A.reshape(A.shape[0],-1).argmax(1)
...: maxpos_vect = np.column_stack(np.unravel_index(max_idx, A[0,:,:].shape))
...:
In [219]: # Verify results
...: np.array_equal(maxpos.astype(int),maxpos_vect)
Out[219]: True
推广到 n 维数组
我们可以概括求解 n-dim 数组以获得最后 N
轴的 argmax
结合这样的东西 -
def argmax_lastNaxes(A, N):
s = A.shape
new_shp = s[:-N] + (np.prod(s[-N:]),)
max_idx = A.reshape(new_shp).argmax(-1)
return np.unravel_index(max_idx, s[-N:])
结果将是索引数组的元组。如果您需要最终输出为数组,我们可以使用 np.stack
或 np.concatenate
.
您可以使用列表理解
result = [np.unravel_index(np.argmax(r), r.shape) for r in a]
IMO 更具可读性,但速度不会比显式循环好多少。
只有在第一个维度实际上是非常大的维度时,主外循环在 Python 中这一事实才有意义。
如果是这种情况(即你有 1000 万个 2x2 矩阵),那么翻转会更快...
# true if 0,0 is not smaller than others
m00 = ((data[:,0,0] >= data[:,1,0]) &
(data[:,0,0] >= data[:,0,1]) &
(data[:,0,0] >= data[:,1,1]))
# true if 0,1 is not smaller than others
m01 = ((data[:,0,1] >= data[:,1,0]) &
(data[:,0,1] >= data[:,0,0]) &
(data[:,0,1] >= data[:,1,1]))
# true if 1,0 is not smaller than others
m10 = ((data[:,1,0] >= data[:,0,0]) &
(data[:,1,0] >= data[:,0,1]) &
(data[:,1,0] >= data[:,1,1]))
# true if 1,1 is not smaller than others
m11 = ((data[:,1,1] >= data[:,1,0]) &
(data[:,1,1] >= data[:,0,1]) &
(data[:,1,1] >= data[:,0,0]))
# choose which is max on equality
m01 &= ~m00
m10 &= ~(m00|m01)
m11 &= ~(m00|m01|m10)
# compute result
result = np.zeros((len(data), 2), np.int32)
result[:,1] |= m01|m11
result[:,0] |= m10|m11
在我的机器上,上面的代码大约快 50 倍(对于一百万个 2x2 矩阵)。