Tensorflow:如何获取第k条对角线的值

Tensorflow: How to get the value of the k-th diagonal

在 PyTorch 中,函数 torch.diag() 获取张量的第 k 条对角线的值。

例如,a.diag(diagonal=1) 获取第 1 条对角线的值。不幸的是 diag_part() 似乎在 Tensorflow 中不起作用:

a = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
a.diag(diagonal=1)
tensor([2, 6])
a.diag(diagonal=2)
tensor([3])

是否有等效的功能?

TensorFlow 2 >= v2.2

您可以使用tf.linalg.diag_part

>>> a = tf.reshape(tf.range(1,10),(3,3))
>>> a
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]], dtype=int32)>
>>> tf.linalg.diag_part(a,k=1)
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 6], dtype=int32)>
>>> tf.linalg.diag_part(a,k=2)
<tf.Tensor: shape=(1,), dtype=int32, numpy=array([3], dtype=int32)>

TensorFlow 1.x 和 TensorFlow 2 <= v2.1

2020-11-26: 从 tf 1.15 和 tf2.1 开始,tf.linalg.diag_part 中生成超对角线和次对角线的代码似乎已被禁用。您可以直接使用 matrix_diag_part_v2 来获得所需的行为作为解决方法:

import tensorflow as tf
from tensorflow.python.ops.gen_array_ops import matrix_diag_part_v2

a = tf.reshape(tf.range(1,10),(3,3))
superdiag = matrix_diag_part_v2(a,k=1,padding_value=0)
superdiag2 = matrix_diag_part_v2(a,k=2,padding_value=0)

with tf.Session() as sess:
    print(f"Matrix A : {sess.run(a)}")
    print(f"Superdiagonal 1 : {sess.run(superdiag)}")
    print(f"Superdiagonal 2 : {sess.run(superdiag2)}")

结果

Matrix A : [[1 2 3]
 [4 5 6]
 [7 8 9]]
Superdiagonal 1 : [2 6]
Superdiagonal 2 : [3]

2021-01-08:tf 1.15 中的错误优先级不高,未计划修复。 Source :

Yes. This is clearly a bug in 1.15. But It's definitely not something significant enough that we'd make a patch release for it, we only do patch releases for major bugs or security fixes.

2021-01-08:感谢 Krzysztof 指出,TF 版本 <= 2.1 也会出现在 TF1 中发现的相同问题。 matrix_diag_part_v2 解决方法也适用于 TF2.1 和 TF2.0。