张量流中的花式索引
Fancy indexing in tensorflow
我已经实现了一个带有自定义损失函数的 3D CNN (Ax' - y)^2
,其中 x' 是 CNN 的 3D 输出的扁平化和裁剪向量,y 是基本事实,A 是线性运算符接受一个 x 并输出一个 y。所以我需要一种方法来展平 3D 输出并在计算损失之前使用花式索引对其进行裁剪。
这是我尝试过的:
这是我要复制的 numpy 代码,
def flatten_crop(img_vol, indices, vol_shape, N):
"""
:param img_vol: shape (145, 59, 82, N)
:param indices: shape (396929,)
"""
nVx, nVy, nVz = vol_shape
voxels = np.reshape(img_vol, (nVx * nVy * nVz, N), order='F')
voxels = voxels[indices, :]
return voxels
我尝试使用 tf.nd_gather
执行相同的操作,但我无法将其概括为任意批量大小。这是我的批量大小为 1(或单个 3D 输出)的 tensorflow 代码:
voxels = tf.transpose(tf.reshape(tf.transpose(y_pred), (1, 145 * 59 * 82))) # to flatten and reshape using Fortran-like index order
voxels = tf.gather_nd(voxels, tf.stack([indices, tf.zeros(len(indices), dtype=tf.dtypes.int32)], axis=1)) # indexing
voxels = tf.reshape(voxels, (voxels.shape[0], 1))
目前我的自定义损失函数中有这段代码,我希望能够泛化到任意批量大小。此外,如果您有其他建议来实现这一点(例如自定义层而不是与损失函数集成),我洗耳恭听!
谢谢。
试试这个代码:
import tensorflow as tf
y_pred = tf.random.uniform((10, 145, 59, 82))
indices = tf.random.uniform((396929,), 0, 145*59*82, dtype=tf.int32)
voxels = tf.reshape(y_pred, (-1, 145 * 59 * 82)) # to flatten and reshape using Fortran-like index order
voxels = tf.gather(voxels, indices, axis=-1)
voxels = tf.transpose(voxels)
我已经实现了一个带有自定义损失函数的 3D CNN (Ax' - y)^2
,其中 x' 是 CNN 的 3D 输出的扁平化和裁剪向量,y 是基本事实,A 是线性运算符接受一个 x 并输出一个 y。所以我需要一种方法来展平 3D 输出并在计算损失之前使用花式索引对其进行裁剪。
这是我尝试过的: 这是我要复制的 numpy 代码,
def flatten_crop(img_vol, indices, vol_shape, N):
"""
:param img_vol: shape (145, 59, 82, N)
:param indices: shape (396929,)
"""
nVx, nVy, nVz = vol_shape
voxels = np.reshape(img_vol, (nVx * nVy * nVz, N), order='F')
voxels = voxels[indices, :]
return voxels
我尝试使用 tf.nd_gather
执行相同的操作,但我无法将其概括为任意批量大小。这是我的批量大小为 1(或单个 3D 输出)的 tensorflow 代码:
voxels = tf.transpose(tf.reshape(tf.transpose(y_pred), (1, 145 * 59 * 82))) # to flatten and reshape using Fortran-like index order
voxels = tf.gather_nd(voxels, tf.stack([indices, tf.zeros(len(indices), dtype=tf.dtypes.int32)], axis=1)) # indexing
voxels = tf.reshape(voxels, (voxels.shape[0], 1))
目前我的自定义损失函数中有这段代码,我希望能够泛化到任意批量大小。此外,如果您有其他建议来实现这一点(例如自定义层而不是与损失函数集成),我洗耳恭听!
谢谢。
试试这个代码:
import tensorflow as tf
y_pred = tf.random.uniform((10, 145, 59, 82))
indices = tf.random.uniform((396929,), 0, 145*59*82, dtype=tf.int32)
voxels = tf.reshape(y_pred, (-1, 145 * 59 * 82)) # to flatten and reshape using Fortran-like index order
voxels = tf.gather(voxels, indices, axis=-1)
voxels = tf.transpose(voxels)