优化并删除 numpy 数组中的 for 循环

Optimize and remove for loop in numpy array

我想从下面的代码中删除 for 循环。在我的用例中,n 的值通常会大得多,其中 200-700 范围内的值并不少见,并且不方便将它们全部列出来,再添加一个循环只会使这更低效。

import numpy as np
n = 3
imgs = np.random.random((16,9,9,n))
transform = np.random.uniform(low=0.0, high=1.0, size=(n,n))
for img in imgs:
    for channel in range(img.shape[2]):
        temp1 = img[:,:,0]
        temp2 = img[:,:,1]
        temp3 = img[:,:,2]

        temp = temp1 * transform[channel][0] + temp2 * transform[channel][1] + temp3 * transform[channel][2]

        img[:,:,channel] = temp/3

任何指点将不胜感激。

我想你也许可以完全避免内部循环。

据我了解,您正在对 transform 矩阵与 [temp1 temp2 temp3] 矩阵的转置进行点积,然后将其除以 3。

以下是图像中相同的表示:

因此,所有这些实际上都可以在 for 循环本身之外完成。代码看起来像这样。 P.S。还修改了一些感觉不一致的地方的变量名

import numpy as np

n = 3
imgs = np.random.random((16,9,9,n))
transform = np.random.uniform(low=0.0, high=1.0, size=(n,n))

for img in imgs:
    temp_arr = img[:,:, 0:3]
    img[:,:, 0:3] = np.dot(temp_arr, np.transpose(transform))/3

将结果与您的结果进行比较,结果相同

切片可能无法按照您想象的方式应用,但使用 numba 您可以获得一些有用的东西,而无需添加另一个 for 循环。

此代码执行与您提供的代码相同的操作。

import numpy as np
from numba import njit

@njit
def mult_arrays_values( array, values, n):
    temp = np.zeros((array.shape[0], array.shape[1], n))
    for i in range(n):
        temp[:,:,i] = array[:,:,i] * values[i]
    return temp

n = 3
imgs = np.random.random((16,9,9,n))
transform = np.random.uniform(low=0.0, high=1.0, size=(n,n))
for img in imgs:
    for channel in range(img.shape[2]):
        temp = mult_arrays_values(img[:,:,:], transform[channel][:], n)
        temp = np.add.reduce(temp , axis=2)
        img[:,:,channel] = temp/n

您也可以使用numpy.einsum

In [77]: tr = np.arange(9).reshape((3,3))
    ...: imgs = np.arange(24).reshape((2,2,2,3)).transpose((1,0,2,3))
    ...: print('original imgs\n', imgs)
    ...: for img in imgs:
    ...:     img[:,:,:] = np.dot(img, tr.T)
    ...: print('imgs after loop\n', imgs)
    ...: imgs = np.arange(24).reshape((2,2,2,3)).transpose((1,0,2,3))
    ...: print('result of einsum\n', np.einsum('ijkl,ml', imgs, tr))
original imgs
 [[[[ 0  1  2]
   [ 3  4  5]]

  [[12 13 14]
   [15 16 17]]]


 [[[ 6  7  8]
   [ 9 10 11]]

  [[18 19 20]
   [21 22 23]]]]
imgs after loop
 [[[[  5  14  23]
   [ 14  50  86]]

  [[ 41 158 275]
   [ 50 194 338]]]


 [[[ 23  86 149]
   [ 32 122 212]]

  [[ 59 230 401]
   [ 68 266 464]]]]
result of einsum
 [[[[  5  14  23]
   [ 14  50  86]]

  [[ 41 158 275]
   [ 50 194 338]]]


 [[[ 23  86 149]
   [ 32 122 212]]

  [[ 59 230 401]
   [ 68 266 464]]]]