Tensorflow:如何获取第k条对角线的值
Tensorflow: How to get the value of the k-th diagonal
在 PyTorch 中,函数 torch.diag()
获取张量的第 k 条对角线的值。
例如,a.diag(diagonal=1)
获取第 1 条对角线的值。不幸的是 diag_part()
似乎在 Tensorflow 中不起作用:
a = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
a.diag(diagonal=1)
tensor([2, 6])
a.diag(diagonal=2)
tensor([3])
是否有等效的功能?
TensorFlow 2 >= v2.2
您可以使用tf.linalg.diag_part
>>> a = tf.reshape(tf.range(1,10),(3,3))
>>> a
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=int32)>
>>> tf.linalg.diag_part(a,k=1)
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 6], dtype=int32)>
>>> tf.linalg.diag_part(a,k=2)
<tf.Tensor: shape=(1,), dtype=int32, numpy=array([3], dtype=int32)>
TensorFlow 1.x 和 TensorFlow 2 <= v2.1
2020-11-26: 从 tf 1.15 和 tf2.1 开始,tf.linalg.diag_part
中生成超对角线和次对角线的代码似乎已被禁用。您可以直接使用 matrix_diag_part_v2
来获得所需的行为作为解决方法:
import tensorflow as tf
from tensorflow.python.ops.gen_array_ops import matrix_diag_part_v2
a = tf.reshape(tf.range(1,10),(3,3))
superdiag = matrix_diag_part_v2(a,k=1,padding_value=0)
superdiag2 = matrix_diag_part_v2(a,k=2,padding_value=0)
with tf.Session() as sess:
print(f"Matrix A : {sess.run(a)}")
print(f"Superdiagonal 1 : {sess.run(superdiag)}")
print(f"Superdiagonal 2 : {sess.run(superdiag2)}")
结果
Matrix A : [[1 2 3]
[4 5 6]
[7 8 9]]
Superdiagonal 1 : [2 6]
Superdiagonal 2 : [3]
2021-01-08:tf 1.15 中的错误优先级不高,未计划修复。 Source :
Yes. This is clearly a bug in 1.15. But It's definitely not something significant enough that we'd make a patch release for it, we only do patch releases for major bugs or security fixes.
2021-01-08:感谢 Krzysztof 指出,TF 版本 <= 2.1 也会出现在 TF1 中发现的相同问题。 matrix_diag_part_v2
解决方法也适用于 TF2.1 和 TF2.0。
在 PyTorch 中,函数 torch.diag()
获取张量的第 k 条对角线的值。
例如,a.diag(diagonal=1)
获取第 1 条对角线的值。不幸的是 diag_part()
似乎在 Tensorflow 中不起作用:
a = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
a.diag(diagonal=1)
tensor([2, 6])
a.diag(diagonal=2)
tensor([3])
是否有等效的功能?
TensorFlow 2 >= v2.2
您可以使用tf.linalg.diag_part
>>> a = tf.reshape(tf.range(1,10),(3,3))
>>> a
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=int32)>
>>> tf.linalg.diag_part(a,k=1)
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 6], dtype=int32)>
>>> tf.linalg.diag_part(a,k=2)
<tf.Tensor: shape=(1,), dtype=int32, numpy=array([3], dtype=int32)>
TensorFlow 1.x 和 TensorFlow 2 <= v2.1
2020-11-26: 从 tf 1.15 和 tf2.1 开始,tf.linalg.diag_part
中生成超对角线和次对角线的代码似乎已被禁用。您可以直接使用 matrix_diag_part_v2
来获得所需的行为作为解决方法:
import tensorflow as tf
from tensorflow.python.ops.gen_array_ops import matrix_diag_part_v2
a = tf.reshape(tf.range(1,10),(3,3))
superdiag = matrix_diag_part_v2(a,k=1,padding_value=0)
superdiag2 = matrix_diag_part_v2(a,k=2,padding_value=0)
with tf.Session() as sess:
print(f"Matrix A : {sess.run(a)}")
print(f"Superdiagonal 1 : {sess.run(superdiag)}")
print(f"Superdiagonal 2 : {sess.run(superdiag2)}")
结果
Matrix A : [[1 2 3]
[4 5 6]
[7 8 9]]
Superdiagonal 1 : [2 6]
Superdiagonal 2 : [3]
2021-01-08:tf 1.15 中的错误优先级不高,未计划修复。 Source :
Yes. This is clearly a bug in 1.15. But It's definitely not something significant enough that we'd make a patch release for it, we only do patch releases for major bugs or security fixes.
2021-01-08:感谢 Krzysztof 指出,TF 版本 <= 2.1 也会出现在 TF1 中发现的相同问题。 matrix_diag_part_v2
解决方法也适用于 TF2.1 和 TF2.0。