Numpy - 沿轴的所有唯一组合或路径的乘积
Numpy - Product of All Unique Combinations or Paths along an axis
给定一个二维数组,我想找到沿最后一个轴的所有唯一“组合”或“路径”的乘积。
如果起始数组是
[[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12]]
那么最后一个轴上的唯一组合将是
[[1, 2, 3, 4, 5, 6],
[1, 2, 3, 4, 5, 12],
[1, 2, 3, 4, 11, 6],
[1, 2, 3, 4, 11, 12],
...,
[7, 8, 9, 10, 11, 6],
[7, 8, 9, 10, 11, 12]]
然后就可以申请了numpy.prod(, axis=-1)
。但是,是否有一种 faster/more 无需手动迭代即可检索这些唯一组合的有效方法?
一种解决方案是使用 itertools
生成所有可能的索引,然后使用 Numpy 间接索引来提取实际路径:
a = np.array([[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12]])
rowCount = a.shape[0]**a.shape[1]
x = np.tile(np.arange(a.shape[1]), rowCount).reshape(-1, a.shape[1])
y = np.array(list(itertools.product(*[range(2)]*6)))
result = a[y, x]
product = np.prod(a[y, x], axis=-1)
请注意,此解决方案对于较大的数组不是很有效。事实上,正如评论中所说,可能路径的数量呈指数增长(n**m
其中 n, m = a.shape
),因此所需的内存 space 增长在 O(n**m * m)
。甚至输出的大小 product
也会根据经验增长 (n**m
)。
您可以在 numpy
中使用此笛卡尔积的推广:
a = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
np.stack(np.meshgrid(*a.T), axis=-1).reshape(-1, len(a.T))
输出:
array([[ 1, 2, 3, 4, 5, 6],
[ 1, 2, 3, 4, 5, 12],
[ 1, 2, 3, 4, 11, 6],
...
[ 7, 8, 9, 10, 5, 12],
[ 7, 8, 9, 10, 11, 6],
[ 7, 8, 9, 10, 11, 12]])
我认为,如果您只想要最终产品,最好不要生成完整的组合数组。此方法迭代计算产品,从而减少内存使用。
a = np.array([[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12]])
from itertools import product
m, n = a.shape
combs = product(*[range(m)]*n)
products = [np.product(a[comb, range(n)]) for comb in combs]
print(products)
[720, 1440, 1584, 3168, 1800, 3600, ... etc.
或使用np.fromiter
products = np.fromiter((np.product(a[comb, range(n)]) for comb in combs),
dtype=float, count=m**n)
这是另一种方法,它根据 a
中的列数递归查找所有产品。输出顺序与其他答案不同;下面的基准显示比下一个最快的答案快 50 倍。
def products(arr):
# if there is one column left, return the last column
if arr.shape[1] == 1:
return arr[:,0]
mid = arr.shape[1] // 2
# return the flattened outer product of recursive calls
return np.outer(products(arr[:,:mid]), products(arr[:,mid:])).ravel()
a = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
result = products(a)
# [ 720 1440 1584 3168 1800 3600 3960 7920 2160 4320
# 4752 9504 5400 10800 11880 23760 2880 5760 6336 12672
# 7200 14400 15840 31680 8640 17280 19008 38016 21600 43200
# 47520 95040 5040 10080 11088 22176 12600 25200 27720 55440
# 15120 30240 33264 66528 37800 75600 83160 166320 20160 40320
# 44352 88704 50400 100800 110880 221760 60480 120960 133056 266112
# 151200 302400 332640 665280]
基准(4 行,10 列):
import itertools
def with_meshgrid(a):
b = np.stack(np.meshgrid(*a.T), axis=-1).reshape(-1, len(a.T))
res = np.prod(b, axis=-1)
return res
def with_itertools1(a):
m, n = a.shape
combs = itertools.product(*[range(m)]*n)
products = [np.product(a[comb, range(n)]) for comb in combs]
return products
def with_itertools2(a):
rowCount = a.shape[0]**a.shape[1]
x = np.tile(np.arange(a.shape[1]), rowCount).reshape(-1, a.shape[1])
y = np.array(list(itertools.product(*[range(a.shape[0])]*a.shape[1])))
result = a[y, x]
product = np.prod(a[y, x], axis=-1)
return product
a = np.random.randint(0, 100, size=(4, 10))
# %timeit products(a) # 3.69 ms
# %timeit with_meshgrid(a) # 201 ms
# %timeit with_itertools1(a) # timeout
# %timeit with_itertools2(a) # 3.1 s
assert (np.sort(products(b)) == np.sort(with_meshgrid(b))).all()
给定一个二维数组,我想找到沿最后一个轴的所有唯一“组合”或“路径”的乘积。
如果起始数组是
[[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12]]
那么最后一个轴上的唯一组合将是
[[1, 2, 3, 4, 5, 6],
[1, 2, 3, 4, 5, 12],
[1, 2, 3, 4, 11, 6],
[1, 2, 3, 4, 11, 12],
...,
[7, 8, 9, 10, 11, 6],
[7, 8, 9, 10, 11, 12]]
然后就可以申请了numpy.prod(, axis=-1)
。但是,是否有一种 faster/more 无需手动迭代即可检索这些唯一组合的有效方法?
一种解决方案是使用 itertools
生成所有可能的索引,然后使用 Numpy 间接索引来提取实际路径:
a = np.array([[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12]])
rowCount = a.shape[0]**a.shape[1]
x = np.tile(np.arange(a.shape[1]), rowCount).reshape(-1, a.shape[1])
y = np.array(list(itertools.product(*[range(2)]*6)))
result = a[y, x]
product = np.prod(a[y, x], axis=-1)
请注意,此解决方案对于较大的数组不是很有效。事实上,正如评论中所说,可能路径的数量呈指数增长(n**m
其中 n, m = a.shape
),因此所需的内存 space 增长在 O(n**m * m)
。甚至输出的大小 product
也会根据经验增长 (n**m
)。
您可以在 numpy
中使用此笛卡尔积的推广:
a = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
np.stack(np.meshgrid(*a.T), axis=-1).reshape(-1, len(a.T))
输出:
array([[ 1, 2, 3, 4, 5, 6],
[ 1, 2, 3, 4, 5, 12],
[ 1, 2, 3, 4, 11, 6],
...
[ 7, 8, 9, 10, 5, 12],
[ 7, 8, 9, 10, 11, 6],
[ 7, 8, 9, 10, 11, 12]])
我认为,如果您只想要最终产品,最好不要生成完整的组合数组。此方法迭代计算产品,从而减少内存使用。
a = np.array([[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12]])
from itertools import product
m, n = a.shape
combs = product(*[range(m)]*n)
products = [np.product(a[comb, range(n)]) for comb in combs]
print(products)
[720, 1440, 1584, 3168, 1800, 3600, ... etc.
或使用np.fromiter
products = np.fromiter((np.product(a[comb, range(n)]) for comb in combs),
dtype=float, count=m**n)
这是另一种方法,它根据 a
中的列数递归查找所有产品。输出顺序与其他答案不同;下面的基准显示比下一个最快的答案快 50 倍。
def products(arr):
# if there is one column left, return the last column
if arr.shape[1] == 1:
return arr[:,0]
mid = arr.shape[1] // 2
# return the flattened outer product of recursive calls
return np.outer(products(arr[:,:mid]), products(arr[:,mid:])).ravel()
a = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
result = products(a)
# [ 720 1440 1584 3168 1800 3600 3960 7920 2160 4320
# 4752 9504 5400 10800 11880 23760 2880 5760 6336 12672
# 7200 14400 15840 31680 8640 17280 19008 38016 21600 43200
# 47520 95040 5040 10080 11088 22176 12600 25200 27720 55440
# 15120 30240 33264 66528 37800 75600 83160 166320 20160 40320
# 44352 88704 50400 100800 110880 221760 60480 120960 133056 266112
# 151200 302400 332640 665280]
基准(4 行,10 列):
import itertools
def with_meshgrid(a):
b = np.stack(np.meshgrid(*a.T), axis=-1).reshape(-1, len(a.T))
res = np.prod(b, axis=-1)
return res
def with_itertools1(a):
m, n = a.shape
combs = itertools.product(*[range(m)]*n)
products = [np.product(a[comb, range(n)]) for comb in combs]
return products
def with_itertools2(a):
rowCount = a.shape[0]**a.shape[1]
x = np.tile(np.arange(a.shape[1]), rowCount).reshape(-1, a.shape[1])
y = np.array(list(itertools.product(*[range(a.shape[0])]*a.shape[1])))
result = a[y, x]
product = np.prod(a[y, x], axis=-1)
return product
a = np.random.randint(0, 100, size=(4, 10))
# %timeit products(a) # 3.69 ms
# %timeit with_meshgrid(a) # 201 ms
# %timeit with_itertools1(a) # timeout
# %timeit with_itertools2(a) # 3.1 s
assert (np.sort(products(b)) == np.sort(with_meshgrid(b))).all()