将数组插入另一个数组的所有位置

Insert array into all places in another array

对于两个数组,比方说,a = np.array([1,2,3,4])b = np.array([5,6]),有没有办法(如果有的话)在没有循环的情况下获得以下形式的二维数组 :

[[5 6 1 2 3 4]
 [1 5 6 2 3 4]
 [1 2 5 6 3 4]
 [1 2 3 5 6 4]
 [1 2 3 4 5 6]]

即在 a.

的所有可能位置插入 b

如果循环不可避免,如何计算最有效的方法(a可以很长,b的长度无关紧要)?

如何使用循环完成它的示例很简单:

a = np.array([1,2,3,4])
b = np.array([5,6])
rows = len(a) + 1
cols = len(a) + len(b)
res = np.empty([rows, cols])
for i in range(rows):
    res[i, i:len(b)+i] = b
    res[i, len(b)+i:] = a[i:]
    res[i, 0:i] = a[0:i]
print(rows.astype(int))

[[5 6 1 2 3 4]
 [1 5 6 2 3 4]
 [1 2 5 6 3 4]
 [1 2 3 5 6 4]
 [1 2 3 4 5 6]]

我认为您可以使用掩码将对角线值 b 添加到 res 数组,并使用 a[=14= 中的值添加数组中的其余单元格]

import numpy as np

a = np.array([1,2,3,4])
b = np.array([5,6])

rows = len(a) + 1
cols = len(a) + len(b)
res = np.empty([rows, cols])

# Frame for masking
frame = np.arange(rows * cols).reshape(rows, cols)

diag_mask = (frame % (res.shape[1] + 1)) < (b.shape[0])
res[diag_mask] = np.tile(b, res.shape[0])
res[~diag_mask] = np.tile(a, res.shape[0])

考虑使用 numba 加速。这恰好是 numba 最擅长的。对于你的例子,它可以提速近6倍:

from timeit import timeit

import numpy as np
from numba import njit


a = np.arange(1, 5)
b = np.array([5, 6])


def fill(a, b):
    rows = a.size + 1
    cols = a.size + b.size
    res = np.empty((rows, cols))
    for i in range(rows):
        res[i, i:b.size + i] = b
        res[i, b.size + i:] = a[i:]
        res[i, :i] = a[:i]
    return res


if __name__ == '__main__':
    print('before:', timeit(lambda: fill(a, b)))
    fill = njit(fill)
    print('after:', timeit(lambda: fill(a, b)))

输出:

before: 9.488150399993174
after: 1.6149254000047222

您可以根据预期的形状创建一个零数组,然后用 b 值填充所需的索引,最后用所需形状的 a 数组的平铺填充剩余的零值如:

shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind = np.lib.stride_tricks.sliding_window_view(np.arange(size), b.shape[0])
# [[0 1]
#  [1 2]
#  [2 3]
#  [3 4]
#  [4 5]]

zero_arr = np.zeros((shape, size), dtype=np.float64) 
zero_arr[np.arange(shape)[:, None], ind] = b
# [[5 6 0 0 0 0]
#  [0 5 6 0 0 0]
#  [0 0 5 6 0 0]
#  [0 0 0 5 6 0]
#  [0 0 0 0 5 6]]

zero_arr[zero_arr == 0] = np.tile(a, shape)
# [[5 6 1 2 3 4]
#  [1 5 6 2 3 4]
#  [1 2 5 6 3 4]
#  [1 2 3 5 6 4]
#  [1 2 3 4 5 6]]

更新

如果我们创建基于 zero_arr 的大型数组 ,此方法在性能方面将击败 tax evader 方法b 数组中未包含的任何值。例如。如果我们在 b 中有 0,那么 zero_arr == 0 将误导解决方案。如果 b 只包含正值,一种可能的方法是使用 -np.oneszero_arr == -1。如果知道哪个任意值不在使用上述索引方法中,我们可以使用 np.fillnp.full 创建此数组;可以使用 np.arangenp.random 选择此值,并检查以查找不在 b 中的值。但更全面的如下:

shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind_0 = np.lib.stride_tricks.sliding_window_view(np.arange(size), b.shape[0])
ind_1 = np.lib.stride_tricks.sliding_window_view(np.arange(b.shape[0], size + shape - 1), a.shape[0])
ind_1 = ind_1 % size
arr = np.zeros((shape, size), dtype=np.float64)
arange_ = np.arange(shape)[:, None]
arr[arange_, ind_0] = b
arr[arange_, np.sort(ind_1)] = np.broadcast_to(a, (shape, a.shape[0]))  # or use np.tile

如果您对使用其他库没有任何限制(因为您同意使用 numba 库)并且可以 运行 GPU 上的代码,我相信 JAX 库将在更大的数组上击败 numba。 我已将 NumPy 编写的代码转换为 JAX jitted 形式,以了解该库如何在性能方面处理此类矩阵形式问题。除了评估和比较 JAX 和 numba 在此问题上的性能的基准之外,这些代码还有一个关于如何使用 jax numpy where (jnp.where) 和 JAX jit decorator 的学习方面,我们必须在其中指定大小 statically 是可行的。另一个方面是关于创建 equivalent np.lib.stride_tricks.sliding_window_view or np.lib.stride_tricks.as_strided in jax jitted function by jax libraryevader 代码也进行了转换,但进行了一些更改(我认为 JAX 使用需要);我不知道我是否可以用更紧凑的形式(更短)来写它。我认为可以以更优化的形式重写编写的代码,这将获得更快的代码。

from functools import partial
import jax
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, vmap

@partial(jit, static_argnums=(1,))
def moving_window(a, size: int):
    starts = jnp.arange(len(a) - size + 1)
    return vmap(lambda start: jax.lax.dynamic_slice(a, (start,), (size,)))(starts)

@jit
def jax_initial(a, b):
    shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
    ind = moving_window(jnp.arange(size), b.shape[0])
    arr = jnp.zeros((shape, size), dtype=jnp.float64)
    arr = arr.at[jnp.arange(shape)[:, None], ind].set(b)
    broad = jnp.ravel(jnp.broadcast_to(a, (shape, a.shape[0])))
    idx = jnp.where(arr == 0, size=broad.size) 
    return arr.at[idx].set(broad)

@jit
def jax_comp(a, b):
    shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
    ind_0 = moving_window(jnp.arange(size), b.shape[0])
    ind_1 = moving_window(jnp.arange(b.shape[0], size + shape - 1), a.shape[0])
    ind_1 = jnp.remainder(ind_1, size)                    # ind_1 = ind_1 % size
    arr = jnp.zeros((shape, size), dtype=jnp.float64)
    arange_ = jnp.arange(shape)[:, None]
    arr = arr.at[arange_, ind_0].set(b)
    arr = arr.at[arange_, jnp.sort(ind_1)].set(jnp.broadcast_to(a, (shape, a.shape[0])))
    return arr

@jit
def jax_evader(a, b):
    shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
    res = jnp.empty([shape, size])
    frame = jnp.reshape(jnp.arange(shape * size), (shape, size))
    diag_mask = (frame % (res.shape[1] + 1)) < (b.shape[0])
    res_0 = jnp.ravel(jnp.broadcast_to(b, (shape, b.shape[0])))  # jnp.tile(b, res.shape[0])
    res_1 = jnp.ravel(jnp.broadcast_to(a, (shape, a.shape[0])))  # jnp.tile(a, res.shape[0])
    idx = jnp.where(diag_mask, size=res_0.size)
    idx_v = jnp.where(~diag_mask, size=res_1.size)
    res = res.at[idx].set(res_0)
    return res.at[idx_v].set(res_1)

一些基准测试 temporary link



最快的代码:

我们可以将 numba 代码与签名并行,在我的系统上,运行 至少比 Pig 快 3 倍:

a = np.random.rand(10000)
b = np.array([5, 6, 7, 8, 9, 10, 11], dtype=np.int64)

@nb.njit('float64[::1], int64[::1]', parallel=True)
def fill_parallel(a, b):
    rows = a.size + 1
    cols = a.size + b.size
    res = np.empty((rows, cols))
    for i in nb.prange(rows):
        res[i, i:b.size + i] = b
        res[i, b.size + i:] = a[i:]
        res[i, :i] = a[:i]
    return res

numba 并行化代码是迄今为止最快的代码。