如何在 numpy 中进行分散和聚集操作?
How to do scatter and gather operations in numpy?
我想在 Numpy 中实现 Tensorflow 或 PyTorch 的分散和收集操作。我已经挠头好一阵子了。非常感谢任何指点!
因为 ref
和 indices
是 numpy 数组:
散点更新:
ref[indices] = updates # tf.scatter_update(ref, indices, updates)
ref[:, indices] = updates # tf.scatter_update(ref, indices, updates, axis=1)
ref[..., indices, :] = updates # tf.scatter_update(ref, indices, updates, axis=-2)
ref[..., indices] = updates # tf.scatter_update(ref, indices, updates, axis=-1)
集合:
ref[indices] # tf.gather(ref, indices)
ref[:, indices] # tf.gather(ref, indices, axis=1)
ref[..., indices, :] # tf.gather(ref, indices, axis=-2)
ref[..., indices] # tf.gather(ref, indices, axis=-1)
有关更多信息,请参阅 numpy docs on indexing。
对于散射,而不是像@DomJack 建议的那样使用切片赋值,使用 np.add.at 通常更好;因为与切片分配不同,这在存在重复索引的情况下具有明确定义的行为。
scatter
方法比我预期的要多得多。我没有在 NumPy 中找到任何现成的功能。为了可能需要使用 NumPy 实现它的任何人的利益,我在这里分享它。
(p.s。self
是方法的目标或输出。)
def scatter_numpy(self, dim, index, src):
"""
Writes all values from the Tensor src into self at the indices specified in the index Tensor.
:param dim: The axis along which to index
:param index: The indices of elements to scatter
:param src: The source element(s) to scatter
:return: self
"""
if index.dtype != np.dtype('int_'):
raise TypeError("The values of index must be integers")
if self.ndim != index.ndim:
raise ValueError("Index should have the same number of dimensions as output")
if dim >= self.ndim or dim < -self.ndim:
raise IndexError("dim is out of range")
if dim < 0:
# Not sure why scatter should accept dim < 0, but that is the behavior in PyTorch's scatter
dim = self.ndim + dim
idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:]
self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
if idx_xsection_shape != self_xsection_shape:
raise ValueError("Except for dimension " + str(dim) +
", all dimensions of index and output should be the same size")
if (index >= self.shape[dim]).any() or (index < 0).any():
raise IndexError("The values of index must be between 0 and (self.shape[dim] -1)")
def make_slice(arr, dim, i):
slc = [slice(None)] * arr.ndim
slc[dim] = i
return slc
# We use index and dim parameters to create idx
# idx is in a form that can be used as a NumPy advanced index for scattering of src param. in self
idx = [[*np.indices(idx_xsection_shape).reshape(index.ndim - 1, -1),
index[make_slice(index, dim, i)].reshape(1, -1)[0]] for i in range(index.shape[dim])]
idx = list(np.concatenate(idx, axis=1))
idx.insert(dim, idx.pop())
if not np.isscalar(src):
if index.shape[dim] > src.shape[dim]:
raise IndexError("Dimension " + str(dim) + "of index can not be bigger than that of src ")
src_xsection_shape = src.shape[:dim] + src.shape[dim + 1:]
if idx_xsection_shape != src_xsection_shape:
raise ValueError("Except for dimension " +
str(dim) + ", all dimensions of index and src should be the same size")
# src_idx is a NumPy advanced index for indexing of elements in the src
src_idx = list(idx)
src_idx.pop(dim)
src_idx.insert(dim, np.repeat(np.arange(index.shape[dim]), np.prod(idx_xsection_shape)))
self[idx] = src[src_idx]
else:
self[idx] = src
return self
gather
可能有一个更简单的解决方案,但这是我确定的:
(这里 self
是从中收集值的 ndarray。)
def gather_numpy(self, dim, index):
"""
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
:param dim: The axis along which to index
:param index: A tensor of indices of elements to gather
:return: tensor of gathered values
"""
idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:]
self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
if idx_xsection_shape != self_xsection_shape:
raise ValueError("Except for dimension " + str(dim) +
", all dimensions of index and self should be the same size")
if index.dtype != np.dtype('int_'):
raise TypeError("The values of index must be integers")
data_swaped = np.swapaxes(self, 0, dim)
index_swaped = np.swapaxes(index, 0, dim)
gathered = np.choose(index_swaped, data_swaped)
return np.swapaxes(gathered, 0, dim)
我做的很像。
def gather(a, dim, index):
expanded_index = [index if dim==i else np.arange(a.shape[i]).reshape([-1 if i==j else 1 for j in range(a.ndim)]) for i in range(a.ndim)]
return a[expanded_index]
def scatter(a, dim, index, b): # a inplace
expanded_index = [index if dim==i else np.arange(a.shape[i]).reshape([-1 if i==j else 1 for j in range(a.ndim)]) for i in range(a.ndim)]
a[expanded_index] = b
对于收集操作:np.take()
https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.take.html
如果您只是想要相同的功能而不是从头开始实现,
numpy.insert() 是 pytorch 中 scatter_(dim, index, src) 操作的有力竞争者,但它只处理一个维度。
scatter_nd
操作可以使用*np*'s ufuncs .at
函数实现。
根据 TF scatter_nd's
文档:
Calling tf.scatter_nd(indices, values, shape)
is identical to tensor_scatter_add(tf.zeros(shape, values.dtype), indices, values)
.
因此,您可以使用应用于 np.zeros
数组的 np.add.at
重现 tf.scatter_nd
,请参阅下面的 MVCE:
import tensorflow as tf
tf.enable_eager_execution() # Remove this line if working in TF2
import numpy as np
def scatter_nd_numpy(indices, updates, shape):
target = np.zeros(shape, dtype=updates.dtype)
indices = tuple(indices.reshape(-1, indices.shape[-1]).T)
updates = updates.ravel()
np.add.at(target, indices, updates)
return target
indices = np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
updates = np.array([[1, 2], [3, 4]])
shape = (2, 3)
scattered_tf = tf.scatter_nd(indices, updates, shape).numpy()
scattered_np = scatter_nd_numpy(indices, updates, shape)
assert np.allclose(scattered_tf, scattered_np)
注意:正如@denis 所指出的,当某些索引重复时,上述解决方案会有所不同,这可以通过使用计数器并仅获取每个重复索引的最后一个来解决。
有两个内置的 numpy 函数可以满足您的要求。
您可以使用np.take_along_axis to implement torch.gather, and use np.put_along_axis来实现torch.scatter
我想在 Numpy 中实现 Tensorflow 或 PyTorch 的分散和收集操作。我已经挠头好一阵子了。非常感谢任何指点!
因为 ref
和 indices
是 numpy 数组:
散点更新:
ref[indices] = updates # tf.scatter_update(ref, indices, updates)
ref[:, indices] = updates # tf.scatter_update(ref, indices, updates, axis=1)
ref[..., indices, :] = updates # tf.scatter_update(ref, indices, updates, axis=-2)
ref[..., indices] = updates # tf.scatter_update(ref, indices, updates, axis=-1)
集合:
ref[indices] # tf.gather(ref, indices)
ref[:, indices] # tf.gather(ref, indices, axis=1)
ref[..., indices, :] # tf.gather(ref, indices, axis=-2)
ref[..., indices] # tf.gather(ref, indices, axis=-1)
有关更多信息,请参阅 numpy docs on indexing。
对于散射,而不是像@DomJack 建议的那样使用切片赋值,使用 np.add.at 通常更好;因为与切片分配不同,这在存在重复索引的情况下具有明确定义的行为。
scatter
方法比我预期的要多得多。我没有在 NumPy 中找到任何现成的功能。为了可能需要使用 NumPy 实现它的任何人的利益,我在这里分享它。
(p.s。self
是方法的目标或输出。)
def scatter_numpy(self, dim, index, src):
"""
Writes all values from the Tensor src into self at the indices specified in the index Tensor.
:param dim: The axis along which to index
:param index: The indices of elements to scatter
:param src: The source element(s) to scatter
:return: self
"""
if index.dtype != np.dtype('int_'):
raise TypeError("The values of index must be integers")
if self.ndim != index.ndim:
raise ValueError("Index should have the same number of dimensions as output")
if dim >= self.ndim or dim < -self.ndim:
raise IndexError("dim is out of range")
if dim < 0:
# Not sure why scatter should accept dim < 0, but that is the behavior in PyTorch's scatter
dim = self.ndim + dim
idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:]
self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
if idx_xsection_shape != self_xsection_shape:
raise ValueError("Except for dimension " + str(dim) +
", all dimensions of index and output should be the same size")
if (index >= self.shape[dim]).any() or (index < 0).any():
raise IndexError("The values of index must be between 0 and (self.shape[dim] -1)")
def make_slice(arr, dim, i):
slc = [slice(None)] * arr.ndim
slc[dim] = i
return slc
# We use index and dim parameters to create idx
# idx is in a form that can be used as a NumPy advanced index for scattering of src param. in self
idx = [[*np.indices(idx_xsection_shape).reshape(index.ndim - 1, -1),
index[make_slice(index, dim, i)].reshape(1, -1)[0]] for i in range(index.shape[dim])]
idx = list(np.concatenate(idx, axis=1))
idx.insert(dim, idx.pop())
if not np.isscalar(src):
if index.shape[dim] > src.shape[dim]:
raise IndexError("Dimension " + str(dim) + "of index can not be bigger than that of src ")
src_xsection_shape = src.shape[:dim] + src.shape[dim + 1:]
if idx_xsection_shape != src_xsection_shape:
raise ValueError("Except for dimension " +
str(dim) + ", all dimensions of index and src should be the same size")
# src_idx is a NumPy advanced index for indexing of elements in the src
src_idx = list(idx)
src_idx.pop(dim)
src_idx.insert(dim, np.repeat(np.arange(index.shape[dim]), np.prod(idx_xsection_shape)))
self[idx] = src[src_idx]
else:
self[idx] = src
return self
gather
可能有一个更简单的解决方案,但这是我确定的:
(这里 self
是从中收集值的 ndarray。)
def gather_numpy(self, dim, index):
"""
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
:param dim: The axis along which to index
:param index: A tensor of indices of elements to gather
:return: tensor of gathered values
"""
idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:]
self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
if idx_xsection_shape != self_xsection_shape:
raise ValueError("Except for dimension " + str(dim) +
", all dimensions of index and self should be the same size")
if index.dtype != np.dtype('int_'):
raise TypeError("The values of index must be integers")
data_swaped = np.swapaxes(self, 0, dim)
index_swaped = np.swapaxes(index, 0, dim)
gathered = np.choose(index_swaped, data_swaped)
return np.swapaxes(gathered, 0, dim)
我做的很像。
def gather(a, dim, index):
expanded_index = [index if dim==i else np.arange(a.shape[i]).reshape([-1 if i==j else 1 for j in range(a.ndim)]) for i in range(a.ndim)]
return a[expanded_index]
def scatter(a, dim, index, b): # a inplace
expanded_index = [index if dim==i else np.arange(a.shape[i]).reshape([-1 if i==j else 1 for j in range(a.ndim)]) for i in range(a.ndim)]
a[expanded_index] = b
对于收集操作:np.take()
https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.take.html
如果您只是想要相同的功能而不是从头开始实现,
numpy.insert() 是 pytorch 中 scatter_(dim, index, src) 操作的有力竞争者,但它只处理一个维度。
scatter_nd
操作可以使用*np*'s ufuncs .at
函数实现。
根据 TF scatter_nd's
文档:
Calling
tf.scatter_nd(indices, values, shape)
is identical totensor_scatter_add(tf.zeros(shape, values.dtype), indices, values)
.
因此,您可以使用应用于 np.zeros
数组的 np.add.at
重现 tf.scatter_nd
,请参阅下面的 MVCE:
import tensorflow as tf
tf.enable_eager_execution() # Remove this line if working in TF2
import numpy as np
def scatter_nd_numpy(indices, updates, shape):
target = np.zeros(shape, dtype=updates.dtype)
indices = tuple(indices.reshape(-1, indices.shape[-1]).T)
updates = updates.ravel()
np.add.at(target, indices, updates)
return target
indices = np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
updates = np.array([[1, 2], [3, 4]])
shape = (2, 3)
scattered_tf = tf.scatter_nd(indices, updates, shape).numpy()
scattered_np = scatter_nd_numpy(indices, updates, shape)
assert np.allclose(scattered_tf, scattered_np)
注意:正如@denis 所指出的,当某些索引重复时,上述解决方案会有所不同,这可以通过使用计数器并仅获取每个重复索引的最后一个来解决。
有两个内置的 numpy 函数可以满足您的要求。
您可以使用np.take_along_axis to implement torch.gather, and use np.put_along_axis来实现torch.scatter