多维张量如何使用tf.gather_nd
How to use tf.gather_nd for multi-dimensional tensor
如果我有多维张量,我不完全明白我应该如何使用 tf.gather_nd() 沿某个轴拾取元素。让我们举一个小例子(如果我得到这个简单例子的答案,它也解决了我更复杂的原始问题)。假设我有 rgb 图像,我正在尝试沿通道选择最小像素值(如果数据顺序为 (B,H,W,C),则为最后一个维度)。我知道这可以用 tf.recude_min(x, axis=-1)
来完成,但我想知道是否也可以用 tf.argmin()
和 tf.gather_nd()
做同样的事情?
from skimage import data
import tensorflow as tf
import numpy as np
# Load RGB image from skimage, cast it to float32 and put it in order (B,H,W,C)
image = data.astronaut()
image = tf.cast(image, tf.float32)
image = tf.expand_dims(image, axis=0)
# Take minimum pixel value of each channel in a way number 1
min_along_channels_1 = tf.reduce_min(image, axis=-1)
# Take minimum pixel value of each channel in a way number 2
# The goal is that min_along_channels_1 is equal to min_along_channels_2
idxs = tf.argmin(image, axis=-1)
min_along_channels_2 = tf.gather_nd(image, idxs) # This line gives error :(
您将不得不使用 tf.meshgrid
,这将创建一个由两个一维数组组成的矩形网格,表示第一维和第二维的张量索引,因为 tf.gather_nd
需要知道确切位置跨维度提取值。这是一个简化的例子:
import tensorflow as tf
image = tf.random.normal((1, 4, 4, 3))
image = tf.squeeze(image, axis=0)
idx = tf.argmin(image, axis=-1)
ij = tf.stack(tf.meshgrid(
tf.range(image.shape[0], dtype=tf.int64),
tf.range(image.shape[1], dtype=tf.int64),
indexing='ij'), axis=-1)
gather_indices = tf.concat([ij, tf.expand_dims(idx, axis=-1)], axis=-1)
result = tf.gather_nd(image, gather_indices)
print('First option -->', tf.reduce_min(image, axis=-1))
print('Second option -->', result)
First option --> tf.Tensor(
[[-0.53245485 -0.29117298 -0.64434254 -0.8209638 ]
[-0.9386176 -0.5993224 -0.597746 -1.5392851 ]
[-0.5478666 -1.5280861 -1.0344954 -1.920418 ]
[-0.5580688 -1.425873 -1.9276617 -1.0668412 ]], shape=(4, 4), dtype=float32)
Second option --> tf.Tensor(
[[-0.53245485 -0.29117298 -0.64434254 -0.8209638 ]
[-0.9386176 -0.5993224 -0.597746 -1.5392851 ]
[-0.5478666 -1.5280861 -1.0344954 -1.920418 ]
[-0.5580688 -1.425873 -1.9276617 -1.0668412 ]], shape=(4, 4), dtype=float32)
或者用你的例子:
from skimage import data
import tensorflow as tf
import numpy as np
image = data.astronaut()
image = tf.cast(image, tf.float32)
image = tf.expand_dims(image, axis=0)
min_along_channels_1 = tf.reduce_min(image, axis=-1)
image = tf.squeeze(image, axis=0)
idx = tf.argmin(image, axis=-1)
ij = tf.stack(tf.meshgrid(
tf.range(image.shape[0], dtype=tf.int64),
tf.range(image.shape[1], dtype=tf.int64),
indexing='ij'), axis=-1)
gather_indices = tf.concat([ij, tf.expand_dims(idx, axis=-1)], axis=-1)
min_along_channels_2 = tf.gather_nd(image, gather_indices)
print(tf.equal(min_along_channels_1, min_along_channels_2))
tf.Tensor(
[[[ True True True ... True True True]
[ True True True ... True True True]
[ True True True ... True True True]
...
[ True True True ... True True True]
[ True True True ... True True True]
[ True True True ... True True True]]], shape=(1, 512, 512), dtype=bool)
如果我有多维张量,我不完全明白我应该如何使用 tf.gather_nd() 沿某个轴拾取元素。让我们举一个小例子(如果我得到这个简单例子的答案,它也解决了我更复杂的原始问题)。假设我有 rgb 图像,我正在尝试沿通道选择最小像素值(如果数据顺序为 (B,H,W,C),则为最后一个维度)。我知道这可以用 tf.recude_min(x, axis=-1)
来完成,但我想知道是否也可以用 tf.argmin()
和 tf.gather_nd()
做同样的事情?
from skimage import data
import tensorflow as tf
import numpy as np
# Load RGB image from skimage, cast it to float32 and put it in order (B,H,W,C)
image = data.astronaut()
image = tf.cast(image, tf.float32)
image = tf.expand_dims(image, axis=0)
# Take minimum pixel value of each channel in a way number 1
min_along_channels_1 = tf.reduce_min(image, axis=-1)
# Take minimum pixel value of each channel in a way number 2
# The goal is that min_along_channels_1 is equal to min_along_channels_2
idxs = tf.argmin(image, axis=-1)
min_along_channels_2 = tf.gather_nd(image, idxs) # This line gives error :(
您将不得不使用 tf.meshgrid
,这将创建一个由两个一维数组组成的矩形网格,表示第一维和第二维的张量索引,因为 tf.gather_nd
需要知道确切位置跨维度提取值。这是一个简化的例子:
import tensorflow as tf
image = tf.random.normal((1, 4, 4, 3))
image = tf.squeeze(image, axis=0)
idx = tf.argmin(image, axis=-1)
ij = tf.stack(tf.meshgrid(
tf.range(image.shape[0], dtype=tf.int64),
tf.range(image.shape[1], dtype=tf.int64),
indexing='ij'), axis=-1)
gather_indices = tf.concat([ij, tf.expand_dims(idx, axis=-1)], axis=-1)
result = tf.gather_nd(image, gather_indices)
print('First option -->', tf.reduce_min(image, axis=-1))
print('Second option -->', result)
First option --> tf.Tensor(
[[-0.53245485 -0.29117298 -0.64434254 -0.8209638 ]
[-0.9386176 -0.5993224 -0.597746 -1.5392851 ]
[-0.5478666 -1.5280861 -1.0344954 -1.920418 ]
[-0.5580688 -1.425873 -1.9276617 -1.0668412 ]], shape=(4, 4), dtype=float32)
Second option --> tf.Tensor(
[[-0.53245485 -0.29117298 -0.64434254 -0.8209638 ]
[-0.9386176 -0.5993224 -0.597746 -1.5392851 ]
[-0.5478666 -1.5280861 -1.0344954 -1.920418 ]
[-0.5580688 -1.425873 -1.9276617 -1.0668412 ]], shape=(4, 4), dtype=float32)
或者用你的例子:
from skimage import data
import tensorflow as tf
import numpy as np
image = data.astronaut()
image = tf.cast(image, tf.float32)
image = tf.expand_dims(image, axis=0)
min_along_channels_1 = tf.reduce_min(image, axis=-1)
image = tf.squeeze(image, axis=0)
idx = tf.argmin(image, axis=-1)
ij = tf.stack(tf.meshgrid(
tf.range(image.shape[0], dtype=tf.int64),
tf.range(image.shape[1], dtype=tf.int64),
indexing='ij'), axis=-1)
gather_indices = tf.concat([ij, tf.expand_dims(idx, axis=-1)], axis=-1)
min_along_channels_2 = tf.gather_nd(image, gather_indices)
print(tf.equal(min_along_channels_1, min_along_channels_2))
tf.Tensor(
[[[ True True True ... True True True]
[ True True True ... True True True]
[ True True True ... True True True]
...
[ True True True ... True True True]
[ True True True ... True True True]
[ True True True ... True True True]]], shape=(1, 512, 512), dtype=bool)