PyTorch 的 BoolTensor 在 Tensorflow 2.x 中的等价物是什么?
What is the equivalent of PyTorch's BoolTensor in Tensorflow 2.x?
Tensorflow 中是否有 Pytorch 的 BoolTensor 等价物,假设我在 Pytorch 中有以下用法,我想迁移到 Tensorflow
done_mask = torch.BoolTensor(dones.values).to(device)
next_state_values[done_mask] = 0.0
什么是dones
?
假设它是一个 0/1 张量,你可以像这样将它转换为 Bool 张量:
tf.cast(dones,tf.bool)
但是,如果你想给张量赋值,你不能那样做。
我推荐的一种方法是乘以 1/0 的矩阵:
next_state_values *= tf.cast(dones!=1,next_state_values.dtype)
我不推荐的另一种方法是使用 tf.tensor_scatter_nd_update,因为它在使用渐变时会出现一些问题。对于您的情况,那将是:
indices = tf.where(dones==1)
next_state_values = tf.tensor_scatter_nd_update(next_state_values ,indices,2*tf.zeros(len(indices)))
Tensorflow 中是否有 Pytorch 的 BoolTensor 等价物,假设我在 Pytorch 中有以下用法,我想迁移到 Tensorflow
done_mask = torch.BoolTensor(dones.values).to(device)
next_state_values[done_mask] = 0.0
什么是dones
?
假设它是一个 0/1 张量,你可以像这样将它转换为 Bool 张量:
tf.cast(dones,tf.bool)
但是,如果你想给张量赋值,你不能那样做。
我推荐的一种方法是乘以 1/0 的矩阵:
next_state_values *= tf.cast(dones!=1,next_state_values.dtype)
我不推荐的另一种方法是使用 tf.tensor_scatter_nd_update,因为它在使用渐变时会出现一些问题。对于您的情况,那将是:
indices = tf.where(dones==1)
next_state_values = tf.tensor_scatter_nd_update(next_state_values ,indices,2*tf.zeros(len(indices)))