在 tensorflow 2.x 中分配给张量的更复杂的方法是什么

What is more sophisticated way to assign to tensor in tensorflow 2.x

我只是想知道是否有更好的方法来更新 tf2 中的张量。假设我有 tensor_a = tf.ones(4,5,5) (batch_size, H, W) 并且我想用零 (index=1) 替换第二个样本的所有值。这就是我在不使用 Eager 执行模式的情况下设法做到这一点的方法:

tensor_a = tf.ones([4,5,5])
tensor_b = tf.zeros([1,5,5])
index=1
tensor_a = tf.concat([tensor_a[:index], tensor_b, tensor_a[index+1:]], axis=0)

我知道存在 tf.tensor_scatter_nd_update() 函数,但我不熟悉 meshgrids,在我看来,对于简单的切片分配操作,它们看起来有点难看。同样在某些情况下,一次更新具有多个索引的切片(例如样本 0,1 和 2 为零)会很方便。

Tensorflow 操作有时有点乱。

import tensorflow as tf

tensor = tf.ones([4, 5, 5])

tensor = tf.tensor_scatter_nd_update(
    tensor, [[1]], tf.zeros_like(tf.gather(tensor, [1])) 
)
<tf.Tensor: shape=(4, 5, 5), dtype=float32, numpy=
array([[[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]],
       [[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],
       [[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]],
       [[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]]], dtype=float32)>