如何使用 argmax 将 return 索引到无法重新整形为矩阵的多维 ndarray 中?
How to use argmax to return indices into multidimensional ndarray that cannot be re-shaped into a matrix?
给定一个包含 2 个 9x9 图像的数组,其中 2 个通道形状如下:
img1 = img1 = np.arange(162).reshape(9,9,2).copy()
img2 = img1 * 2
batch = np.array([img1, img2])
我需要将每个图像切片为 3x3x2(步长=3)个区域,然后定位并替换每个切片的最大元素。对于上面的示例,这些元素是:
(:, 2, 2, :)
(:, 2, 5, :)
(:, 2, 8, :)
(:, 5, 2, :)
(:, 5, 5, :)
(:, 5, 8, :)
(:, 8, 2, :)
(:, 8, 5, :)
(:, 8, 8, :)
到目前为止我的解决方案是这样的:
batch_size, _, _, channels = batch.shape
region_size = 3
# For the (0, 0) region
region_slice = (slice(batch_size), slice(region_size), slice(region_size), slice(channels))
region = batch[region_slice]
new_values = np.arange(batch_size * channels)
# Flatten each channel of an image
region_3d = region.reshape(batch_size, region_size ** 2, channels)
region_3d_argmax = region_3d.argmax(axis=1)
region_argmax = (
np.repeat(np.arange(batch_size), channels),
*np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
np.tile(np.arange(channels), batch_size)
)
# Find indices of max element for each channel
region_3d_argmax = region_3d.argmax(axis=1)
# Manually unravel indices
region_argmax = (
np.repeat(np.arange(batch_size), channels),
*np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
np.tile(np.arange(channels), batch_size)
)
batch[region_slice][region_argmax] = new_values
这段代码有两个问题:
- 重塑
region
可以 return 复制而不是视图
- 手动解开
执行此操作的更好方法是什么?
有合并轴
更好的方法(在内存和性能效率上)是使用 advanced-indexing
创建适当的索引元组 -
m,n = idx.shape
indexer = np.arange(m)[:,None],idx,np.arange(n)
batch_3d[indexer].flat = ...# perform replacement with 1D array
可以通过将替换数组重塑为索引形状来以不同方式编写最后一步(如果尚未如此,则跳过)-
batch_3d[indexer] = replacement_array.reshape(m,n)
我们也可以使用内置的np.put_along_axis
和p
作为替换数组-
np.put_along_axis(batch_3d,idx[:,None,:],p.reshape(m,1,n),axis=1)
注意:post中使用的idx
是从idx = batch_3d.argmax(axis=1)
生成的idx
,因此我们跳过了manually unravel indices
步骤。
不合并轴
我们将定义辅助函数来实现基于 argmax 的多轴替换,而不合并不相邻的轴,因为它们会强制复制。
def indexer_skip_one_axis(a, axis):
return tuple(slice(None) if i!=axis else None for i in range(a.ndim))
def argmax_along_axes(a, axis):
# a is input array
# axis is tuple of axes along which argmax indices are to be computed
argmax1 = (a.argmax(axis[0]))[indexer_skip_one_axis(a,axis[0])]
val_argmax1 = np.take_along_axis(a,argmax1,axis=axis[0])
argmax2 = (val_argmax1.argmax(axis[1]))[indexer_skip_one_axis(a,axis[1])]
val_argmax2 = np.take_along_axis(argmax1,argmax2,axis=axis[1])
r = list(np.ix_(*[np.arange(i) for i in a.shape]))
r[axis[0]] = val_argmax2
r[axis[1]] = argmax2
return tuple(r)
因此,要解决我们的案例,进行所有替换将是 -
m,n,r,s = batch.shape
batch6D = batch.reshape(m,n//3,3,r//3,3,s)
batch6D[argmax_along_axes(batch6D, axis=(2,4))] = new_values.reshape(2,1,1,1,1,2)
out = batch6D.reshape(m,n,r,s)
给定一个包含 2 个 9x9 图像的数组,其中 2 个通道形状如下:
img1 = img1 = np.arange(162).reshape(9,9,2).copy()
img2 = img1 * 2
batch = np.array([img1, img2])
我需要将每个图像切片为 3x3x2(步长=3)个区域,然后定位并替换每个切片的最大元素。对于上面的示例,这些元素是:
(:, 2, 2, :)
(:, 2, 5, :)
(:, 2, 8, :)
(:, 5, 2, :)
(:, 5, 5, :)
(:, 5, 8, :)
(:, 8, 2, :)
(:, 8, 5, :)
(:, 8, 8, :)
到目前为止我的解决方案是这样的:
batch_size, _, _, channels = batch.shape
region_size = 3
# For the (0, 0) region
region_slice = (slice(batch_size), slice(region_size), slice(region_size), slice(channels))
region = batch[region_slice]
new_values = np.arange(batch_size * channels)
# Flatten each channel of an image
region_3d = region.reshape(batch_size, region_size ** 2, channels)
region_3d_argmax = region_3d.argmax(axis=1)
region_argmax = (
np.repeat(np.arange(batch_size), channels),
*np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
np.tile(np.arange(channels), batch_size)
)
# Find indices of max element for each channel
region_3d_argmax = region_3d.argmax(axis=1)
# Manually unravel indices
region_argmax = (
np.repeat(np.arange(batch_size), channels),
*np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
np.tile(np.arange(channels), batch_size)
)
batch[region_slice][region_argmax] = new_values
这段代码有两个问题:
- 重塑
region
可以 return 复制而不是视图 - 手动解开
执行此操作的更好方法是什么?
有合并轴
更好的方法(在内存和性能效率上)是使用 advanced-indexing
创建适当的索引元组 -
m,n = idx.shape
indexer = np.arange(m)[:,None],idx,np.arange(n)
batch_3d[indexer].flat = ...# perform replacement with 1D array
可以通过将替换数组重塑为索引形状来以不同方式编写最后一步(如果尚未如此,则跳过)-
batch_3d[indexer] = replacement_array.reshape(m,n)
我们也可以使用内置的np.put_along_axis
和p
作为替换数组-
np.put_along_axis(batch_3d,idx[:,None,:],p.reshape(m,1,n),axis=1)
注意:post中使用的idx
是从idx = batch_3d.argmax(axis=1)
生成的idx
,因此我们跳过了manually unravel indices
步骤。
不合并轴
我们将定义辅助函数来实现基于 argmax 的多轴替换,而不合并不相邻的轴,因为它们会强制复制。
def indexer_skip_one_axis(a, axis):
return tuple(slice(None) if i!=axis else None for i in range(a.ndim))
def argmax_along_axes(a, axis):
# a is input array
# axis is tuple of axes along which argmax indices are to be computed
argmax1 = (a.argmax(axis[0]))[indexer_skip_one_axis(a,axis[0])]
val_argmax1 = np.take_along_axis(a,argmax1,axis=axis[0])
argmax2 = (val_argmax1.argmax(axis[1]))[indexer_skip_one_axis(a,axis[1])]
val_argmax2 = np.take_along_axis(argmax1,argmax2,axis=axis[1])
r = list(np.ix_(*[np.arange(i) for i in a.shape]))
r[axis[0]] = val_argmax2
r[axis[1]] = argmax2
return tuple(r)
因此,要解决我们的案例,进行所有替换将是 -
m,n,r,s = batch.shape
batch6D = batch.reshape(m,n//3,3,r//3,3,s)
batch6D[argmax_along_axes(batch6D, axis=(2,4))] = new_values.reshape(2,1,1,1,1,2)
out = batch6D.reshape(m,n,r,s)