无法使用 Numba 优化分形代码

Unable to optimize Fractal code with Numba

我正在编写代码来可视化 Mandelbrot 集和其他分形。下面是 运行 中的代码片段。代码 运行 完全没问题,但我正在尝试优化它以更快地制作更高分辨率的图像。我已经尝试在 fractal() 上使用缓存,以及来自 Numba 的 @jit@njit。缓存导致崩溃(我假设是内存溢出)并且 @jit 只是将我的程序执行速度减慢了 6 倍。我也知道有很多数学方法可以使我的代码 运行 更快,正如我在维基百科页面上看到的那样,但我想看看我是否可以获得上述方法之一或其他替代方法。

为了连续创建多个图像(制作缩放动画,就像这个)我已经实现了多处理(这似乎是一次 运行 9 个进程)但我不知道如何在创建单个高分辨率图像时执行相同的操作。

这是我的代码片段:

import numpy as np
import cv2
import cmath
import math

# pick the fractal
def fractal(z,c):
# Mandelbrot
    if fractal_type == 0:
        return z**d + c
# Burning Ship
    if fractal_type == 1:
        return complex(abs(z.real), abs(z.imag))**d + c

#naive escape time algorithm
def naive_escape(arr):
    h = arr[0]
    w = arr[1]
    d = arr[2]
    zoom = pow(1.5, arr[3]) * pow(10,int(np.log10(h)))
    x_cen = arr[4]
    y_cen = arr[5]

    for i in range(w):
        sys.stdout.write("\r{0:03}%".format(np.round(i/w * 100, 4)))
        sys.stdout.flush()

        for j in range(h):
            it = 0
        #coordinates
            cx = i - int(w/2)
            cy = j - int(h/2)
        #scaling
            sx = (cx / (zoom)) + x_cen
            sy = (cy / (zoom)) - y_cen

            c = complex(sx,sy)
            z = complex(0.0,0.0)

            while ((z.real)**2 + (z.imag)**2 <= 2**d) and (it < max_it):
                z = fractal(z,c)
                it += 1

            img[j][i] = color_dict[it]

    sys.stdout.write("\n")

    name = "fractal"

    cv2.imwrite("{}.png".format(name), img)
    print("\n{} created!\n".format(name), fractal_type)


我应该澄清一下,着色函数 naive_escape() 采用数组输入的原因是因为我实现了多处理。由于 map() 在 multiprocessing 中只允许我们将函数映射到一个输入,我只传递一个包含所有输入值的数组。

上面粘贴的代码是从一个更大的文件中摘录的片段,因此如有任何语法错误,请原谅。

如果能帮助我加快代码速度,我们将不胜感激!

This older answer 专门处理矢量化,但可以进行一些额外的优化。

你可以从 Numpy 向量化入手,方便但速度不快:

@np.vectorize
def mandelbrot_numpy(c: complex, max_it: int) -> int:
    z = c
    for i in range(max_it):
        if abs(z) > 2:
            return i
        z = z**2 + c
    return 0

或者 Numba 向量化,它提高了一个数量级的速度:

@nb.vectorize([nb.u2(nb.c16, nb.i8)])
def mandelbrot_numba(c: complex, max_it: int) -> int:
    z = c
    for i in range(max_it):
        if abs(z) > 2:
            return i
        z = z**2 + c
    return 0

然后您可以应用一些常用的优化:

@nb.vectorize([nb.u2(nb.c16, nb.u2)])
def mandelbrot_numba_opt(c: complex, max_it: int) -> int:
    x = cx = c.real
    y = cy = c.imag
    for i in range(max_it):
        x2 = x*x
        y2 = y*y
        if x2 + y2 > 4:
            return i
        y = (x+x)*y + cy
        x = x2 - y2 + cx
    return 0

您还可以将其并行化(在此示例中按行):

@nb.njit([nb.u2[:,:](nb.c16[:,:], nb.u2)], parallel=True)
def mandelbrot_parallel(c: np.ndarray, max_it: int) -> np.ndarray:
    result = np.zeros_like(c, dtype=nb.u2)
    for row in nb.prange(len(c)):
        result[row] = mandelbrot_numba_opt(c[row], max_it)
    return result

1000x1000 阵列上的一些计时:

N = 1000
x = np.linspace(-2, 2, N).reshape((1, -1))
y = x.T
c = x + 1j * y

%timeit mandelbrot_numpy(c, 99)
1.59 s ± 40.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit mandelbrot_numba(c, 99)
100 ms ± 406 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit mandelbrot_numba_opt(c, 99)
35 ms ± 140 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit mandelbrot_parallel(c, 99)
10.9 ms ± 64.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)