找到整数列表列表最大值的最快方法

fastest way to find maximum of list of list of integers

假设我有 a: list[list[int]] = [[1, 2, 3], [4, 5, 6], [1, 7, 1]].

max_a: int = max([max(tmp_list) for tmp_list in a]) 是最佳方式吗?或者有没有更快的方法?

我将处理一个包含大约 10 个元素的 8 个列表的列表。每次启动算法我都会找max 160,000次左右

我会使用 Python 标准库中的 itertools.chain.from_iterable

from itertools import chain
max_a = max(chain.from_iterable(a))
print(max_a)

至少在我的系统上它比问题中的方法(用 timeit 测量)更快,但在 Python 的不同版本中可能会有所不同。有趣的事实:如果我修改问题中的代码以使用生成器表达式,它会变慢。

首先,Is premature optimization really the root of all evil? 现在,关于优化的事情是它可以非常 input-reliant。即使在理论上某些方法更适合大输入尺寸,将环境旋转到 运行 的成本也可能超过使用更复杂的方法所获得的好处。下面是我为测试不同算法而创建的示例代码。

当您提供 运行 数据大小时,似乎简单的 double max 最适合这项工作,而使用 numpy 数组总是较慢。我看到一些性能提升的可能性是 numba 的 JIT 与 np.array 相结合。它在您的样本量下速度较慢,但​​是当问题增长时它变得越来越有效。最重要的是,后续调用的效率令人难以置信,因为它们已经编译,比所有其他选项高出几个数量级。

如您所见,如果您希望获得实际的性能增益,算法不能在空白中进行比较,应该在实际工作数据上逐个选择和测试。

import random
import numpy as np
from numba import njit
from itertools import chain
n, m = 8, 10
ls = [[random.uniform(0, 1) for _ in range(n)] for _ in range(m)]

k = 160000
from functools import wraps
from time import time

def timing(f):
    @wraps(f)
    def wrap(*args, **kw):
        ts = time()
        result = f(*args, **kw)
        te = time()
        print(f"func:{f.__name__} took {te-ts}s, {result=}")
        return result
    return wrap

@timing
def test1(k, ls):
    for i in range(k):
        max(max(ls))
    return max(max(ls))

npls = np.array(ls)

@timing
def test2(k, npls):
    for i in range(k):
        npls.max()
    return npls.max()

@timing
@njit
def test3(k, npls):
    for i in range(k-1):
        npls.max()
    return npls.max()
    
@timing
def test4(k, ls):
    for i in range(k):
        max(chain.from_iterable(ls))
    return max(chain.from_iterable(ls))

test1(k, ls)
test2(k, npls)
test3(k, npls)
test4(k, ls)

ls = [[random.uniform(0, 1) for _ in range(n)] for _ in range(m)]
npls = np.array(ls)

test3(k, npls)

更多ways/benchmarks:

2.30 us  2.30 us  2.31 us  chained
2.85 us  2.86 us  2.86 us  self
2.88 us  2.89 us  2.91 us  self2
2.99 us  3.00 us  3.03 us  mapmax
3.45 us  3.45 us  3.45 us  listcomp
3.46 us  3.54 us  3.54 us  genexp

代码(Try it online!):

def listcomp(a):
    return max([max(tmp_list) for tmp_list in a])

def genexp(a):
    return max(max(tmp_list) for tmp_list in a)

def mapmax(a):
    return max(map(max, a))

def chained(a):
    return max(chain.from_iterable(a))

def self(a):
    maxi = -1
    for b in a:
        for c in b:
            if c > maxi:
                maxi = c
    return maxi

def self2(a):
    maxi = a[0][0]
    for b in a:
        for c in b:
            if c > maxi:
                maxi = c
    return maxi

funcs = [listcomp, genexp, mapmax, chained, self, self2]

from timeit import repeat
import random
from bisect import insort
from collections import deque
from itertools import chain

tests = 100
A = [[random.choices(range(1000), k=10) for _ in range(8)]
     for _ in range(tests)]

expect = list(map(funcs[0], A))
for func in funcs:
    result = list(map(func, A))
    assert result == expect, func.__name__

times = {func: [] for func in funcs}
for _ in range(10):
    random.shuffle(funcs)
    for func in funcs:
        time = min(repeat(lambda: deque(map(func, A), 0), number=1)) / tests
        insort(times[func], time)
for func in sorted(funcs, key=times.get):
    print(*('%.2f us ' % (t * 1e6) for t in times[func][:3]), func.__name__)