使用 jax.numpy 时结果数组为零
Resulting array is zero when jax.numpy is used
我使用 numpy 编写了下面的代码并得到了正确的输出,如程序 1 所示。但是当我切换到 jax.numpy 作为 jnp(在程序 2 中)时,结果输出是一个零数组。我的 MWE 如下所示。我想知道我哪里计算错了?
PS:代码 运行 在不同的 python 文件中。
#Program 1 (using numpy as np):
import numpy as np
num_rows = 5
num_cols = 20
smf = np.array([np.inf, 0.1, 0.1, 0.1, 0.1])
par_init = np.array([1,2,3,4,5])
lb = np.array([0.1, 0.1, 0.1, 0.1, 0.1])
ub = np.array([10, 10, 10, 10, 10])
par = np.broadcast_to(par_init[:,None],(num_rows,num_cols))
kvals = np.where(np.isinf(smf), 1, num_cols)
kvals = np.insert(kvals, 0, 0)
kvals = np.cumsum(kvals)
par0_col = np.zeros(num_rows*num_cols - (num_cols-1) * np.sum(np.isinf(smf)))
lb_col = np.zeros(num_rows*num_cols - (num_cols-1) * np.sum(np.isinf(smf)))
ub_col = np.zeros(num_rows*num_cols- (num_cols-1) * np.sum(np.isinf(smf)))
for i in range(num_rows):
par0_col[kvals[i]:kvals[i+1]] = par[i, :kvals[i+1]-kvals[i]]
lb_col[kvals[i]:kvals[i+1]] = lb[i]
ub_col[kvals[i]:kvals[i+1]] = ub[i]
arr_1 = np.zeros(shape = (num_rows, num_cols))
arr_2 = np.zeros(shape = (num_rows, num_cols))
par_log = np.log10((par0_col - lb_col) / (1 - par0_col / ub_col))
k = 0
for i in range(num_rows):
arr_1[i, :] = (par_log[kvals[i]:kvals[i+1]])
arr_2[i, :] = 10**par_log[kvals[i]:kvals[i+1]]
print(arr_1)
# [[0. 0. 0. 0. 0. 0.
# 0. 0. 0. 0. 0. 0.
# 0. 0. 0. 0. 0. 0.
# 0. 0. ]
# [0.37566361 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361
# 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361
# 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361
# 0.37566361 0.37566361]
# [0.61729996 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996
# 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996
# 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996
# 0.61729996 0.61729996]
# [0.81291336 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336
# 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336
# 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336
# 0.81291336 0.81291336]
# [0.99122608 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608
# 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608
# 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608
# 0.99122608 0.99122608]]
# Program 2 (using jax.numpy as jnp):
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
smf = jnp.array([jnp.inf, 0.1, 0.1, 0.1, 0.1])
par_init = jnp.array([1.0,2.0,3.0,4.0,5.0])
lb = jnp.array([0.1, 0.1, 0.1, 0.1, 0.1])
ub = jnp.array([10.0, 10.0, 10.0, 10.0, 10.0])
par = jnp.broadcast_to(par_init[:,None],(num_rows,num_cols))
kvals = jnp.where(jnp.isinf(smf), 1, num_cols)
kvals = jnp.insert(kvals, 0, 0)
kvals = jnp.cumsum(kvals)
par0_col = jnp.zeros(num_rows*num_cols - (num_cols-1) * jnp.sum(jnp.isinf(smf)))
lb_col = jnp.zeros(num_rows*num_cols - (num_cols-1) * jnp.sum(jnp.isinf(smf)))
ub_col = jnp.zeros(num_rows*num_cols- (num_cols-1) * jnp.sum(jnp.isinf(smf)))
for i in range(num_rows):
par0_col.at[kvals[i]:kvals[i+1]].set(par[i, :kvals[i+1]-kvals[i]])
lb_col.at[kvals[i]:kvals[i+1]].set(lb[i])
ub_col.at[kvals[i]:kvals[i+1]].set(ub[i])
arr_1 = jnp.zeros(shape = (num_rows, num_cols))
arr_2 = jnp.zeros(shape = (num_rows, num_cols))
par_log = jnp.log10((par0_col - lb_col) / (1 - par0_col / ub_col))
for i in range(num_rows):
arr_1.at[i, :].set((par_log[kvals[i]:kvals[i+1]]))
arr_2.at[i, :].set(10**par_log[kvals[i]:kvals[i+1]])
print(arr_1)
# #[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
# [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
# [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
# [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
# [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
哦,我已经想通了。我需要明确分配。
问题是 ndarray.at
expressions 不操作 in-place,而是 return 修改后的值。
所以不是这个:
arr_1.at[i, :].set((par_log[kvals[i]:kvals[i+1]]))
arr_2.at[i, :].set(10**par_log[kvals[i]:kvals[i+1]])
你应该这样写:
arr_1 = arr_1.at[i, :].set((par_log[kvals[i]:kvals[i+1]]))
arr_2 = arr_2.at[i, :].set(10**par_log[kvals[i]:kvals[i+1]])
在 JAX sharp bits: in-place updates 阅读更多内容。
我使用 numpy 编写了下面的代码并得到了正确的输出,如程序 1 所示。但是当我切换到 jax.numpy 作为 jnp(在程序 2 中)时,结果输出是一个零数组。我的 MWE 如下所示。我想知道我哪里计算错了? PS:代码 运行 在不同的 python 文件中。
#Program 1 (using numpy as np):
import numpy as np
num_rows = 5
num_cols = 20
smf = np.array([np.inf, 0.1, 0.1, 0.1, 0.1])
par_init = np.array([1,2,3,4,5])
lb = np.array([0.1, 0.1, 0.1, 0.1, 0.1])
ub = np.array([10, 10, 10, 10, 10])
par = np.broadcast_to(par_init[:,None],(num_rows,num_cols))
kvals = np.where(np.isinf(smf), 1, num_cols)
kvals = np.insert(kvals, 0, 0)
kvals = np.cumsum(kvals)
par0_col = np.zeros(num_rows*num_cols - (num_cols-1) * np.sum(np.isinf(smf)))
lb_col = np.zeros(num_rows*num_cols - (num_cols-1) * np.sum(np.isinf(smf)))
ub_col = np.zeros(num_rows*num_cols- (num_cols-1) * np.sum(np.isinf(smf)))
for i in range(num_rows):
par0_col[kvals[i]:kvals[i+1]] = par[i, :kvals[i+1]-kvals[i]]
lb_col[kvals[i]:kvals[i+1]] = lb[i]
ub_col[kvals[i]:kvals[i+1]] = ub[i]
arr_1 = np.zeros(shape = (num_rows, num_cols))
arr_2 = np.zeros(shape = (num_rows, num_cols))
par_log = np.log10((par0_col - lb_col) / (1 - par0_col / ub_col))
k = 0
for i in range(num_rows):
arr_1[i, :] = (par_log[kvals[i]:kvals[i+1]])
arr_2[i, :] = 10**par_log[kvals[i]:kvals[i+1]]
print(arr_1)
# [[0. 0. 0. 0. 0. 0.
# 0. 0. 0. 0. 0. 0.
# 0. 0. 0. 0. 0. 0.
# 0. 0. ]
# [0.37566361 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361
# 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361
# 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361
# 0.37566361 0.37566361]
# [0.61729996 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996
# 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996
# 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996
# 0.61729996 0.61729996]
# [0.81291336 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336
# 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336
# 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336
# 0.81291336 0.81291336]
# [0.99122608 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608
# 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608
# 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608
# 0.99122608 0.99122608]]
# Program 2 (using jax.numpy as jnp):
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
smf = jnp.array([jnp.inf, 0.1, 0.1, 0.1, 0.1])
par_init = jnp.array([1.0,2.0,3.0,4.0,5.0])
lb = jnp.array([0.1, 0.1, 0.1, 0.1, 0.1])
ub = jnp.array([10.0, 10.0, 10.0, 10.0, 10.0])
par = jnp.broadcast_to(par_init[:,None],(num_rows,num_cols))
kvals = jnp.where(jnp.isinf(smf), 1, num_cols)
kvals = jnp.insert(kvals, 0, 0)
kvals = jnp.cumsum(kvals)
par0_col = jnp.zeros(num_rows*num_cols - (num_cols-1) * jnp.sum(jnp.isinf(smf)))
lb_col = jnp.zeros(num_rows*num_cols - (num_cols-1) * jnp.sum(jnp.isinf(smf)))
ub_col = jnp.zeros(num_rows*num_cols- (num_cols-1) * jnp.sum(jnp.isinf(smf)))
for i in range(num_rows):
par0_col.at[kvals[i]:kvals[i+1]].set(par[i, :kvals[i+1]-kvals[i]])
lb_col.at[kvals[i]:kvals[i+1]].set(lb[i])
ub_col.at[kvals[i]:kvals[i+1]].set(ub[i])
arr_1 = jnp.zeros(shape = (num_rows, num_cols))
arr_2 = jnp.zeros(shape = (num_rows, num_cols))
par_log = jnp.log10((par0_col - lb_col) / (1 - par0_col / ub_col))
for i in range(num_rows):
arr_1.at[i, :].set((par_log[kvals[i]:kvals[i+1]]))
arr_2.at[i, :].set(10**par_log[kvals[i]:kvals[i+1]])
print(arr_1)
# #[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
# [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
# [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
# [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
# [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
哦,我已经想通了。我需要明确分配。
问题是 ndarray.at
expressions 不操作 in-place,而是 return 修改后的值。
所以不是这个:
arr_1.at[i, :].set((par_log[kvals[i]:kvals[i+1]]))
arr_2.at[i, :].set(10**par_log[kvals[i]:kvals[i+1]])
你应该这样写:
arr_1 = arr_1.at[i, :].set((par_log[kvals[i]:kvals[i+1]]))
arr_2 = arr_2.at[i, :].set(10**par_log[kvals[i]:kvals[i+1]])
在 JAX sharp bits: in-place updates 阅读更多内容。