如何使用 numba 加速 python 函数
How to speed up a python function with numba
我正在尝试使用 numba 加快 Floyd-Steinberg's dithering algorithm 的实施速度。在阅读初学者指南后,我将 @jit
装饰器添加到我的代码中:
def findClosestColour(pixel):
colors = np.array([[255, 255, 255], [255, 0, 0], [0, 0, 255], [255, 255, 0], [0, 128, 0], [253, 134, 18]])
distances = np.sum(np.abs(pixel[:, np.newaxis].T - colors), axis=1)
shortest = np.argmin(distances)
closest_color = colors[shortest]
return closest_color
@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
def floydDither(img_array):
height, width, _ = img_array.shape
for y in range(0, height-1):
for x in range(1, width-1):
old_pixel = img_array[y, x, :]
new_pixel = findClosestColour(old_pixel)
img_array[y, x, :] = new_pixel
quant_error = new_pixel - old_pixel
img_array[y, x+1, :] = img_array[y, x+1, :] + quant_error * 7/16
img_array[y+1, x-1, :] = img_array[y+1, x-1, :] + quant_error * 3/16
img_array[y+1, x, :] = img_array[y+1, x, :] + quant_error * 5/16
img_array[y+1, x+1, :] = img_array[y+1, x+1, :] + quant_error * 1/16
return img_array
但是,我收到以下错误:
Untyped global name 'findClosestColour': Cannot determine Numba type of <class 'function'>
我想我明白 numba 不知道 findClosestColour
的类型,但我刚开始使用 numba,不知道如何处理错误。
这是我用来测试函数的代码:
image = cv2.imread('logo.jpeg')
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
im_out = floydDither(img)
这是我使用的测试图像。
首先,无法从 Numba nopython jitted 函数(又名 njit 函数)调用纯Python 函数。这是因为 Numba 需要在编译时跟踪类型以生成高效的二进制文件。
此外,Numba 无法编译表达式 pixel[:, np.newaxis].T
,因为 np.newaxis
似乎还不受支持(可能是因为 np.newaxis
是 None
)。您可以改用pixel.reshape(3, -1).T
。
请注意,您应该注意类型,因为当两个变量都是 np.uint8
类型时执行 a - b
会导致可能的 溢出 (例如。 0 - 1 == 255
,甚至更令人惊讶:0 - 256 = 65280
当 b
是字面整数且 a
类型为 np.uint8
)。请注意,数组是就地计算的,像素是在
之前写入的
尽管 Numba 做得很好,但生成的代码效率不会很高。您可以使用循环自己迭代颜色以找到最小索引。这要好一些,因为它不会生成 许多小的临时数组 。您还可以指定类型,以便 Numba 提前编译函数。话虽如此。这也使代码变得更底层等等 verbose/harder-to-maintain.
这是一个优化的实现:
@nb.njit('int32[::1](uint8[::1])')
def nb_findClosestColour(pixel):
colors = np.array([[255, 255, 255], [255, 0, 0], [0, 0, 255], [255, 255, 0], [0, 128, 0], [253, 134, 18]], dtype=np.int32)
r,g,b = pixel.astype(np.int32)
r2,g2,b2 = colors[0]
minDistance = np.abs(r-r2) + np.abs(g-g2) + np.abs(b-b2)
shortest = 0
for i in range(1, colors.shape[0]):
r2,g2,b2 = colors[i]
distance = np.abs(r-r2) + np.abs(g-g2) + np.abs(b-b2)
if distance < minDistance:
minDistance = distance
shortest = i
return colors[shortest]
@nb.njit('uint8[:,:,::1](uint8[:,:,::1])')
def nb_floydDither(img_array):
assert(img_array.shape[2] == 3)
height, width, _ = img_array.shape
for y in range(0, height-1):
for x in range(1, width-1):
old_pixel = img_array[y, x, :]
new_pixel = nb_findClosestColour(old_pixel)
img_array[y, x, :] = new_pixel
quant_error = new_pixel - old_pixel
img_array[y, x+1, :] = img_array[y, x+1, :] + quant_error * 7/16
img_array[y+1, x-1, :] = img_array[y+1, x-1, :] + quant_error * 3/16
img_array[y+1, x, :] = img_array[y+1, x, :] + quant_error * 5/16
img_array[y+1, x+1, :] = img_array[y+1, x+1, :] + quant_error * 1/16
return img_array
原始版本快 14 倍,而最后一个版本快 19 倍。
我正在尝试使用 numba 加快 Floyd-Steinberg's dithering algorithm 的实施速度。在阅读初学者指南后,我将 @jit
装饰器添加到我的代码中:
def findClosestColour(pixel):
colors = np.array([[255, 255, 255], [255, 0, 0], [0, 0, 255], [255, 255, 0], [0, 128, 0], [253, 134, 18]])
distances = np.sum(np.abs(pixel[:, np.newaxis].T - colors), axis=1)
shortest = np.argmin(distances)
closest_color = colors[shortest]
return closest_color
@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
def floydDither(img_array):
height, width, _ = img_array.shape
for y in range(0, height-1):
for x in range(1, width-1):
old_pixel = img_array[y, x, :]
new_pixel = findClosestColour(old_pixel)
img_array[y, x, :] = new_pixel
quant_error = new_pixel - old_pixel
img_array[y, x+1, :] = img_array[y, x+1, :] + quant_error * 7/16
img_array[y+1, x-1, :] = img_array[y+1, x-1, :] + quant_error * 3/16
img_array[y+1, x, :] = img_array[y+1, x, :] + quant_error * 5/16
img_array[y+1, x+1, :] = img_array[y+1, x+1, :] + quant_error * 1/16
return img_array
但是,我收到以下错误:
Untyped global name 'findClosestColour': Cannot determine Numba type of <class 'function'>
我想我明白 numba 不知道 findClosestColour
的类型,但我刚开始使用 numba,不知道如何处理错误。
这是我用来测试函数的代码:
image = cv2.imread('logo.jpeg')
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
im_out = floydDither(img)
这是我使用的测试图像。
首先,无法从 Numba nopython jitted 函数(又名 njit 函数)调用纯Python 函数。这是因为 Numba 需要在编译时跟踪类型以生成高效的二进制文件。
此外,Numba 无法编译表达式 pixel[:, np.newaxis].T
,因为 np.newaxis
似乎还不受支持(可能是因为 np.newaxis
是 None
)。您可以改用pixel.reshape(3, -1).T
。
请注意,您应该注意类型,因为当两个变量都是 np.uint8
类型时执行 a - b
会导致可能的 溢出 (例如。 0 - 1 == 255
,甚至更令人惊讶:0 - 256 = 65280
当 b
是字面整数且 a
类型为 np.uint8
)。请注意,数组是就地计算的,像素是在
尽管 Numba 做得很好,但生成的代码效率不会很高。您可以使用循环自己迭代颜色以找到最小索引。这要好一些,因为它不会生成 许多小的临时数组 。您还可以指定类型,以便 Numba 提前编译函数。话虽如此。这也使代码变得更底层等等 verbose/harder-to-maintain.
这是一个优化的实现:
@nb.njit('int32[::1](uint8[::1])')
def nb_findClosestColour(pixel):
colors = np.array([[255, 255, 255], [255, 0, 0], [0, 0, 255], [255, 255, 0], [0, 128, 0], [253, 134, 18]], dtype=np.int32)
r,g,b = pixel.astype(np.int32)
r2,g2,b2 = colors[0]
minDistance = np.abs(r-r2) + np.abs(g-g2) + np.abs(b-b2)
shortest = 0
for i in range(1, colors.shape[0]):
r2,g2,b2 = colors[i]
distance = np.abs(r-r2) + np.abs(g-g2) + np.abs(b-b2)
if distance < minDistance:
minDistance = distance
shortest = i
return colors[shortest]
@nb.njit('uint8[:,:,::1](uint8[:,:,::1])')
def nb_floydDither(img_array):
assert(img_array.shape[2] == 3)
height, width, _ = img_array.shape
for y in range(0, height-1):
for x in range(1, width-1):
old_pixel = img_array[y, x, :]
new_pixel = nb_findClosestColour(old_pixel)
img_array[y, x, :] = new_pixel
quant_error = new_pixel - old_pixel
img_array[y, x+1, :] = img_array[y, x+1, :] + quant_error * 7/16
img_array[y+1, x-1, :] = img_array[y+1, x-1, :] + quant_error * 3/16
img_array[y+1, x, :] = img_array[y+1, x, :] + quant_error * 5/16
img_array[y+1, x+1, :] = img_array[y+1, x+1, :] + quant_error * 1/16
return img_array
原始版本快 14 倍,而最后一个版本快 19 倍。