找到 'connected' 矩阵的最大最小值的最快方法

Fastest way to find the maximum minimum value of 'connected' matrices

这个 中给出了三个矩阵的答案,但我不确定如何将此逻辑应用于任意数量的成对连接矩阵:

f(i, j, k, l, ...) = min(A(i, j), B(i,k), C(i,l), D(j,k), E(j,l), F(k,l), ...)

其中 A,B,... 是矩阵,i,j,... 是索引,其范围最大为矩阵。如果我们考虑 n 个索引,则有 n(n-1)/2 个对,因此有矩阵。我想找到 (i,j,k,...) 使得 f(i,j,k,l,...) 最大化。我目前正在这样做:

import numpy as np
import itertools

#             i  j  k  l  ...
dimensions = [50,50,50,50]
n_dims = len(dimensions)

pairs = list(itertools.combinations(range(n_dims), 2))

# Construct the matrices A(i,j), B(i,k), ...
matrices = [];
for pair in pairs:
    matrices.append(np.random.rand(dimensions[pair[0]], dimensions[pair[1]]))


# All the different i,j,k,l... combinations
combinations = itertools.product(*list(map(np.arange,dimensions)))
combinations = np.asarray(list(combinations))

# Find the maximum minimum
vals = []

for i in range(len(pairs)):
    pair = pairs[i]
    matrix = matrices[i]
    vals.append(matrix[combinations[:,pair[0]], combinations[:,pair[1]]])


f = np.min(vals,axis=0)

best_indices = combinations[np.argmax(f)]

print(best_indices, np.max(f))

[5 17 17 18] 0.932985854758534

这比遍历所有 (i, j, k, l, ...) 更快,但是构建组合和 vals 矩阵需要花费大量时间。有没有另一种方法可以做到这一点(1)可以保留 numpy 矩阵计算的速度,并且(2)我不必构建内存密集型 vals 矩阵?

这是 3D 解决方案的概括。我假设还有其他(更好的?)组织递归的方法,但这很有效。它在 <10 ms

内完成一个 6D 示例(9x10^6 的产品)

示例运行,注意偶尔两种方法返回的索引不匹配。这是因为它们并不总是唯一的,有时不同的指数组合会产生相同的最大值或最小值。另请注意,最后我们做了一个巨大的 6D 9x10^12 示例的单个 运行。蛮力不再可行,聪明的方法大约需要 10 秒。

trial 1
results identical True
results compatible True
brute force 276.8830654968042 ms
branch cut 9.971900499658659 ms
trial 2
results identical True
results compatible True
brute force 273.444719001418 ms
branch cut 9.236706099909497 ms
trial 3
results identical True
results compatible True
brute force 274.2998780013295 ms
branch cut 7.31226220013923 ms
trial 4
results identical True
results compatible True
brute force 273.0268925006385 ms
branch cut 6.956217200058745 ms
HUGE (100, 150, 200, 100, 150, 200) 9000000000000
branch cut 10246.754082996631 ms

代码:

import numpy as np
import itertools as it
import functools as ft

def bf(dims,pairs):
    dims,pairs = np.array(dims),np.array(pairs,object)
    n,m = len(dims),len(pairs)
    IDX = np.empty((m,n),object)
    Y,X = np.triu_indices(n,1)
    IDX[np.arange(m),Y] = slice(None)
    IDX[np.arange(m),X] = slice(None)
    idx = np.unravel_index(
        ft.reduce(np.minimum,(p[(*i,)] for p,i in zip(pairs,IDX))).argmax(),dims)
    return ft.reduce(np.minimum,(
        p[I] for p,I in zip(pairs,it.combinations(idx,2)))),idx

def cut(dims,pairs,offs=None):
    n = len(dims)
    if n<3:
        if n==2:
            A = pairs[0] if offs is None else np.minimum(
                pairs[0],np.minimum.outer(offs[0],offs[1]))
            idx = np.unravel_index(A.argmax(),dims)
            return A[idx],idx
        else:
            idx = offs[0].argmax()
            return offs[0][idx],(idx,)
    gmx = min(map(np.min,pairs))
    gidx = n * (0,)
    A = pairs[0] if offs is None else np.minimum(
        pairs[0],np.minimum.outer(offs[0],offs[1]))
    Y,X = np.unravel_index(A.argsort(axis=None)[::-1],dims[:2])
    for y,x in zip(Y,X):
        if A[y,x] <= gmx:
            return gmx,gidx
        coffs = [np.minimum(p1[y],p2[x])
                 for p1,p2 in zip(pairs[1:n-1],pairs[n-1:])]
        if not offs is None:
            coffs = [*map(np.minimum,coffs,offs[2:])]
        cmx,cidx = cut(dims[2:],pairs[2*n-3:],coffs)
        if cmx >= A[y,x]:
            return A[y,x],(y,x,*cidx)
        if gmx < cmx:
            gmx = min(A[y,x],cmx)
            gidx = y,x,*cidx
    return gmx,gidx

from timeit import timeit

IDX = 10,15,20,10,15,20

for rep in range(4):
    print("trial",rep+1)
    pairs = [np.random.rand(i,j) for i,j in it.combinations(IDX,2)]

    print("results identical",cut(IDX,pairs)==bf(IDX,pairs))
    print("results compatible",cut(IDX,pairs)[1]==bf(IDX,pairs)[1])
    print("brute force",timeit(lambda:bf(IDX,pairs),number=2)*500,"ms")
    print("branch cut",timeit(lambda:cut(IDX,pairs),number=10)*100,"ms")

IDX = 100,150,200,100,150,200
pairs = [np.random.rand(i,j) for i,j in it.combinations(IDX,2)]
print("HUGE",IDX,np.prod(IDX))
print("branch cut",timeit(lambda:cut(IDX,pairs),number=1)*1000,"ms")