如何从keras中的张量中提取非零值

How can I extract nonzero values from tensor in keras

我正在尝试在 Python 中的自定义损失函数中操作一些数据 Tensorflow.keras

考虑以下示例:

b = tf.constant([[0, 3, 1], [0, 5, 2]])

我想擦除零列,或提取非零列,这样最终结果将是一个张量

[[3,1], [5,2]]

我尝试使用 tf.where,使用遮罩,但它不保持形状,它只是 return 具有非零值的一维张量。 此外,我需要它适用于任意数量的行,唯一固定的是列数。

这将选择总和 > 0 的所有列:

tf.transpose(tf.gather_nd(tf.transpose(b), tf.where(tf.reduce_sum(b, axis=0)>0)))