根据条件替换 numpy 数组块

Replacing chunks of numpy array on condition

假设我有以下 numpy 数组,仅由 1 和 0 组成:

import numpy as np

example = np.array([0,1,1,0,1,0,0,1,1], dtype=np.uint8)

我想将所有元素分组到 3 的块中,并根据条件用单个值替换这些块。 假设我希望 [0,1,1] 变为 5,并且 [0,1,0] 变为 10。 因此,所需的输出将是:

[5,10,5]

块中 1 和 0 的所有可能组合都有一个相应的唯一值,应该替换该块。最快的方法是什么?

我建议您将数组重塑为 3 by something 数组。现在我们可以将每一行视为一个二进制数,它是您想要的值列表的索引。您将其转换为该数字并索引到值中。

arr = np.array([0,1,1,0,1,0,0,1,1], dtype=np.uint8).reshape(-1,3)

idx = 2**0*arr[:,0]+2**1*arr[:,1]+2**2*arr[:,2]

values = np.zeros(2**3)
values[0 *2**0+ 1 *2**1+ 1 *2**2] = 5
values[0 *2**0+ 1 *2**1+ 0 *2**2] = 10

values[idx]

这给出了

array([ 5., 10.,  5.])

或者,如果您更喜欢更简洁地编写转换,但不太基础(感谢@mozway 的想法):

def bin_vect_to_int(arr):
    bin_units = 2**np.arange(arr.shape[1])
    return np.dot(arr,bin_units)


arr = np.array([0,1,1,0,1,0,0,1,1,0,1,1], dtype=np.uint8).reshape(-1,3)
idx = binVecToInt(arr)

values = np.zeros(2**3)
values[bin_vect_to_int(np.array([[0,1,1]]))] = 5
values[bin_vect_to_int(np.array([[0,1,0]]))] = 10

values[idx]

您可以使用 shape(3, -1) 的连续数组视图,找到唯一出现的位置并在这些位置替换它们:

def view_ascontiguous(a): # a is array
    a = np.ascontiguousarray(a)
    void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
    return a.view(void_dt).ravel()

def replace(x, args, subs, viewer):
    u, inv = np.unique(viewer(x), return_inverse=True)
    idx = np.searchsorted(viewer(args), u)
    return subs[idx][inv]

>>> replace(x=np.array([1, 0, 1, 0, 0, 1, 1, 0, 1]).reshape(-1, 3),
        args=np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]]),
        subs=np.array([ 5, 57, 58, 44, 67, 17, 77,  1]),
        viewer=view_ascontiguous)
array([17, 57, 17])

注意这里 idx 表示唯一 contiguous blocks of length N in a Power Set {0, 1}^N.

的位置

如果 viewer(args)np.searchsorted 方法中将 args 映射到 [0, 1, 2, 3, ...],将其替换为 np.arange(len(args)) 有助于提高性能。

这个算法也可以用于更一般的问题:


你得到了 dtype=np.uint8M*N 值 0 和 1 的数组。你还得到了幂集 [0, 1]^N 之间的映射(所有可能的长度块 N 的 0 & 1) 和一些标量值。按照以下步骤查找 M 个值的数组:

  • 将给定的数组拆分为 M 个长度为 N
  • 的连续块
  • 使用给定的映射用标量值替换每个块

现在,有趣的部分:您可以使用自己的 viewer。需要将您传入的数组 args 映射到任何类型的升序索引,如下所示:

viewer=lambda arr: np.ravel_multi_index(arr.T, (2,2,2)) #0, 1, 2, 3, 4, 5, 6, 7
viewer=lambda arr: np.sum(arr * [4, 2, 1], axis=1) #0, 1, 2, 3, 4, 5, 6, 7
viewer=lambda arr: np.dot(arr, [4, 2, 1]) #0, 1, 2, 3, 4, 5, 6, 7

或者更有趣:

viewer=lambda arr: 2*np.dot(arr, [4, 2, 1]) + 1 #1, 3, 5, 7, 9, 11, 13, 15
viewer=lambda arr: np.vectorize(chr)(97+np.dot(arr, [4, 2, 1])) #a b c d e f g h

因为你也可以映射

[[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]]

你能想到的任何升序,比如 [1, 3, 5, 7, 9, 11, 13, 15]['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] 结果还是一样。

进一步说明

感谢@MadPhysicist,

np.packbits(example.reshape(-1, N), axis=1, bitorder='little').ravel()

也可以解决问题。它假装是最快的解决方案,因为 np.packbitsnumpy 中得到了很好的优化。

如其他答案所示,您可以从重塑数组开始(实际上,您可能应该以正确的形状开始生成它,但这是另一个问题):

example = np.array([0, 1, 1, 0, 1, 0, 0, 1, 1], dtype=np.uint8)
data = example.reshape(-1, 3)

现在 运行 数组上的自定义 python 函数会变慢,但幸运的是 numpy 支持你。您可以使用 np.packbits 将每一行直接转换为数字:

data = np.packbits(data, axis=1, bitorder='little').ravel() # [6, 2, 6]

如果您希望 101 映射到 5 并且 110 映射到 6,您的工作就完成了。否则,您将需要提出一个映射。由于你有三位,所以映射数组中只需要8个数字:

mapping = np.array([7, 4, 3, 8, 124, 1, 5, 0])

您可以使用 data 作为直接进入 mapping 的索引。输出的类型为 mapping 但形状为 data:

result = mapping[data]  # [5, 3, 5]

您可以在一行中完成此操作:

mapping[np.packbits(example.reshape(-1, 3), axis=1, bitorder='little').ravel()]