交错形状不匹配的 NumPy 数组
Interleaving NumPy arrays with mismatching shapes
我想沿特定轴交错具有不同维度的多个 numpy 数组。特别是,我有一个形状为 (_, *dims)
的数组列表,沿第一个轴变化,我想将其交错以获得另一个形状为 (_, *dims)
的数组。例如,给定输入
a1 = np.array([[11,12], [41,42]])
a2 = np.array([[21,22], [51,52], [71,72], [91,92], [101,102]])
a3 = np.array([[31,32], [61,62], [81,82]])
interweave(a1,a2,a3)
所需的输出将是
np.array([[11,12], [21,22], [31,32], [41,42], [51,52], [61,62], [71,72], [81,82], [91,92], [101,102]]
在以前的帖子(例如 )的帮助下,当数组沿第一个维度匹配时,我已经开始工作了:
import numpy as np
def interweave(*arrays, stack_axis=0, weave_axis=1):
final_shape = list(arrays[0].shape)
final_shape[stack_axis] = -1
# stack up arrays along the "weave axis", then reshape back to desired shape
return np.concatenate(arrays, axis=weave_axis).reshape(final_shape)
不幸的是,如果输入形状在第一个维度上不匹配,上面会抛出异常,因为我们必须沿着与不匹配的轴不同的轴连接。事实上,我在这里看不到任何有效使用连接的方法,因为沿着不匹配的轴连接会破坏我们生成所需输出所需的信息。
我的另一个想法是用空条目填充输入数组,直到它们的形状沿第一个维度匹配,然后在一天结束时删除空条目。虽然这可行,但我不确定如何最好地实施它,而且似乎一开始就没有必要。
这是一个主要基于 NumPy
的方法,它还使用 zip_longest
用填充值交错数组:
def interleave(*a):
# zip_longest filling values with as many NaNs as
# values in second axis
l = *zip_longest(*a, fillvalue=[np.nan]*a[0].shape[1]),
# build a 2d array from the list
out = np.concatenate(l)
# return non-NaN values
return out[~np.isnan(out[:,0])]
a1 = np.array([[11,12], [41,42]])
a2 = np.array([[21,22], [51,52], [71,72], [91,92], [101,102]])
a3 = np.array([[31,32], [61,62], [81,82]])
interleave(a1,a2,a3)
array([[ 11., 12.],
[ 21., 22.],
[ 31., 32.],
[ 41., 42.],
[ 51., 52.],
[ 61., 62.],
[ 71., 72.],
[ 81., 82.],
[ 91., 92.],
[101., 102.]])
您可能正在寻找 np.choose
。使用正确构造的索引,您可以一次调用结果:
def interweave(*arrays, axis=0):
arrays = [np.moveaxis(a, axis, 0) for a in arrays]
m = len(arrays)
n = max(map(len, arrays))
index = [k for i, k in (divmod(x, m) for x in range(m * n)) if i < len(arrays[k])]
return np.moveaxis(np.choose(index, arrays), 0, axis)
range(m * n)
是输出的大小 space 如果所有数组的大小都相同。 divmod
计算交错的元素和从中选择它的数组。由于数组太短而缺失的元素被跳过,所以结果只从数组中选择有效元素。
可能有更好的方法来制作索引,但这只是一个例子。你必须 move 堆栈轴到第一个位置,因为 choose
沿着第一个轴。
我继续概括了 yatu 对我在实践中面临的情况的回答,其中维数是任意的。这是我所拥有的:
import numpy as np
from itertools import zip_longest
def interleave(*a):
#creating padding array of NaNs
fill_shape = a[0].shape[1:]
fill_array = np.full(fill_shape,np.nan)
l = *zip_longest(*a, fillvalue=fill_array),
# build a 2d array from the list
out = np.concatenate(l)
# return non-NaN values
tup = (0,)*(len(out.shape)-1)
return out[~np.isnan(out[(...,)+tup])]
正在测试:
b1 = np.array(
[
[[111,112,113],[121,122,123]],
[[411,412,413],[421,422,423]]
])
b2=np.array(
[
[[211,212,213],[221,222,223]],
[[511,512,513],[521,522,523]],
[[711,712,713],[721,722,712]],
[[911,912,913],[921,922,923]],
[[1011,1012,1013],[1021,1022,1023]]
])
b3=np.array(
[
[[311,312,313],[321,322,323]],
[[611,612,613],[621,622,623]],
[[811,812,813],[821,822,823]]
])
In [1]: interleave(b1,b2,b3)
Out [1]: [[[ 111. 112. 113.]
[ 121. 122. 123.]]
[[ 211. 212. 213.]
[ 221. 222. 223.]]
[[ 311. 312. 313.]
[ 321. 322. 323.]]
[[ 411. 412. 413.]
[ 421. 422. 423.]]
[[ 511. 512. 513.]
[ 521. 522. 523.]]
[[ 611. 612. 613.]
[ 621. 622. 623.]]
[[ 711. 712. 713.]
[ 721. 722. 712.]]
[[ 811. 812. 813.]
[ 821. 822. 823.]]
[[ 911. 912. 913.]
[ 921. 922. 923.]]
[[1011. 1012. 1013.]
[1021. 1022. 1023.]]]
欢迎提出任何建议!特别是,在我的应用程序中,space,而不是时间,是限制因素,所以我想知道是否有一种方法可以使用更少的内存来做到这一点(数据集沿合并轴很大)。
我想沿特定轴交错具有不同维度的多个 numpy 数组。特别是,我有一个形状为 (_, *dims)
的数组列表,沿第一个轴变化,我想将其交错以获得另一个形状为 (_, *dims)
的数组。例如,给定输入
a1 = np.array([[11,12], [41,42]])
a2 = np.array([[21,22], [51,52], [71,72], [91,92], [101,102]])
a3 = np.array([[31,32], [61,62], [81,82]])
interweave(a1,a2,a3)
所需的输出将是
np.array([[11,12], [21,22], [31,32], [41,42], [51,52], [61,62], [71,72], [81,82], [91,92], [101,102]]
在以前的帖子(例如
import numpy as np
def interweave(*arrays, stack_axis=0, weave_axis=1):
final_shape = list(arrays[0].shape)
final_shape[stack_axis] = -1
# stack up arrays along the "weave axis", then reshape back to desired shape
return np.concatenate(arrays, axis=weave_axis).reshape(final_shape)
不幸的是,如果输入形状在第一个维度上不匹配,上面会抛出异常,因为我们必须沿着与不匹配的轴不同的轴连接。事实上,我在这里看不到任何有效使用连接的方法,因为沿着不匹配的轴连接会破坏我们生成所需输出所需的信息。
我的另一个想法是用空条目填充输入数组,直到它们的形状沿第一个维度匹配,然后在一天结束时删除空条目。虽然这可行,但我不确定如何最好地实施它,而且似乎一开始就没有必要。
这是一个主要基于 NumPy
的方法,它还使用 zip_longest
用填充值交错数组:
def interleave(*a):
# zip_longest filling values with as many NaNs as
# values in second axis
l = *zip_longest(*a, fillvalue=[np.nan]*a[0].shape[1]),
# build a 2d array from the list
out = np.concatenate(l)
# return non-NaN values
return out[~np.isnan(out[:,0])]
a1 = np.array([[11,12], [41,42]])
a2 = np.array([[21,22], [51,52], [71,72], [91,92], [101,102]])
a3 = np.array([[31,32], [61,62], [81,82]])
interleave(a1,a2,a3)
array([[ 11., 12.],
[ 21., 22.],
[ 31., 32.],
[ 41., 42.],
[ 51., 52.],
[ 61., 62.],
[ 71., 72.],
[ 81., 82.],
[ 91., 92.],
[101., 102.]])
您可能正在寻找 np.choose
。使用正确构造的索引,您可以一次调用结果:
def interweave(*arrays, axis=0):
arrays = [np.moveaxis(a, axis, 0) for a in arrays]
m = len(arrays)
n = max(map(len, arrays))
index = [k for i, k in (divmod(x, m) for x in range(m * n)) if i < len(arrays[k])]
return np.moveaxis(np.choose(index, arrays), 0, axis)
range(m * n)
是输出的大小 space 如果所有数组的大小都相同。 divmod
计算交错的元素和从中选择它的数组。由于数组太短而缺失的元素被跳过,所以结果只从数组中选择有效元素。
可能有更好的方法来制作索引,但这只是一个例子。你必须 move 堆栈轴到第一个位置,因为 choose
沿着第一个轴。
我继续概括了 yatu 对我在实践中面临的情况的回答,其中维数是任意的。这是我所拥有的:
import numpy as np
from itertools import zip_longest
def interleave(*a):
#creating padding array of NaNs
fill_shape = a[0].shape[1:]
fill_array = np.full(fill_shape,np.nan)
l = *zip_longest(*a, fillvalue=fill_array),
# build a 2d array from the list
out = np.concatenate(l)
# return non-NaN values
tup = (0,)*(len(out.shape)-1)
return out[~np.isnan(out[(...,)+tup])]
正在测试:
b1 = np.array(
[
[[111,112,113],[121,122,123]],
[[411,412,413],[421,422,423]]
])
b2=np.array(
[
[[211,212,213],[221,222,223]],
[[511,512,513],[521,522,523]],
[[711,712,713],[721,722,712]],
[[911,912,913],[921,922,923]],
[[1011,1012,1013],[1021,1022,1023]]
])
b3=np.array(
[
[[311,312,313],[321,322,323]],
[[611,612,613],[621,622,623]],
[[811,812,813],[821,822,823]]
])
In [1]: interleave(b1,b2,b3)
Out [1]: [[[ 111. 112. 113.]
[ 121. 122. 123.]]
[[ 211. 212. 213.]
[ 221. 222. 223.]]
[[ 311. 312. 313.]
[ 321. 322. 323.]]
[[ 411. 412. 413.]
[ 421. 422. 423.]]
[[ 511. 512. 513.]
[ 521. 522. 523.]]
[[ 611. 612. 613.]
[ 621. 622. 623.]]
[[ 711. 712. 713.]
[ 721. 722. 712.]]
[[ 811. 812. 813.]
[ 821. 822. 823.]]
[[ 911. 912. 913.]
[ 921. 922. 923.]]
[[1011. 1012. 1013.]
[1021. 1022. 1023.]]]
欢迎提出任何建议!特别是,在我的应用程序中,space,而不是时间,是限制因素,所以我想知道是否有一种方法可以使用更少的内存来做到这一点(数据集沿合并轴很大)。