在 Python 中有效地插入交替的行和列

Efficiently insert alternate rows and columns in Python

我需要用零的行和列交替 Pytorch 张量(类似于 numpy 数组)。像这样:

Input => [[ 1,2,3],
           [ 4,5,6],
           [ 7,8,9]]

output => [[ 1,0,2,0,3],
           [ 0,0,0,0,0],
           [ 4,0,5,0,6],
           [ 0,0,0,0,0],
           [ 7,0,8,0,9]] 

我正在使用 提出以下建议

def insert_zeros(a, N=1):
    # a : Input array
    # N : number of zeros to be inserted between consecutive rows and cols 
    out = np.zeros( (N+1)*np.array(a.shape)-N,dtype=a.dtype)
    out[::N+1,::N+1] = a
    return out

答案完美无缺,除了我需要在许多阵列上多次执行此操作并且所花费的时间已成为瓶颈。花费大部分时间的是步长切片。

对于它的价值,我使用它的矩阵是 4D,矩阵的示例大小是 32x18x16x16,我只在最后两个维度中插入备用 rows/cols。

所以我的问题是,是否有另一种具有 相同功能但时间更短的实现?

我找到了几种方法来实现这个结果,索引方法似乎一直是最快的。

虽然其他方法可能会有一些改进,因为我试图将它们从 1D 推广到 2D 和任意数量的主要维度,但可能没有以最好的方式做到这一点。

编辑:另一种使用 numpy 的方法,但速度并不快。

性能测试(CPU):

In [4]: N, C, H, W = 11, 5, 128, 128
   ...: x = torch.rand(N, C, H, W)
   ...: k = 3
   ...:
   ...: x1 = interleave_index(x, k)
   ...: x2 = interleave_view(x, k)
   ...: x3 = interleave_einops(x, k)
   ...: x4 = interleave_convtranspose(x, k)
   ...: x4 = interleave_numpy(x, k)
   ...:
   ...: assert torch.all(x1 == x2)
   ...: assert torch.all(x2 == x3)
   ...: assert torch.all(x3 == x4)
   ...: assert torch.all(x4 == x5)
   ...:
   ...: %timeit interleave_index(x, k)
   ...: %timeit interleave_view(x, k)
   ...: %timeit interleave_einops(x, k)
   ...: %timeit interleave_convtranspose(x, k)
   ...: %timeit interleave_numpy(x, k)

9.51 ms ± 2.21 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
12.6 ms ± 4.98 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
23.3 ms ± 4.19 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
62.5 ms ± 19.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
50.6 ms ± 809 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

性能测试(GPU):

(numpy 方法未测试)

...: ...
...: x = torch.rand(N, C, H, W, device="cuda")
...: ...
260 µs ± 1.92 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
861 µs ± 6.77 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
912 µs ± 14.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
429 µs ± 5.08 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

实施:

import torch
import torch.nn.functional as F
import einops


def interleave_index(x, k):
    *cdims, Hin, Win = x.shape
    Hout = (k + 1) * (Hin - 1) + 1
    Wout = (k + 1) * (Win - 1) + 1
    out = x.new_zeros(*cdims, Hout, Wout)
    out[..., :: k + 1, :: k + 1] = x
    return out


def interleave_view(x, k):
    """
    From
    https://discuss.pytorch.org/t/how-to-interleave-two-tensors-along-certain-dimension/11332/4
    """
    *cdims, Hin, Win = x.shape
    Hout = (k + 1) * (Hin - 1) + 1
    Wout = (k + 1) * (Win - 1) + 1
    zeros = [torch.zeros_like(x)] * k
    out = torch.stack([x, *zeros], dim=-1).view(*cdims, Hin, Wout + k)[..., :-k]
    zeros = [torch.zeros_like(out)] * k
    out = torch.stack([out, *zeros], dim=-2).view(*cdims, Hout + k, Wout)[..., :-k, :]
    return out


