Numpy - 沿轴的所有唯一组合或路径的乘积

Numpy - Product of All Unique Combinations or Paths along an axis

给定一个二维数组,我想找到沿最后一个轴的所有唯一“组合”或“路径”的乘积。

如果起始数组是

[[1, 2, 3, 4, 5, 6],
 [7, 8, 9, 10, 11, 12]]

那么最后一个轴上的唯一组合将是

[[1, 2, 3, 4, 5, 6],
 [1, 2, 3, 4, 5, 12],
 [1, 2, 3, 4, 11, 6],
 [1, 2, 3, 4, 11, 12],
 ...,
 [7, 8, 9, 10, 11, 6],
 [7, 8, 9, 10, 11, 12]]

然后就可以申请了numpy.prod(, axis=-1)。但是,是否有一种 faster/more 无需手动迭代即可检索这些唯一组合的有效方法?

一种解决方案是使用 itertools 生成所有可能的索引,然后使用 Numpy 间接索引来提取实际路径:

a = np.array([[1, 2, 3,  4,  5,  6],
              [7, 8, 9, 10, 11, 12]])

rowCount = a.shape[0]**a.shape[1]
x = np.tile(np.arange(a.shape[1]), rowCount).reshape(-1, a.shape[1])
y = np.array(list(itertools.product(*[range(2)]*6)))
result = a[y, x]
product = np.prod(a[y, x], axis=-1)

请注意,此解决方案对于较大的数组不是很有效。事实上,正如评论中所说,可能路径的数量呈指数增长n**m 其中 n, m = a.shape),因此所需的内存 space 增长在 O(n**m * m)。甚至输出的大小 product 也会根据经验增长 (n**m)。

您可以在 numpy 中使用此笛卡尔积的推广:

a = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
np.stack(np.meshgrid(*a.T), axis=-1).reshape(-1, len(a.T))

输出:

array([[ 1,  2,  3,  4,  5,  6],
       [ 1,  2,  3,  4,  5, 12],
       [ 1,  2,  3,  4, 11,  6],
                            ...
       [ 7,  8,  9, 10,  5, 12],
       [ 7,  8,  9, 10, 11,  6],
       [ 7,  8,  9, 10, 11, 12]])

我认为,如果您只想要最终产品,最好不要生成完整的组合数组。此方法迭代计算产品,从而减少内存使用。

a = np.array([[1, 2, 3,  4,  5,  6],
              [7, 8, 9, 10, 11, 12]])

from itertools import product

m, n = a.shape
combs = product(*[range(m)]*n)
products = [np.product(a[comb, range(n)]) for comb in combs]
print(products)

[720, 1440, 1584, 3168, 1800, 3600, ... etc.

或使用np.fromiter

products = np.fromiter((np.product(a[comb, range(n)]) for comb in combs), 
                       dtype=float, count=m**n)

这是另一种方法,它根据 a 中的列数递归查找所有产品。输出顺序与其他答案不同;下面的基准显示比下一个最快的答案快 50 倍。

def products(arr):
  # if there is one column left, return the last column
  if arr.shape[1] == 1:
    return arr[:,0]
  mid = arr.shape[1] // 2
  # return the flattened outer product of recursive calls
  return np.outer(products(arr[:,:mid]), products(arr[:,mid:])).ravel()

a = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
result = products(a)
# [   720   1440   1584   3168   1800   3600   3960   7920   2160   4320
#    4752   9504   5400  10800  11880  23760   2880   5760   6336  12672
#    7200  14400  15840  31680   8640  17280  19008  38016  21600  43200
#   47520  95040   5040  10080  11088  22176  12600  25200  27720  55440
#   15120  30240  33264  66528  37800  75600  83160 166320  20160  40320
#   44352  88704  50400 100800 110880 221760  60480 120960 133056 266112
#  151200 302400 332640 665280]

基准(4 行,10 列):

import itertools

def with_meshgrid(a):  
  b = np.stack(np.meshgrid(*a.T), axis=-1).reshape(-1, len(a.T))
  res = np.prod(b, axis=-1)
  return res

def with_itertools1(a):
  m, n = a.shape
  combs = itertools.product(*[range(m)]*n)
  products = [np.product(a[comb, range(n)]) for comb in combs]
  return products

def with_itertools2(a):
  rowCount = a.shape[0]**a.shape[1]
  x = np.tile(np.arange(a.shape[1]), rowCount).reshape(-1, a.shape[1])
  y = np.array(list(itertools.product(*[range(a.shape[0])]*a.shape[1])))
  result = a[y, x]
  product = np.prod(a[y, x], axis=-1)
  return product

a = np.random.randint(0, 100, size=(4, 10))
# %timeit products(a)        # 3.69 ms
# %timeit with_meshgrid(a)   # 201 ms
# %timeit with_itertools1(a) # timeout
# %timeit with_itertools2(a) # 3.1 s
assert (np.sort(products(b)) == np.sort(with_meshgrid(b))).all()