cython 加速 3d 列表操作

cython to speed up a 3d list manipulation

有人可以帮我为这个例子创建一个 cython 代码吗?

我创建这个例子是因为我想创建一个更快的版本,作为我正在考虑的解决方案 and/or ,但如果有其他方法,请随时提出。

该代码旨在从前一个列表(随机生成的列表 1)开始生成一个 3d 列表(称为列表 2)。

它使用两个转换函数(transform 和 trasform1):一个仅执行一些随机数学运算(transform),另一个使用 list1 生成最终值以插入 list2。

我用

跟踪时间

代码如下:

import math
from random import seed
from random import random

global i_elements
global j_elements
global k_elements

i_elements,j_elements,k_elements =11,11,11

global list1
global list2
seed(1)
list1 = [[[random() for k in range(i_elements)] for j in range(j_elements)] for i in range(k_elements)]
list2 = [[[0 for k in range(i_elements)] for j in range(j_elements)] for i in range(k_elements)]

def transform(x, y, z):
    '''
    no-sense function performing some math
    '''
    a = 0
    b = 0
    sol = 0

    if x > y:
        a = math.sqrt(x ** 2) + math.atan(y ** 2 +1)
    else:
        b = math.sqrt(z ** 2) + math.atan(y ** 2 +1)

    if x > z:
        sol = math.sqrt(a*b)
    else:
        sol = math.sqrt(b**2)

    return sol

def transform2(a, b, c):
    '''
    transformation dependent on element in list1
    '''

    global list1, i_elements,j_elements,k_elements
    sol = 0

    for i in range(i_elements):
        for j in range(j_elements):
            for k in range(k_elements):
                temp = transform(i, j, k)
                if list1[i][j][k] > temp:
                    sol = temp*list1[i][j][k]*(a+1)**2
                else:
                    sol = temp + list1[i][j][k]**(b*c +1)

    return sol

def save_list():
    '''
    function to save my 3d list after the transform2
    '''
    global list2,i_elements, j_elements, k_elements
    for i in range(i_elements):
        for j in range(j_elements):
            for k in range(k_elements):
                list2[i][j][k] = transform2(i,j,k)

    return list2

def main():
    save_list()
    print('finish')

if __name__ == "__main__":
    import cProfile

    cProfile.run('main()')

输出为:

finish
         7087581 function calls in 2.628 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    2.628    2.628 <string>:1(<module>)
  1771561    1.453    0.000    1.932    0.000 code_to_speed_up.py:17(transform)
     1331    0.696    0.001    2.628    0.002 code_to_speed_up.py:37(transform2)
        1    0.001    0.001    2.628    2.628 code_to_speed_up.py:56(save_list)
        1    0.000    0.000    2.628    2.628 code_to_speed_up.py:68(main)
        1    0.000    0.000    2.628    2.628 {built-in method builtins.exec}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.print}
  1771561    0.178    0.000    0.178    0.000 {built-in method math.atan}
  3543122    0.301    0.000    0.301    0.000 {built-in method math.sqrt}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

Numpy 可以用来解决这个问题,但是 transform2 很难被有效地向量化。然而,Cython 和 Numba 可以高效地做到这一点(Numba 有点像 Cython,但它是一个即时编译器,在这里使用起来更简单)。

单独使用 Cython 或 Numba 是不够的,因为无法有效地计算列表(由于引用计数、GIL、导致额外间接的低效内部表示等)。

请注意,使用全局变量不是一个好主意。它通常速度较慢,并且使您的代码更难优化和理解(这被视为一种糟糕的软件工程实践,尤其是在修改全局变量时,因为它会导致隐藏的依赖关系,或者换句话说,这是一种远距离的幽灵般的动作)。

这是一个使用 Numpy + Numba 的(几乎没有测试过的)示例:

import math
import numpy as np
import numba as nb

i_elements,j_elements,k_elements = 11,11,11
np.random.seed(1)
arr1 = np.random.rand(i_elements, j_elements, k_elements)
arr2 = np.zeros((i_elements, j_elements, k_elements))


@nb.njit
def transform(x, y, z):
    '''
    no-sense function performing some math
    '''
    a = 0
    b = 0
    sol = 0

    if x > y:
        a = math.sqrt(x ** 2) + math.atan(y ** 2 +1)
    else:
        b = math.sqrt(z ** 2) + math.atan(y ** 2 +1)

    if x > z:
        sol = math.sqrt(a*b)
    else:
        sol = math.sqrt(b**2)

    return sol

@nb.njit
def transform2(arr1, a, b, c):
    '''
    transformation dependent on element in arr1
    '''
    sol = 0

    for i in range(arr1.shape[0]):
        for j in range(arr1.shape[1]):
            for k in range(arr1.shape[2]):
                temp = transform(i, j, k)
                if arr1[i,j,k] > temp:
                    sol = temp * arr1[i,j,k] * (a+1)**2
                else:
                    sol = temp + arr1[i,j,k]**(b*c +1)

    return sol

# Giving a signature to Numba helps him to compile the function eagerly
# Read the doc for more information about this.
@nb.njit('(float64[:,:,::1],float64[:,:,::1])')
def save_list(arr1, arr2):
    '''
    function to save my 3d list after the transform2
    '''
    for i in range(arr2.shape[0]):
        for j in range(arr2.shape[1]):
            for k in range(arr2.shape[2]):
                arr2[i,j,k] = transform2(arr1, i,j,k)

    return arr2

def main():
    global arr1
    global arr2
    save_list(arr1, arr2)
    print('finish')

if __name__ == "__main__":
    import cProfile
    cProfile.run('main()')

请注意,尽管种子设置为相同的值,np.random.rand 可能会产生不同的结果,因为 Numpy 肯定使用不同的随机数生成器实现。

如果你真的想使用 Cython 而不是 Numba,那么你需要使用 Numpy 内存视图。有关详细信息,请阅读 Cython 文档的 Cython for NumPy users tutorial

在我的机器上 500 倍以上

请注意,save_list 中的 return 不是很有用,因为它是传入参数(对于全局变量也不是更有用)。另请注意,soltransform2 中分配,这意味着只有最后一次迭代很重要(编译器可以对其进行优化)。不过这很可疑:它看起来像一个错误,你当然想执行减少(例如 +=),特别是因为初始分配为 0。请检查结果是否正确。