在 NumPy 中将带有索引的切片混合到 mgrid 输入
Mixing slices with indexes to mgrid input in NumPy
np.mgrid 接受切片元组,如 np.mgrid[1:3, 4:8]
或 np.mgrid[np.s_[1:3, 4:8]]
.
但是有没有办法在 mgrid 的元组参数中混合切片和索引数组?例如:
extended_mgrid(np.s_[1:3, 4:8] + (np.array([1,2,3]), np.array([7,8])))
应该给出与
相同的结果
np.mgrid[1:3, 4:8, 1:4, 7:9]
但通常元组中的索引数组可能无法表示为切片。
解决此任务需要能够创建索引的 N 维元组,提供使用 np.mgrid
的切片 + 索引的混合,如 .
任务已解决 of @hpaulj using np.meshgrid。
import numpy as np
def extended_mgrid(i):
res = np.meshgrid(*[(
np.arange(e.start or 0, e.stop, e.step or 1)
if type(e) is slice else e
) for e in {slice: (i,), np.ndarray: (i,), tuple: i}[type(i)]
], indexing = 'ij')
return np.stack(res, 0) if type(i) is tuple else res[0]
# Tests
a = np.mgrid[1:3]
b = extended_mgrid(np.s_[1:3])
assert np.array_equal(a, b), (a, b)
a = np.mgrid[(np.s_[1:3],)]
b = extended_mgrid((np.s_[1:3],))
assert np.array_equal(a, b), (a, b)
a = np.array([[[1,1],[2,2]],[[3,4],[3,4]]])
b = extended_mgrid((np.array([1,2]), np.array([3,4])))
assert np.array_equal(a, b), (a, b)
a = np.mgrid[1:3, 4:8, 1:4, 7:9]
b = extended_mgrid(np.s_[1:3, 4:8] + (np.array([1,2,3]), np.array([7,8])))
assert np.array_equal(a, b), (a, b)
np.mgrid 接受切片元组,如 np.mgrid[1:3, 4:8]
或 np.mgrid[np.s_[1:3, 4:8]]
.
但是有没有办法在 mgrid 的元组参数中混合切片和索引数组?例如:
extended_mgrid(np.s_[1:3, 4:8] + (np.array([1,2,3]), np.array([7,8])))
应该给出与
相同的结果np.mgrid[1:3, 4:8, 1:4, 7:9]
但通常元组中的索引数组可能无法表示为切片。
解决此任务需要能够创建索引的 N 维元组,提供使用 np.mgrid
的切片 + 索引的混合,如
任务已解决
import numpy as np
def extended_mgrid(i):
res = np.meshgrid(*[(
np.arange(e.start or 0, e.stop, e.step or 1)
if type(e) is slice else e
) for e in {slice: (i,), np.ndarray: (i,), tuple: i}[type(i)]
], indexing = 'ij')
return np.stack(res, 0) if type(i) is tuple else res[0]
# Tests
a = np.mgrid[1:3]
b = extended_mgrid(np.s_[1:3])
assert np.array_equal(a, b), (a, b)
a = np.mgrid[(np.s_[1:3],)]
b = extended_mgrid((np.s_[1:3],))
assert np.array_equal(a, b), (a, b)
a = np.array([[[1,1],[2,2]],[[3,4],[3,4]]])
b = extended_mgrid((np.array([1,2]), np.array([3,4])))
assert np.array_equal(a, b), (a, b)
a = np.mgrid[1:3, 4:8, 1:4, 7:9]
b = extended_mgrid(np.s_[1:3, 4:8] + (np.array([1,2,3]), np.array([7,8])))
assert np.array_equal(a, b), (a, b)