用 Numba 优化整组元组的字典?

Optimizing dict of set of tuple of ints with Numba?

我正在学习如何使用 Numba(虽然我已经相当熟悉 Cython)。我应该如何加速这段代码?注意函数 returns 一个由整数的二元组集合组成的字典。我正在使用 IPython 笔记本。我更喜欢 Numba 而不是 Cython。

@autojit
def generateadj(width,height):
    adj = {}
    for y in range(height):
        for x in range(width):
            s = set()
            if x>0:
                s.add((x-1,y))
            if x<width-1:
                s.add((x+1,y))
            if y>0:
                s.add((x,y-1))
            if y<height-1:
                s.add((x,y+1))
            adj[x,y] = s
    return adj

我设法用 Cython 写了这个,但我不得不放弃数据的结构方式。我不喜欢这个。我在 Numba 文档的某个地方读到它可以处理列表、元组等基本内容

%%cython
import numpy as np

def generateadj(int width, int height):
    cdef int[:,:,:,:] adj = np.zeros((width,height,4,2), np.int32)
    cdef int count

    for y in range(height):
        for x in range(width):
            count = 0
            if x>0:
                adj[x,y,count,0] = x-1
                adj[x,y,count,1] = y
                count += 1
            if x<width-1:
                adj[x,y,count,0] = x+1
                adj[x,y,count,1] = y
                count += 1
            if y>0:
                adj[x,y,count,0] = x
                adj[x,y,count,1] = y-1
                count += 1
            if y<height-1:
                adj[x,y,count,0] = x
                adj[x,y,count,1] = y+1
                count += 1
            for i in range(count,4):
                adj[x,y,i] = adj[x,y,0]
    return adj

虽然 numba 支持 Python 数据结构 dicts 和 sets,但它在 object 模式.在 numba 词汇表中,对象模式定义为:

A Numba compilation mode that generates code that handles all values as Python objects and uses the Python C API to perform all operations on those objects. Code compiled in object mode will often run no faster than Python interpreted code, unless the Numba compiler can take advantage of loop-jitting.

因此,在编写 numba 代码时,您需要坚持使用数组等内置数据类型。这里有一些代码可以做到这一点:

@jit
def gen_adj_loop(width, height, adj):
    i = 0
    for x in range(width):
        for y in range(height):
            if x > 0:
                adj[i,0] = x
                adj[i,1] = y
                adj[i,2] = x - 1
                adj[i,3] = y
                i += 1

            if x < width - 1:
                adj[i,0] = x
                adj[i,1] = y
                adj[i,2] = x + 1
                adj[i,3] = y
                i += 1

            if y > 0:
                adj[i,0] = x
                adj[i,1] = y
                adj[i,2] = x
                adj[i,3] = y - 1
                i += 1

            if y < height - 1:
                adj[i,0] = x
                adj[i,1] = y
                adj[i,2] = x
                adj[i,3] = y + 1
                i += 1
    return

这需要一个数组 adj。每行的形式为 x y adj_x adj_y。所以对于 (3,4) 处的像素,我们有四行:

3 4 2 4
3 4 4 4
3 4 3 3
3 4 3 5

我们可以将上面的函数包装在另一个函数中:

@jit
def gen_adj(width, height):
    # each pixel has four neighbors, but some of these neighbors are
    # off the grid -- 2*width + 2*height of them to be exact
    n_entries = width*height*4 - 2*width - 2*height
    adj = np.zeros((n_entries, 4), dtype=int)
    gen_adj_loop(width, height, adj)

这个功能非常快,但是不完整。我们必须将 adj 转换为您问题中形式的字典。问题是这是一个非常缓慢的过程。我们必须遍历 adj 数组并将每个条目添加到 Python 字典中。 numba.

无法解决这个问题

所以底线是:结果是元组字典的要求确实限制了您可以优化此代码的程度。