张量流中的花式索引

Fancy indexing in tensorflow

我已经实现了一个带有自定义损失函数的 3D CNN (Ax' - y)^2,其中 x' 是 CNN 的 3D 输出的扁平化和裁剪向量,y 是基本事实,A 是线性运算符接受一个 x 并输出一个 y。所以我需要一种方法来展平 3D 输出并在计算损失之前使用花式索引对其进行裁剪。

这是我尝试过的: 这是我要复制的 numpy 代码,

def flatten_crop(img_vol, indices, vol_shape, N):
    """
    :param img_vol: shape (145, 59, 82, N)
    :param indices: shape (396929,)
    """
    nVx, nVy, nVz = vol_shape
    voxels = np.reshape(img_vol, (nVx * nVy * nVz, N), order='F')
    voxels = voxels[indices, :]
    return voxels

我尝试使用 tf.nd_gather 执行相同的操作,但我无法将其概括为任意批量大小。这是我的批量大小为 1(或单个 3D 输出)的 tensorflow 代码:

voxels = tf.transpose(tf.reshape(tf.transpose(y_pred), (1, 145 * 59 * 82)))   # to flatten and reshape using Fortran-like index order
voxels = tf.gather_nd(voxels, tf.stack([indices, tf.zeros(len(indices), dtype=tf.dtypes.int32)], axis=1))    # indexing
voxels = tf.reshape(voxels, (voxels.shape[0], 1))

目前我的自定义损失函数中有这段代码,我希望能够泛化到任意批量大小。此外,如果您有其他建议来实现这一点(例如自定义层而不是与损失函数集成),我洗耳恭听!

谢谢。

试试这个代码:

import tensorflow as tf
y_pred = tf.random.uniform((10, 145, 59, 82))
indices = tf.random.uniform((396929,), 0, 145*59*82, dtype=tf.int32)
voxels = tf.reshape(y_pred, (-1, 145 * 59 * 82))   # to flatten and reshape using Fortran-like index order
voxels = tf.gather(voxels, indices, axis=-1)
voxels = tf.transpose(voxels)