PyTorch 的 BoolTensor 在 Tensorflow 2.x 中的等价物是什么?

What is the equivalent of PyTorch's BoolTensor in Tensorflow 2.x?

Tensorflow 中是否有 Pytorch 的 BoolTensor 等价物,假设我在 Pytorch 中有以下用法,我想迁移到 Tensorflow

done_mask = torch.BoolTensor(dones.values).to(device)
next_state_values[done_mask] = 0.0

什么是dones? 假设它是一个 0/​​1 张量,你可以像这样将它转换为 Bool 张量:

tf.cast(dones,tf.bool)

但是,如果你想给张量赋值,你不能那样做。

我推荐的一种方法是乘以 1/0 的矩阵:

next_state_values *= tf.cast(dones!=1,next_state_values.dtype)

我不推荐的另一种方法是使用 tf.tensor_scatter_nd_update,因为它在使用渐变时会出现一些问题。对于您的情况,那将是:

indices = tf.where(dones==1)
next_state_values = tf.tensor_scatter_nd_update(next_state_values ,indices,2*tf.zeros(len(indices)))