优化并删除 numpy 数组中的 for 循环
Optimize and remove for loop in numpy array
我想从下面的代码中删除 for 循环。在我的用例中,n
的值通常会大得多,其中 200-700 范围内的值并不少见,并且不方便将它们全部列出来,再添加一个循环只会使这更低效。
import numpy as np
n = 3
imgs = np.random.random((16,9,9,n))
transform = np.random.uniform(low=0.0, high=1.0, size=(n,n))
for img in imgs:
for channel in range(img.shape[2]):
temp1 = img[:,:,0]
temp2 = img[:,:,1]
temp3 = img[:,:,2]
temp = temp1 * transform[channel][0] + temp2 * transform[channel][1] + temp3 * transform[channel][2]
img[:,:,channel] = temp/3
任何指点将不胜感激。
我想你也许可以完全避免内部循环。
据我了解,您正在对 transform
矩阵与 [temp1 temp2 temp3]
矩阵的转置进行点积,然后将其除以 3。
以下是图像中相同的表示:
因此,所有这些实际上都可以在 for 循环本身之外完成。代码看起来像这样。 P.S。还修改了一些感觉不一致的地方的变量名
import numpy as np
n = 3
imgs = np.random.random((16,9,9,n))
transform = np.random.uniform(low=0.0, high=1.0, size=(n,n))
for img in imgs:
temp_arr = img[:,:, 0:3]
img[:,:, 0:3] = np.dot(temp_arr, np.transpose(transform))/3
将结果与您的结果进行比较,结果相同
切片可能无法按照您想象的方式应用,但使用 numba
您可以获得一些有用的东西,而无需添加另一个 for 循环。
此代码执行与您提供的代码相同的操作。
import numpy as np
from numba import njit
@njit
def mult_arrays_values( array, values, n):
temp = np.zeros((array.shape[0], array.shape[1], n))
for i in range(n):
temp[:,:,i] = array[:,:,i] * values[i]
return temp
n = 3
imgs = np.random.random((16,9,9,n))
transform = np.random.uniform(low=0.0, high=1.0, size=(n,n))
for img in imgs:
for channel in range(img.shape[2]):
temp = mult_arrays_values(img[:,:,:], transform[channel][:], n)
temp = np.add.reduce(temp , axis=2)
img[:,:,channel] = temp/n
您也可以使用numpy.einsum
In [77]: tr = np.arange(9).reshape((3,3))
...: imgs = np.arange(24).reshape((2,2,2,3)).transpose((1,0,2,3))
...: print('original imgs\n', imgs)
...: for img in imgs:
...: img[:,:,:] = np.dot(img, tr.T)
...: print('imgs after loop\n', imgs)
...: imgs = np.arange(24).reshape((2,2,2,3)).transpose((1,0,2,3))
...: print('result of einsum\n', np.einsum('ijkl,ml', imgs, tr))
original imgs
[[[[ 0 1 2]
[ 3 4 5]]
[[12 13 14]
[15 16 17]]]
[[[ 6 7 8]
[ 9 10 11]]
[[18 19 20]
[21 22 23]]]]
imgs after loop
[[[[ 5 14 23]
[ 14 50 86]]
[[ 41 158 275]
[ 50 194 338]]]
[[[ 23 86 149]
[ 32 122 212]]
[[ 59 230 401]
[ 68 266 464]]]]
result of einsum
[[[[ 5 14 23]
[ 14 50 86]]
[[ 41 158 275]
[ 50 194 338]]]
[[[ 23 86 149]
[ 32 122 212]]
[[ 59 230 401]
[ 68 266 464]]]]
我想从下面的代码中删除 for 循环。在我的用例中,n
的值通常会大得多,其中 200-700 范围内的值并不少见,并且不方便将它们全部列出来,再添加一个循环只会使这更低效。
import numpy as np
n = 3
imgs = np.random.random((16,9,9,n))
transform = np.random.uniform(low=0.0, high=1.0, size=(n,n))
for img in imgs:
for channel in range(img.shape[2]):
temp1 = img[:,:,0]
temp2 = img[:,:,1]
temp3 = img[:,:,2]
temp = temp1 * transform[channel][0] + temp2 * transform[channel][1] + temp3 * transform[channel][2]
img[:,:,channel] = temp/3
任何指点将不胜感激。
我想你也许可以完全避免内部循环。
据我了解,您正在对 transform
矩阵与 [temp1 temp2 temp3]
矩阵的转置进行点积,然后将其除以 3。
以下是图像中相同的表示:
因此,所有这些实际上都可以在 for 循环本身之外完成。代码看起来像这样。 P.S。还修改了一些感觉不一致的地方的变量名
import numpy as np
n = 3
imgs = np.random.random((16,9,9,n))
transform = np.random.uniform(low=0.0, high=1.0, size=(n,n))
for img in imgs:
temp_arr = img[:,:, 0:3]
img[:,:, 0:3] = np.dot(temp_arr, np.transpose(transform))/3
将结果与您的结果进行比较,结果相同
切片可能无法按照您想象的方式应用,但使用 numba
您可以获得一些有用的东西,而无需添加另一个 for 循环。
此代码执行与您提供的代码相同的操作。
import numpy as np
from numba import njit
@njit
def mult_arrays_values( array, values, n):
temp = np.zeros((array.shape[0], array.shape[1], n))
for i in range(n):
temp[:,:,i] = array[:,:,i] * values[i]
return temp
n = 3
imgs = np.random.random((16,9,9,n))
transform = np.random.uniform(low=0.0, high=1.0, size=(n,n))
for img in imgs:
for channel in range(img.shape[2]):
temp = mult_arrays_values(img[:,:,:], transform[channel][:], n)
temp = np.add.reduce(temp , axis=2)
img[:,:,channel] = temp/n
您也可以使用numpy.einsum
In [77]: tr = np.arange(9).reshape((3,3))
...: imgs = np.arange(24).reshape((2,2,2,3)).transpose((1,0,2,3))
...: print('original imgs\n', imgs)
...: for img in imgs:
...: img[:,:,:] = np.dot(img, tr.T)
...: print('imgs after loop\n', imgs)
...: imgs = np.arange(24).reshape((2,2,2,3)).transpose((1,0,2,3))
...: print('result of einsum\n', np.einsum('ijkl,ml', imgs, tr))
original imgs
[[[[ 0 1 2]
[ 3 4 5]]
[[12 13 14]
[15 16 17]]]
[[[ 6 7 8]
[ 9 10 11]]
[[18 19 20]
[21 22 23]]]]
imgs after loop
[[[[ 5 14 23]
[ 14 50 86]]
[[ 41 158 275]
[ 50 194 338]]]
[[[ 23 86 149]
[ 32 122 212]]
[[ 59 230 401]
[ 68 266 464]]]]
result of einsum
[[[[ 5 14 23]
[ 14 50 86]]
[[ 41 158 275]
[ 50 194 338]]]
[[[ 23 86 149]
[ 32 122 212]]
[[ 59 230 401]
[ 68 266 464]]]]