将数组插入另一个数组的所有位置
Insert array into all places in another array
对于两个数组,比方说,a = np.array([1,2,3,4])
和 b = np.array([5,6])
,有没有办法(如果有的话)在没有循环的情况下获得以下形式的二维数组 :
[[5 6 1 2 3 4]
[1 5 6 2 3 4]
[1 2 5 6 3 4]
[1 2 3 5 6 4]
[1 2 3 4 5 6]]
即在 a
.
的所有可能位置插入 b
如果循环不可避免,如何计算最有效的方法(a
可以很长,b
的长度无关紧要)?
如何使用循环完成它的示例很简单:
a = np.array([1,2,3,4])
b = np.array([5,6])
rows = len(a) + 1
cols = len(a) + len(b)
res = np.empty([rows, cols])
for i in range(rows):
res[i, i:len(b)+i] = b
res[i, len(b)+i:] = a[i:]
res[i, 0:i] = a[0:i]
print(rows.astype(int))
[[5 6 1 2 3 4]
[1 5 6 2 3 4]
[1 2 5 6 3 4]
[1 2 3 5 6 4]
[1 2 3 4 5 6]]
我认为您可以使用掩码将对角线值 b
添加到 res
数组,并使用 a
[=14= 中的值添加数组中的其余单元格]
import numpy as np
a = np.array([1,2,3,4])
b = np.array([5,6])
rows = len(a) + 1
cols = len(a) + len(b)
res = np.empty([rows, cols])
# Frame for masking
frame = np.arange(rows * cols).reshape(rows, cols)
diag_mask = (frame % (res.shape[1] + 1)) < (b.shape[0])
res[diag_mask] = np.tile(b, res.shape[0])
res[~diag_mask] = np.tile(a, res.shape[0])
考虑使用 numba 加速。这恰好是 numba 最擅长的。对于你的例子,它可以提速近6倍:
from timeit import timeit
import numpy as np
from numba import njit
a = np.arange(1, 5)
b = np.array([5, 6])
def fill(a, b):
rows = a.size + 1
cols = a.size + b.size
res = np.empty((rows, cols))
for i in range(rows):
res[i, i:b.size + i] = b
res[i, b.size + i:] = a[i:]
res[i, :i] = a[:i]
return res
if __name__ == '__main__':
print('before:', timeit(lambda: fill(a, b)))
fill = njit(fill)
print('after:', timeit(lambda: fill(a, b)))
输出:
before: 9.488150399993174
after: 1.6149254000047222
您可以根据预期的形状创建一个零数组,然后用 b
值填充所需的索引,最后用所需形状的 a
数组的平铺填充剩余的零值如:
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind = np.lib.stride_tricks.sliding_window_view(np.arange(size), b.shape[0])
# [[0 1]
# [1 2]
# [2 3]
# [3 4]
# [4 5]]
zero_arr = np.zeros((shape, size), dtype=np.float64)
zero_arr[np.arange(shape)[:, None], ind] = b
# [[5 6 0 0 0 0]
# [0 5 6 0 0 0]
# [0 0 5 6 0 0]
# [0 0 0 5 6 0]
# [0 0 0 0 5 6]]
zero_arr[zero_arr == 0] = np.tile(a, shape)
# [[5 6 1 2 3 4]
# [1 5 6 2 3 4]
# [1 2 5 6 3 4]
# [1 2 3 5 6 4]
# [1 2 3 4 5 6]]
更新
如果我们创建基于 zero_arr
的大型数组 ,此方法在性能方面将击败 tax evader 方法b
数组中未包含的任何值。例如。如果我们在 b
中有 0
,那么 zero_arr == 0
将误导解决方案。如果 b
只包含正值,一种可能的方法是使用 -np.ones
和 zero_arr == -1
。如果知道哪个任意值不在使用上述索引方法中,我们可以使用 np.fill
或 np.full
创建此数组;可以使用 np.arange
或 np.random
选择此值,并检查以查找不在 b
中的值。但更全面的如下:
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind_0 = np.lib.stride_tricks.sliding_window_view(np.arange(size), b.shape[0])
ind_1 = np.lib.stride_tricks.sliding_window_view(np.arange(b.shape[0], size + shape - 1), a.shape[0])
ind_1 = ind_1 % size
arr = np.zeros((shape, size), dtype=np.float64)
arange_ = np.arange(shape)[:, None]
arr[arange_, ind_0] = b
arr[arange_, np.sort(ind_1)] = np.broadcast_to(a, (shape, a.shape[0])) # or use np.tile
如果您对使用其他库没有任何限制(因为您同意使用 numba 库)并且可以 运行 GPU 上的代码,我相信 JAX 库将在更大的数组上击败 numba。 我已将 NumPy 编写的代码转换为 JAX jitted 形式,以了解该库如何在性能方面处理此类矩阵形式问题。除了评估和比较 JAX 和 numba 在此问题上的性能的基准之外,这些代码还有一个关于如何使用 jax numpy where (jnp.where
) 和 JAX jit decorator 的学习方面,我们必须在其中指定大小 statically 是可行的。另一个方面是关于创建 equivalent np.lib.stride_tricks.sliding_window_view
or np.lib.stride_tricks.as_strided
in jax jitted function by jax library。 evader 代码也进行了转换,但进行了一些更改(我认为 JAX 使用需要);我不知道我是否可以用更紧凑的形式(更短)来写它。我认为可以以更优化的形式重写编写的代码,这将获得更快的代码。
from functools import partial
import jax
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, vmap
@partial(jit, static_argnums=(1,))
def moving_window(a, size: int):
starts = jnp.arange(len(a) - size + 1)
return vmap(lambda start: jax.lax.dynamic_slice(a, (start,), (size,)))(starts)
@jit
def jax_initial(a, b):
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind = moving_window(jnp.arange(size), b.shape[0])
arr = jnp.zeros((shape, size), dtype=jnp.float64)
arr = arr.at[jnp.arange(shape)[:, None], ind].set(b)
broad = jnp.ravel(jnp.broadcast_to(a, (shape, a.shape[0])))
idx = jnp.where(arr == 0, size=broad.size)
return arr.at[idx].set(broad)
@jit
def jax_comp(a, b):
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind_0 = moving_window(jnp.arange(size), b.shape[0])
ind_1 = moving_window(jnp.arange(b.shape[0], size + shape - 1), a.shape[0])
ind_1 = jnp.remainder(ind_1, size) # ind_1 = ind_1 % size
arr = jnp.zeros((shape, size), dtype=jnp.float64)
arange_ = jnp.arange(shape)[:, None]
arr = arr.at[arange_, ind_0].set(b)
arr = arr.at[arange_, jnp.sort(ind_1)].set(jnp.broadcast_to(a, (shape, a.shape[0])))
return arr
@jit
def jax_evader(a, b):
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
res = jnp.empty([shape, size])
frame = jnp.reshape(jnp.arange(shape * size), (shape, size))
diag_mask = (frame % (res.shape[1] + 1)) < (b.shape[0])
res_0 = jnp.ravel(jnp.broadcast_to(b, (shape, b.shape[0]))) # jnp.tile(b, res.shape[0])
res_1 = jnp.ravel(jnp.broadcast_to(a, (shape, a.shape[0]))) # jnp.tile(a, res.shape[0])
idx = jnp.where(diag_mask, size=res_0.size)
idx_v = jnp.where(~diag_mask, size=res_1.size)
res = res.at[idx].set(res_0)
return res.at[idx_v].set(res_1)
一些基准测试 temporary link
最快的代码:
我们可以将 numba 代码与签名并行,在我的系统上,运行 至少比 Pig 快 3 倍:
a = np.random.rand(10000)
b = np.array([5, 6, 7, 8, 9, 10, 11], dtype=np.int64)
@nb.njit('float64[::1], int64[::1]', parallel=True)
def fill_parallel(a, b):
rows = a.size + 1
cols = a.size + b.size
res = np.empty((rows, cols))
for i in nb.prange(rows):
res[i, i:b.size + i] = b
res[i, b.size + i:] = a[i:]
res[i, :i] = a[:i]
return res
numba 并行化代码是迄今为止最快的代码。
对于两个数组,比方说,a = np.array([1,2,3,4])
和 b = np.array([5,6])
,有没有办法(如果有的话)在没有循环的情况下获得以下形式的二维数组 :
[[5 6 1 2 3 4]
[1 5 6 2 3 4]
[1 2 5 6 3 4]
[1 2 3 5 6 4]
[1 2 3 4 5 6]]
即在 a
.
b
如果循环不可避免,如何计算最有效的方法(a
可以很长,b
的长度无关紧要)?
如何使用循环完成它的示例很简单:
a = np.array([1,2,3,4])
b = np.array([5,6])
rows = len(a) + 1
cols = len(a) + len(b)
res = np.empty([rows, cols])
for i in range(rows):
res[i, i:len(b)+i] = b
res[i, len(b)+i:] = a[i:]
res[i, 0:i] = a[0:i]
print(rows.astype(int))
[[5 6 1 2 3 4]
[1 5 6 2 3 4]
[1 2 5 6 3 4]
[1 2 3 5 6 4]
[1 2 3 4 5 6]]
我认为您可以使用掩码将对角线值 b
添加到 res
数组,并使用 a
[=14= 中的值添加数组中的其余单元格]
import numpy as np
a = np.array([1,2,3,4])
b = np.array([5,6])
rows = len(a) + 1
cols = len(a) + len(b)
res = np.empty([rows, cols])
# Frame for masking
frame = np.arange(rows * cols).reshape(rows, cols)
diag_mask = (frame % (res.shape[1] + 1)) < (b.shape[0])
res[diag_mask] = np.tile(b, res.shape[0])
res[~diag_mask] = np.tile(a, res.shape[0])
考虑使用 numba 加速。这恰好是 numba 最擅长的。对于你的例子,它可以提速近6倍:
from timeit import timeit
import numpy as np
from numba import njit
a = np.arange(1, 5)
b = np.array([5, 6])
def fill(a, b):
rows = a.size + 1
cols = a.size + b.size
res = np.empty((rows, cols))
for i in range(rows):
res[i, i:b.size + i] = b
res[i, b.size + i:] = a[i:]
res[i, :i] = a[:i]
return res
if __name__ == '__main__':
print('before:', timeit(lambda: fill(a, b)))
fill = njit(fill)
print('after:', timeit(lambda: fill(a, b)))
输出:
before: 9.488150399993174
after: 1.6149254000047222
您可以根据预期的形状创建一个零数组,然后用 b
值填充所需的索引,最后用所需形状的 a
数组的平铺填充剩余的零值如:
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind = np.lib.stride_tricks.sliding_window_view(np.arange(size), b.shape[0])
# [[0 1]
# [1 2]
# [2 3]
# [3 4]
# [4 5]]
zero_arr = np.zeros((shape, size), dtype=np.float64)
zero_arr[np.arange(shape)[:, None], ind] = b
# [[5 6 0 0 0 0]
# [0 5 6 0 0 0]
# [0 0 5 6 0 0]
# [0 0 0 5 6 0]
# [0 0 0 0 5 6]]
zero_arr[zero_arr == 0] = np.tile(a, shape)
# [[5 6 1 2 3 4]
# [1 5 6 2 3 4]
# [1 2 5 6 3 4]
# [1 2 3 5 6 4]
# [1 2 3 4 5 6]]
更新
如果我们创建基于 zero_arr
的大型数组 ,此方法在性能方面将击败 tax evader 方法b
数组中未包含的任何值。例如。如果我们在 b
中有 0
,那么 zero_arr == 0
将误导解决方案。如果 b
只包含正值,一种可能的方法是使用 -np.ones
和 zero_arr == -1
。如果知道哪个任意值不在使用上述索引方法中,我们可以使用 np.fill
或 np.full
创建此数组;可以使用 np.arange
或 np.random
选择此值,并检查以查找不在 b
中的值。但更全面的如下:
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind_0 = np.lib.stride_tricks.sliding_window_view(np.arange(size), b.shape[0])
ind_1 = np.lib.stride_tricks.sliding_window_view(np.arange(b.shape[0], size + shape - 1), a.shape[0])
ind_1 = ind_1 % size
arr = np.zeros((shape, size), dtype=np.float64)
arange_ = np.arange(shape)[:, None]
arr[arange_, ind_0] = b
arr[arange_, np.sort(ind_1)] = np.broadcast_to(a, (shape, a.shape[0])) # or use np.tile
如果您对使用其他库没有任何限制(因为您同意使用 numba 库)并且可以 运行 GPU 上的代码,我相信 JAX 库将在更大的数组上击败 numba。 我已将 NumPy 编写的代码转换为 JAX jitted 形式,以了解该库如何在性能方面处理此类矩阵形式问题。除了评估和比较 JAX 和 numba 在此问题上的性能的基准之外,这些代码还有一个关于如何使用 jax numpy where (jnp.where
) 和 JAX jit decorator 的学习方面,我们必须在其中指定大小 statically 是可行的。另一个方面是关于创建 equivalent np.lib.stride_tricks.sliding_window_view
or np.lib.stride_tricks.as_strided
in jax jitted function by jax library。 evader 代码也进行了转换,但进行了一些更改(我认为 JAX 使用需要);我不知道我是否可以用更紧凑的形式(更短)来写它。我认为可以以更优化的形式重写编写的代码,这将获得更快的代码。
from functools import partial
import jax
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, vmap
@partial(jit, static_argnums=(1,))
def moving_window(a, size: int):
starts = jnp.arange(len(a) - size + 1)
return vmap(lambda start: jax.lax.dynamic_slice(a, (start,), (size,)))(starts)
@jit
def jax_initial(a, b):
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind = moving_window(jnp.arange(size), b.shape[0])
arr = jnp.zeros((shape, size), dtype=jnp.float64)
arr = arr.at[jnp.arange(shape)[:, None], ind].set(b)
broad = jnp.ravel(jnp.broadcast_to(a, (shape, a.shape[0])))
idx = jnp.where(arr == 0, size=broad.size)
return arr.at[idx].set(broad)
@jit
def jax_comp(a, b):
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind_0 = moving_window(jnp.arange(size), b.shape[0])
ind_1 = moving_window(jnp.arange(b.shape[0], size + shape - 1), a.shape[0])
ind_1 = jnp.remainder(ind_1, size) # ind_1 = ind_1 % size
arr = jnp.zeros((shape, size), dtype=jnp.float64)
arange_ = jnp.arange(shape)[:, None]
arr = arr.at[arange_, ind_0].set(b)
arr = arr.at[arange_, jnp.sort(ind_1)].set(jnp.broadcast_to(a, (shape, a.shape[0])))
return arr
@jit
def jax_evader(a, b):
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
res = jnp.empty([shape, size])
frame = jnp.reshape(jnp.arange(shape * size), (shape, size))
diag_mask = (frame % (res.shape[1] + 1)) < (b.shape[0])
res_0 = jnp.ravel(jnp.broadcast_to(b, (shape, b.shape[0]))) # jnp.tile(b, res.shape[0])
res_1 = jnp.ravel(jnp.broadcast_to(a, (shape, a.shape[0]))) # jnp.tile(a, res.shape[0])
idx = jnp.where(diag_mask, size=res_0.size)
idx_v = jnp.where(~diag_mask, size=res_1.size)
res = res.at[idx].set(res_0)
return res.at[idx_v].set(res_1)
一些基准测试 temporary link
最快的代码:
我们可以将 numba 代码与签名并行,在我的系统上,运行 至少比 Pig 快 3 倍:
a = np.random.rand(10000)
b = np.array([5, 6, 7, 8, 9, 10, 11], dtype=np.int64)
@nb.njit('float64[::1], int64[::1]', parallel=True)
def fill_parallel(a, b):
rows = a.size + 1
cols = a.size + b.size
res = np.empty((rows, cols))
for i in nb.prange(rows):
res[i, i:b.size + i] = b
res[i, b.size + i:] = a[i:]
res[i, :i] = a[:i]
return res
numba 并行化代码是迄今为止最快的代码。