如何优化这些嵌套循环?

How to optimize these nested loops?

我正在尝试计算一种颜色(或最接近 17 种颜色之一的颜色)出现在图像像素中的次数(给定为 300x300x3 np 数组,浮点值 [0,1]) .我已经写了这个但是它似乎非常低效:

for w in range(300):
    for h in range(300):
        colordistance = float('inf')
        colorindex = 0
        for c in range(17):
            r1 = color[c, 0]
            g1 = color[c, 1]
            b1 = color[c, 2]
            r2 = img[w, h, 0]
            g2 = img[w, h, 1]
            b2 = img[w, h, 2]
            distance = math.sqrt(
                ((r2-r1)*0.3)**2 + ((g2-g1)*0.59)**2 + ((b2-b1)*0.11)**2)
            if distance < colordistance:
                colordistance = distance
                colorindex = c
        colorcounters[colorindex] = colorcounters[colorindex] + 1

有什么办法可以提高这个位的效率吗?我已经在为外循环使用多处理。

你提到你正在使用 numpy,所以你应该尽可能避免迭代。我的矢量化实现速度大约快 40 倍。我对您的代码做了一些更改,以便它们可以使用相同的数组,这样我就可以验证正确性。这可能会影响速度。

import numpy as np
import time
import math

num_images = 1
hw = 300 # height and width
nc = 17  # number of colors
img = np.random.random((num_images, hw, hw, 1, 3))
colors = np.random.random((1, 1, 1, nc, 3))

## NUMPY IMPLEMENTATION
t = time.time()

dist_sq = np.sum(((img - colors) * [0.3, 0.59, 0.11]) ** 2, axis=4)  # calculate (distance * coefficients) ** 2
largest_color = np.argmin(dist_sq, axis=3)  # find the minimum
color_counters = np.unique(largest_color, return_counts=True) # count
print(color_counters)
# should return an object like [[1, 2, 3, ... 17], [count1, count2, count3, ...]]

print("took {} s".format(time.time() - t))

## REFERENCE IMPLEMENTATION
t = time.time()
colorcounters = [0 for i in range(nc)]
for i in range(num_images):
    for h in range(hw):
        for w in range(hw):
            colordistance = float('inf')
            colorindex = 0
            for c in range(nc):
                r1 = colors[0, 0, 0, c, 0]
                g1 = colors[0, 0, 0, c, 1]
                b1 = colors[0, 0, 0, c, 2]
                r2 = img[i, w, h, 0, 0]
                g2 = img[i, w, h, 0, 1]
                b2 = img[i, w, h, 0, 2]

                # dist_sq
                distance = math.sqrt(((r2-r1)*0.3)**2 + ((g2-g1)*0.59)**2 + ((b2-b1)*0.11)**2)  # not using sqrt gives a 14% improvement

                # largest_color
                if distance < colordistance:
                    colordistance = distance
                    colorindex = c
            # color_counters
            colorcounters[colorindex] = colorcounters[colorindex] + 1
print(colorcounters)
print("took {} s".format(time.time() - t))