找到三个 'connected' 矩阵的最大最小值的最快方法
Fastest way to find the maximum minimum value of three 'connected' matrices
这个 中给出了两个矩阵的答案,但我不确定如何将此逻辑应用于三个成对连接的矩阵,因为没有 'free' 索引。我想最大化以下功能:
f(i, j, k) = min(A(i, j), B(j, k), C(i,k))
其中 A
、B
和 C
是矩阵,i
、j
和 k
是索引,范围最大为矩阵的各个维度。我想找到 (i, j, k)
使得 f(i, j, k)
最大化。我目前正在这样做:
import numpy as np
import itertools
I = 100
J = 150
K = 200
A = np.random.rand(I, J)
B = np.random.rand(J, K)
C = np.random.rand(I, K)
# All the different i,j,k
combinations = itertools.product(np.arange(I), np.arange(J), np.arange(K))
combinations = np.asarray(list(combinations))
A_vals = A[combinations[:,0], combinations[:,1]]
B_vals = B[combinations[:,1], combinations[:,2]]
C_vals = C[combinations[:,0], combinations[:,2]]
f = np.min([A_vals,B_vals,C_vals],axis=0)
best_indices = combinations[np.argmax(f)]
print(best_indices)
[ 49 14 136]
这比遍历所有 (i, j, k)
更快,但是很多(而且大部分)时间都花在构建 _vals
矩阵上。这很不幸,因为它们包含许多重复值,因为相同的 i
、j
和 k
出现了多次。有没有一种方法可以做到这一点(1)可以保留 numpy 矩阵计算的速度,并且(2)我不必构建内存密集型 _vals
矩阵。
在其他语言中,您也许可以构造矩阵,使它们包含指向 A
、B
和 C
的指针,但我不知道如何在 Python.
编辑:查看后续问题以获取更多索引
除了使用 itertools,您还可以使用重复和拼贴来“构建”组合:
A_=np.repeat(A.reshape((-1,1)),K,axis=0).T
B_=np.tile(B.reshape((-1,1)),(I,1)).T
C_=np.tile(C,J).reshape((-1,1)).T
并将它们传递给 np.min:
print((t:=np.argmax(np.min([A_,B_,C_],axis=0)) , t//(K*J),(t//K)%J, t%K,))
使用 timeit 重复 10 次代码大约需要 18 秒,而使用 numpy 只需大约 1 秒。
我们可以使用 numpy
广播对其进行暴力破解,或者尝试一些智能分支切割:
import numpy as np
def bf(A,B,C):
I,J = A.shape
J,K = B.shape
return np.unravel_index((np.minimum(np.minimum(A[:,:,None],C[:,None,:]),B[None,:,:])).argmax(),(I,J,K))
def cut(A,B,C):
gmx = min(A.min(),B.min(),C.min())
I,J = A.shape
J,K = B.shape
Y,X = np.unravel_index(A.argsort(axis=None)[::-1],A.shape)
for y,x in zip(Y,X):
if A[y,x] <= gmx:
return gamx
curr = np.minimum(B[x,:],C[y,:])
camx = curr.argmax()
cmx = curr[camx]
if cmx >= A[y,x]:
return y,x,camx
if gmx < cmx:
gmx = cmx
gamx = y,x,camx
return gamx
from timeit import timeit
I = 100
J = 150
K = 200
for rep in range(4):
print("trial",rep+1)
A = np.random.rand(I, J)
B = np.random.rand(J, K)
C = np.random.rand(I, K)
print("results identical",cut(A,B,C)==bf(A,B,C))
print("brute force",timeit(lambda:bf(A,B,C),number=2)*500,"ms")
print("branch cut",timeit(lambda:cut(A,B,C),number=10)*100,"ms")
事实证明,在给定的尺寸下,分支切割是非常值得的:
trial 1
results identical True
brute force 169.74265850149095 ms
branch cut 1.951422297861427 ms
trial 2
results identical True
brute force 180.37619898677804 ms
branch cut 2.1000938024371862 ms
trial 3
results identical True
brute force 181.6371419990901 ms
branch cut 1.999850495485589 ms
trial 4
results identical True
brute force 217.75578951928765 ms
branch cut 1.5871295996475965 ms
分支切割是如何工作的?
我们选择一个数组(比如 A)并将其从大到小排序。然后我们一个一个地遍历数组,将每个值与其他数组中的适当值进行比较,并跟踪 运行 最小值的最大值。一旦最大值不小于 A 中的剩余值,我们就完成了。由于这通常很快就会发生,因此我们节省了大量资金。
建立在 的基础上 - 您可以通过使用 numba 获得轻微的加速 (~20%):
import numba
@numba.jit(nopython=True)
def find_gamx(A, B, C, X, Y, gmx):
gamx = (0, 0, 0)
for y, x in zip(Y, X):
if A[y, x] <= gmx:
return gamx
curr = np.minimum(B[x, :], C[y, :])
camx = curr.argmax()
cmx = curr[camx]
if cmx >= A[y, x]:
return y, x, camx
if gmx < cmx:
gmx = cmx
gamx = y, x, camx
return gamx
def cut_numba(A, B, C):
gmx = min(A.min(), B.min(), C.min())
I, J = A.shape
J, K = B.shape
Y, X = np.unravel_index(A.argsort(axis=None)[::-1], A.shape)
gamx = find_gamx(A, B, C, X, Y, gmx)
return gamx
from timeit import timeit
I = 100
J = 150
K = 200
for rep in range(40):
print("trial", rep + 1)
A = np.random.rand(I, J)
B = np.random.rand(J, K)
C = np.random.rand(I, K)
print("results identical", cut(A, B, C) == bf(A, B, C))
print("results identical", cut_numba(A, B, C) == bf(A, B, C))
print("brute force", timeit(lambda: bf(A, B, C), number=2) * 500, "ms")
print("branch cut", timeit(lambda: cut(A, B, C), number=10) * 100, "ms")
print("branch cut_numba", timeit(lambda: cut_numba(A, B, C), number=10) * 100, "ms")
trial 1
results identical True
results identical True
brute force 38.774325 ms
branch cut 1.7196750999999955 ms
branch cut_numba 1.3950291999999864 ms
trial 2
results identical True
results identical True
brute force 38.77167049999996 ms
branch cut 1.8655760999999993 ms
branch cut_numba 1.4977325999999902 ms
trial 3
results identical True
results identical True
brute force 39.69611449999999 ms
branch cut 1.8876490000000024 ms
branch cut_numba 1.421615300000001 ms
trial 4
results identical True
results identical True
brute force 44.338816499999936 ms
branch cut 1.614051399999994 ms
branch cut_numba 1.3842962000000014 ms
这个
f(i, j, k) = min(A(i, j), B(j, k), C(i,k))
其中 A
、B
和 C
是矩阵,i
、j
和 k
是索引,范围最大为矩阵的各个维度。我想找到 (i, j, k)
使得 f(i, j, k)
最大化。我目前正在这样做:
import numpy as np
import itertools
I = 100
J = 150
K = 200
A = np.random.rand(I, J)
B = np.random.rand(J, K)
C = np.random.rand(I, K)
# All the different i,j,k
combinations = itertools.product(np.arange(I), np.arange(J), np.arange(K))
combinations = np.asarray(list(combinations))
A_vals = A[combinations[:,0], combinations[:,1]]
B_vals = B[combinations[:,1], combinations[:,2]]
C_vals = C[combinations[:,0], combinations[:,2]]
f = np.min([A_vals,B_vals,C_vals],axis=0)
best_indices = combinations[np.argmax(f)]
print(best_indices)
[ 49 14 136]
这比遍历所有 (i, j, k)
更快,但是很多(而且大部分)时间都花在构建 _vals
矩阵上。这很不幸,因为它们包含许多重复值,因为相同的 i
、j
和 k
出现了多次。有没有一种方法可以做到这一点(1)可以保留 numpy 矩阵计算的速度,并且(2)我不必构建内存密集型 _vals
矩阵。
在其他语言中,您也许可以构造矩阵,使它们包含指向 A
、B
和 C
的指针,但我不知道如何在 Python.
编辑:查看后续问题以获取更多索引
除了使用 itertools,您还可以使用重复和拼贴来“构建”组合:
A_=np.repeat(A.reshape((-1,1)),K,axis=0).T
B_=np.tile(B.reshape((-1,1)),(I,1)).T
C_=np.tile(C,J).reshape((-1,1)).T
并将它们传递给 np.min:
print((t:=np.argmax(np.min([A_,B_,C_],axis=0)) , t//(K*J),(t//K)%J, t%K,))
使用 timeit 重复 10 次代码大约需要 18 秒,而使用 numpy 只需大约 1 秒。
我们可以使用 numpy
广播对其进行暴力破解,或者尝试一些智能分支切割:
import numpy as np
def bf(A,B,C):
I,J = A.shape
J,K = B.shape
return np.unravel_index((np.minimum(np.minimum(A[:,:,None],C[:,None,:]),B[None,:,:])).argmax(),(I,J,K))
def cut(A,B,C):
gmx = min(A.min(),B.min(),C.min())
I,J = A.shape
J,K = B.shape
Y,X = np.unravel_index(A.argsort(axis=None)[::-1],A.shape)
for y,x in zip(Y,X):
if A[y,x] <= gmx:
return gamx
curr = np.minimum(B[x,:],C[y,:])
camx = curr.argmax()
cmx = curr[camx]
if cmx >= A[y,x]:
return y,x,camx
if gmx < cmx:
gmx = cmx
gamx = y,x,camx
return gamx
from timeit import timeit
I = 100
J = 150
K = 200
for rep in range(4):
print("trial",rep+1)
A = np.random.rand(I, J)
B = np.random.rand(J, K)
C = np.random.rand(I, K)
print("results identical",cut(A,B,C)==bf(A,B,C))
print("brute force",timeit(lambda:bf(A,B,C),number=2)*500,"ms")
print("branch cut",timeit(lambda:cut(A,B,C),number=10)*100,"ms")
事实证明,在给定的尺寸下,分支切割是非常值得的:
trial 1
results identical True
brute force 169.74265850149095 ms
branch cut 1.951422297861427 ms
trial 2
results identical True
brute force 180.37619898677804 ms
branch cut 2.1000938024371862 ms
trial 3
results identical True
brute force 181.6371419990901 ms
branch cut 1.999850495485589 ms
trial 4
results identical True
brute force 217.75578951928765 ms
branch cut 1.5871295996475965 ms
分支切割是如何工作的?
我们选择一个数组(比如 A)并将其从大到小排序。然后我们一个一个地遍历数组,将每个值与其他数组中的适当值进行比较,并跟踪 运行 最小值的最大值。一旦最大值不小于 A 中的剩余值,我们就完成了。由于这通常很快就会发生,因此我们节省了大量资金。
建立在
import numba
@numba.jit(nopython=True)
def find_gamx(A, B, C, X, Y, gmx):
gamx = (0, 0, 0)
for y, x in zip(Y, X):
if A[y, x] <= gmx:
return gamx
curr = np.minimum(B[x, :], C[y, :])
camx = curr.argmax()
cmx = curr[camx]
if cmx >= A[y, x]:
return y, x, camx
if gmx < cmx:
gmx = cmx
gamx = y, x, camx
return gamx
def cut_numba(A, B, C):
gmx = min(A.min(), B.min(), C.min())
I, J = A.shape
J, K = B.shape
Y, X = np.unravel_index(A.argsort(axis=None)[::-1], A.shape)
gamx = find_gamx(A, B, C, X, Y, gmx)
return gamx
from timeit import timeit
I = 100
J = 150
K = 200
for rep in range(40):
print("trial", rep + 1)
A = np.random.rand(I, J)
B = np.random.rand(J, K)
C = np.random.rand(I, K)
print("results identical", cut(A, B, C) == bf(A, B, C))
print("results identical", cut_numba(A, B, C) == bf(A, B, C))
print("brute force", timeit(lambda: bf(A, B, C), number=2) * 500, "ms")
print("branch cut", timeit(lambda: cut(A, B, C), number=10) * 100, "ms")
print("branch cut_numba", timeit(lambda: cut_numba(A, B, C), number=10) * 100, "ms")
trial 1
results identical True
results identical True
brute force 38.774325 ms
branch cut 1.7196750999999955 ms
branch cut_numba 1.3950291999999864 ms
trial 2
results identical True
results identical True
brute force 38.77167049999996 ms
branch cut 1.8655760999999993 ms
branch cut_numba 1.4977325999999902 ms
trial 3
results identical True
results identical True
brute force 39.69611449999999 ms
branch cut 1.8876490000000024 ms
branch cut_numba 1.421615300000001 ms
trial 4
results identical True
results identical True
brute force 44.338816499999936 ms
branch cut 1.614051399999994 ms
branch cut_numba 1.3842962000000014 ms