进行Python组操作

Make Python group operations

def f1(x):
    for i in range(1, 100):
        x *= 2
        x /= 3.14159
        x *= i**.25
    return x

def f2(x):
    for i in range(1, 100):
        x *= 2 / 3.14159 * i**.25
    return x

两个函数的计算完全相同,但 f1 需要 3 倍的时间来计算,即使使用 @numba.njit 也是如此。 Python 是否可以识别编译中的等价性,就像它以其他方式优化 dis 一样,例如丢弃未使用的作业?

请注意,我知道浮点运算关心顺序,因此这两个函数的输出可能略有不同,但如果有的话 more 对数组值的单独编辑是 less 准确,所以这是一个二合一的优化。


x = np.random.randn(10000, 1000)
%timeit f1(x.copy())        # 2.68 s ± 50.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f2(x.copy())        # 894 ms ± 36.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit njit(f1)(x.copy())  # 2.59 s ± 65.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit njit(f2)(x.copy())  # 901 ms ± 41.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

可能用 jit 做不到。我尝试了 api 中指定的 fastmath 和 nogil kwarg:https://numba.pydata.org/numba-doc/latest/reference/jit-compilation.html

f0 仍然比 f1 在摆脱溢出或非正规数后稍慢。 plot

from timeit import default_timer as timer
import numpy as np
import matplotlib.pyplot as plt
import numba as nb


def f0(x):
    for i in range(1, 1000):
        x *= 3.000001
        x /= 3
    return x


def f1(x):
    for i in range(1, 1000):
        x *= 3.000001 / 3
    return x


def timing(f, **kwarg):
    x = np.ones(1000, dtype=np.float32)
    times = []
    n_iter = list(range(100, 1000, 100))
    f2 = nb.njit(f, **kwarg)
    for i in n_iter:
        print(i)
        s = timer()
        for j in range(i):
            f2(x)
        e = timer()
        times.append(e - s)
    print(x)
    m, b = np.polyfit(n_iter, times, 1)
    return times, m, b, n_iter


def main():
    results = []
    for fastmath in [True, False]:
        for i, f in enumerate([f0, f1]):
            kwarg = {
                "fastmath": fastmath,
                "nogil": True
            }
            r1, m, b, n_iter = timing(f, **kwarg)
            label = "f%d with %s" % (i, kwarg)
            plt.plot(n_iter, r1, label=label)
            results.append((m, b, label))
    for m, b, kwarg in results:
        print(m * 1e5, b, kwarg)
    plt.legend(loc="upper left")
    plt.xlabel("n iterations")
    plt.ylabel("timing")
    plt.show()
    plt.close()


if __name__ == '__main__':
    main()

使用numba.jit 可能是目前您将获得的此类功能的最佳优化。您可能还想尝试 pypy 并进行一些基准比较。

不过,我想指出为什么这两个函数 不是 等价的,因此您不应该期望 f1 减少到 f2

f1的操作顺序如下:

x1 = (x * 2)            # First binary operation
x2 = (x1 / 3.14159      # Second binary operation
x3 = x2 * (i ** 0.25)   # Third and fourth binary operation

# Order: Multiplication, division, exponent, multiplication

这与 f2 不同:

x *= ((2 / 3.14159) * (i ** 0.25))
#  ^     ^          ^     ^
#  |     |          |     |
#  4     1          3     2

# Order: Division, exponent, multiplication, multiplication

floating-point arithmetic is not associative 以来,这些可能不会产生相同的结果。出于这个原因,编译器或解释器进行您期望的优化是错误的,除非它是为了优化 floating-point 精度。

我不知道 Python 工具可以进行这种特定类型的优化。