amin 中的条件

Conditionals within amin

我有两个数组 UV,都是形状 (f, m)。我将在下面的示例中设置 f = 4,m = 3。

我想提取U每一列的最小值,前提是V中对应的值是非负的,即对于第j^列,我想return U[i,j] 的最小值使得 V[i,j] > 0.

我的第一次尝试是:

import numpy as np

U = np.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])
V = np.array([[1,-1,1],[1,1,1],[-1,-1,1],[-1,-1,1]])

mask = (V > 0)

np.amin(U[mask], axis = 0)

但是这个 returns 1(整个数组的最小值),而不是 [1,5,3],我正在寻找的条件,按列的最小值。

我的问题似乎是 U[mask] 变平为形状 (1, 7),这破坏了 (4, 3) 结构,并使得搜索列最小值变得不可能(显然) .

我有没有办法修改此代码,以便我可以 return 我正在寻找的列最小值?

可能不是最漂亮的解决方案,但它确实有效 ;-)

mask = np.where(V[:,:] < 0, np.inf, 1)
x = np.amin(U*mask, axis = 1)

这听起来像是 masked arrays 的任务:

np.amin(np.ma.masked_array(U, V <= 0), axis=0)

让我们比较一下所提出方法的性能:

import numpy as np
from time_stats import compare_calls

U = np.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])
V = np.array([[1,-1,1],[1,1,1],[-1,-1,1],[-1,-1,1]])


def masked(U=U, V=V):
    return np.amin(U[mask], axis = 0)

def where1(U=U, V=V):
    mask = np.where(V[:,:] < 0, np.inf, 1)
    return np.amin(U*mask, axis = 1)

def where2(U=U, V=V):
    np.where(V>0, U, np.iinfo(int).max).min(axis=0)


r = compare_calls(['masked()', 'where1()', 'where2()'], globals=globals())
print(r)
r.hist()

# masked() : 0.0001 s/call median, 9.7e-05 ... 0.00016 IQR
# where1() : 1e-05 s/call median, 1e-05 ... 1.1e-05 IQR
where2() : 9.6e-06 s/call median, 9.1e-06 ... 1e-05 IQR

对于这个矩阵大小,使用 where 显然比掩码数组快 :) 矩阵越大,差异越小,但@PaulPanzer 的解决方案总是最快的。

例如对于 1000x1000 矩阵:

# masked() : 0.015 s/call median, 0.015 ... 0.016 IQR
# where1() : 0.017 s/call median, 0.017 ... 0.02 IQR
# where2() : 0.011 s/call median, 0.01 ... 0.013 IQR

您可以将 whereiinfo 一起使用:

np.where(V>0, U, np.iinfo(int).max).min(axis=0)
# array([1, 5, 3], dtype=int64)

np.inf 不是整数,因此会强制进行不希望的向上转换。

np.where(V>0, U, np.inf).min(axis=0)
# array([1., 5., 3.])

Step-by-step:

np.iinfo(int)
# iinfo(min=-9223372036854775808, max=9223372036854775807, dtype=int64)

np.where(V>0, U, np.iinfo(int).max)
# array([[                  1, 9223372036854775807,                   3],
#        [                  4,                   5,                   6],
#        [9223372036854775807, 9223372036854775807,                   9],
#        [9223372036854775807, 9223372036854775807,                  12]],
#       dtype=int64)