def interleave_einops(x, k):
    """
    From
    https://discuss.pytorch.org/t/how-to-interleave-two-tensors-along-certain-dimension/11332/6
    """
    zeros = [torch.zeros_like(x)] * k
    out = einops.rearrange([x, *zeros], "t ... h w -> ... h (w t)")[..., :-k]
    zeros = [torch.zeros_like(out)] * k
    out = einops.rearrange([out, *zeros], "t ... h w -> ... (h t) w")[..., :-k, :]
    return out


def interleave_convtranspose(x, k):
    """
    From
    https://github.com/pytorch/pytorch/issues/7911#issuecomment-515493009
    """
    C = x.shape[-3]
    weight=x.new_ones(C, 1, 1, 1)
    return F.conv_transpose2d(x, weight=weight, stride=k+1, groups=C)


def interleave_numpy(x, k):
    """
    From 
    """
    pos = np.repeat(np.arange(1, x.shape[-1]), k)
    out = np.insert(x, pos, 0, axis=-1)
    pos = np.repeat(np.arange(1, x.shape[-2]), k)
    out = np.insert(out, pos, 0, axis=-2)
    return out

我不熟悉 Pytorch,但是为了加速您提供的代码,我认为 JAX 库会有很大帮助。所以,如果:

import numpy as np
import jax
import jax.numpy as jnp
from functools import partial

a = np.arange(10000).reshape(100, 100)
b = jnp.array(a)

@partial(jax.jit, static_argnums=1)
def new(a, N):
    out = jnp.zeros( (N+1)*np.array(a.shape)-N,dtype=a.dtype)
    out = out.at[::N+1,::N+1].set(a)
    return out

将在 GPU 上提高大约 10 倍的运行时间。这取决于数组大小和 N(大小增加,性能更好)。根据目前提出的 4 个答案,您可以在我的 Colab link 上看到 BenchmarksJAX 击败其他人)。
我相信 jax 如果您可以根据您的问题进行调整(这是可能的),那么它可以成为您案例的最佳库之一。

为任何 N 封装,如何将 numpy.kron 与 4D 输入一起使用,

a = np.arange(1, 19).reshape((1, 2, 3, 3))
print(a)
# array([[[[ 1,  2,  3],
#          [ 4,  5,  6],
#          [ 7,  8,  9]],
# 
#         [[10, 11, 12],
#          [13, 14, 15],
#          [16, 17, 18]]]])


def interleave_kron(a, N=1):
    n = N + 1
    return np.kron(
        a, np.hstack((1, np.zeros(pow(n, 2) - 1))).reshape((1, 1, n, n))
    )[..., :-N, :-N]

其中 np.hstack((1, np.zeros(pow(n, 2) - 1))).reshape((1, 1, n, n)) 可以 externalized/defaulted 为了性能而一劳永逸。

然后

>>> interleave_kron(a, N=2)
array([[[[ 1.,  0.,  0.,  2.,  0.,  0.,  3.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 4.,  0.,  0.,  5.,  0.,  0.,  6.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 7.,  0.,  0.,  8.,  0.,  0.,  9.]],

        [[10.,  0.,  0., 11.,  0.,  0., 12.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [13.,  0.,  0., 14.,  0.,  0., 15.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [16.,  0.,  0., 17.,  0.,  0., 18.]]]])

?

由于您事先知道数组的大小,因此优化的第一步是在函数外部创建 out 数组。然后,尝试 numba 到 jit-compile 函数并在 out 数组上工作 in-place。这比您发布的 numpy 版本实现了 5 倍的加速。

import numpy as np
from numba import njit

@njit
def insert_zeros_n(a, out, N=1):
    for i in range(a.shape[0]):
        for j in range(a.shape[1]):
            out[2*i,2*j] = a[i,j]

并使用指定的 Na:

调用它
N = 1
a = np.arange(16*16).reshape(16, 16)
out = np.zeros( (N+1)*np.array(a.shape)-N,dtype=a.dtype)
insert_zeros_n(a,out)