TensorFlow中基于布尔掩码的部分更新张量

Partial update tensor based on boolean mask in TensorFlow

我想根据某些条件更新部分张量。

我知道 TensorFlow 张量是不可变的,所以创建一个新的张量对我来说没问题。 我尝试了 tensor_scatter_nd_update 方法,但我无法使其工作

这是我想在用 NumPy 编写的 TensorFlow 中复制的代码。

import numpy as np

a = np.random.random((1, 3))
b = np.array([[0, 1, 0]])

c = np.zeros_like(a)
mask = b == 1
c[mask] = np.log(a[mask])

在 TensorFlow 中,我们不会更新实际上是不可变对象的张量。相反,我们像在函数式语言中一样从其他张量创建新的张量。

import tensorflow as tf

a = tf.random.uniform(shape=(1, 3))
b = tf.constant([[0, 1, 0]], dtype=tf.int32)

c = tf.zeros_like(a)
mask = b == 1
c_updated = tf.where(mask, tf.math.log(a), c)
# [[ 0.      , -4.175911,  0.      ]]```