Numba 和多维添加 - 不适用于 numpy.newaxis?
Numba and multidimensions additions - not working with numpy.newaxis?
尝试在 python 上加速 DP 算法,numba 似乎是一个合适的候选者。
我正在用一个提供 3 维数组的一维数组减去一个二维数组。然后,我沿第 3 个维度使用 .argmin()
来获得二维数组。这适用于 numpy,但不适用于 numba。
重现问题的玩具代码:
from numba import jit
import numpy as np
inflow = np.arange(1,0,-0.01) # Dim [T]
actions = np.arange(0,1,0.05) # Dim [M]
start_lvl = np.random.rand(500).reshape(-1,1)*49 # Dim [Nx1]
disc_lvl = np.arange(0,1000) # Dim [O]
@jit(nopython=True)
def my_func(disc_lvl, actions, start_lvl, inflow):
for i in range(0,100):
# Calculate new level at time i
new_lvl = start_lvl + inflow[i] + actions # Dim [N x M]
# For each new_level element, find closest discretized level
diff = (disc_lvl-new_lvl[:,:,np.newaxis]) # Dim [N x M x O]
idx_lvl = abs(diff).argmin(axis=2) # Dim [N x M]
return True
# function works fine without numba
success = my_func(disc_lvl, actions, start_lvl, inflow)
为什么上面的代码没有 运行?它在取出 @jit(nopython=True)
时起作用。
是否有工作轮使以下计算与 numba 一起工作?
我尝试了带有 numpy 重复和 expand_dims 的变体,以及明确定义 jit 函数的输入类型但没有成功。
您需要进行一些更改才能使其正常工作:
- 使用
arr[:, :, None]
添加维度:对于 Numba,它看起来像 getitem
所以更喜欢使用 reshape
- 使用
np.abs
而不是内置的 abs
- 带有
axis
关键字参数的 argmin
是 not implemented。更喜欢使用 Numba 旨在优化的循环。
所有这些都修复后,您可以 运行 jitted 函数:
from numba import jit
import numpy as np
inflow = np.arange(1,0,-0.01) # Dim [T]
actions = np.arange(0,1,0.05) # Dim [M]
start_lvl = np.random.rand(500).reshape(-1,1)*49 # Dim [Nx1]
disc_lvl = np.arange(0,1000) # Dim [O]
@jit(nopython=True)
def my_func(disc_lvl, actions, start_lvl, inflow):
for i in range(0,100):
# Calculate new level at time i
new_lvl = start_lvl + inflow[i] + actions # Dim [N x M]
# For each new_level element, find closest discretized level
new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)
diff = np.abs(disc_lvl - new_lvl_3d) # Dim [N x M x O]
idx_lvl = np.empty(new_lvl.shape)
for i in range(diff.shape[0]):
for j in range(diff.shape[1]):
idx_lvl[i, j] = diff[i, j, :].argmin()
return True
# function works fine without numba
success = my_func(disc_lvl, actions, start_lvl, inflow)
在下面找到我的第一个 post 的更正代码,您可以使用和不使用 numba 库的 jitted 模式执行(通过删除以 @jit 开头的行)。对于这个例子,我观察到速度增加了 2 倍。
from numba import jit
import numpy as np
import datetime as dt
inflow = np.arange(1,0,-0.01) # Dim [T]
nbTime = np.shape(inflow)[0]
actions = np.arange(0,1,0.01) # Dim [M]
start_lvl = np.random.rand(500).reshape(-1,1)*49 # Dim [Nx1]
disc_lvl = np.arange(0,1000) # Dim [O]
@jit(nopython=True)
def my_func(nbTime, disc_lvl, actions, start_lvl, inflow):
# Initialize result
res = np.empty((nbTime,np.shape(start_lvl)[0],np.shape(actions)[0]))
for t in range(0,nbTime):
# Calculate new level at time t
new_lvl = start_lvl + inflow[t] + actions # Dim [N x M]
print(t)
# For each new_level element, find closest discretized level
new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)
diff = np.abs(disc_lvl - new_lvl_3d) # Dim [N x M x O]
idx_lvl = np.empty(new_lvl.shape)
for i in range(diff.shape[0]):
for j in range(diff.shape[1]):
idx_lvl[i, j] = diff[i, j, :].argmin()
res[t,:,:] = idx_lvl
return res
# Call function and print running time
start_time = dt.datetime.now()
result = my_func(nbTime, disc_lvl, actions, start_lvl, inflow)
print('Execution time :',(dt.datetime.now() - start_time))
尝试在 python 上加速 DP 算法,numba 似乎是一个合适的候选者。
我正在用一个提供 3 维数组的一维数组减去一个二维数组。然后,我沿第 3 个维度使用 .argmin()
来获得二维数组。这适用于 numpy,但不适用于 numba。
重现问题的玩具代码:
from numba import jit
import numpy as np
inflow = np.arange(1,0,-0.01) # Dim [T]
actions = np.arange(0,1,0.05) # Dim [M]
start_lvl = np.random.rand(500).reshape(-1,1)*49 # Dim [Nx1]
disc_lvl = np.arange(0,1000) # Dim [O]
@jit(nopython=True)
def my_func(disc_lvl, actions, start_lvl, inflow):
for i in range(0,100):
# Calculate new level at time i
new_lvl = start_lvl + inflow[i] + actions # Dim [N x M]
# For each new_level element, find closest discretized level
diff = (disc_lvl-new_lvl[:,:,np.newaxis]) # Dim [N x M x O]
idx_lvl = abs(diff).argmin(axis=2) # Dim [N x M]
return True
# function works fine without numba
success = my_func(disc_lvl, actions, start_lvl, inflow)
为什么上面的代码没有 运行?它在取出 @jit(nopython=True)
时起作用。
是否有工作轮使以下计算与 numba 一起工作?
我尝试了带有 numpy 重复和 expand_dims 的变体,以及明确定义 jit 函数的输入类型但没有成功。
您需要进行一些更改才能使其正常工作:
- 使用
arr[:, :, None]
添加维度:对于 Numba,它看起来像getitem
所以更喜欢使用reshape
- 使用
np.abs
而不是内置的abs
- 带有
axis
关键字参数的argmin
是 not implemented。更喜欢使用 Numba 旨在优化的循环。
所有这些都修复后,您可以 运行 jitted 函数:
from numba import jit
import numpy as np
inflow = np.arange(1,0,-0.01) # Dim [T]
actions = np.arange(0,1,0.05) # Dim [M]
start_lvl = np.random.rand(500).reshape(-1,1)*49 # Dim [Nx1]
disc_lvl = np.arange(0,1000) # Dim [O]
@jit(nopython=True)
def my_func(disc_lvl, actions, start_lvl, inflow):
for i in range(0,100):
# Calculate new level at time i
new_lvl = start_lvl + inflow[i] + actions # Dim [N x M]
# For each new_level element, find closest discretized level
new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)
diff = np.abs(disc_lvl - new_lvl_3d) # Dim [N x M x O]
idx_lvl = np.empty(new_lvl.shape)
for i in range(diff.shape[0]):
for j in range(diff.shape[1]):
idx_lvl[i, j] = diff[i, j, :].argmin()
return True
# function works fine without numba
success = my_func(disc_lvl, actions, start_lvl, inflow)
在下面找到我的第一个 post 的更正代码,您可以使用和不使用 numba 库的 jitted 模式执行(通过删除以 @jit 开头的行)。对于这个例子,我观察到速度增加了 2 倍。
from numba import jit
import numpy as np
import datetime as dt
inflow = np.arange(1,0,-0.01) # Dim [T]
nbTime = np.shape(inflow)[0]
actions = np.arange(0,1,0.01) # Dim [M]
start_lvl = np.random.rand(500).reshape(-1,1)*49 # Dim [Nx1]
disc_lvl = np.arange(0,1000) # Dim [O]
@jit(nopython=True)
def my_func(nbTime, disc_lvl, actions, start_lvl, inflow):
# Initialize result
res = np.empty((nbTime,np.shape(start_lvl)[0],np.shape(actions)[0]))
for t in range(0,nbTime):
# Calculate new level at time t
new_lvl = start_lvl + inflow[t] + actions # Dim [N x M]
print(t)
# For each new_level element, find closest discretized level
new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)
diff = np.abs(disc_lvl - new_lvl_3d) # Dim [N x M x O]
idx_lvl = np.empty(new_lvl.shape)
for i in range(diff.shape[0]):
for j in range(diff.shape[1]):
idx_lvl[i, j] = diff[i, j, :].argmin()
res[t,:,:] = idx_lvl
return res
# Call function and print running time
start_time = dt.datetime.now()
result = my_func(nbTime, disc_lvl, actions, start_lvl, inflow)
print('Execution time :',(dt.datetime.now() - start_time))