在 TensorFlow 中获取矩阵的对角线

Get the diagonal of a matrix in TensorFlow

有没有办法在TensorFlow中提取方阵的对角线?也就是说,对于这样的矩阵:

[
 [0, 1, 2],
 [3, 4, 5],
 [6, 7, 8]
]

我要获取元素:[0, 4, 8]

在 numpy 中,这非常简单 np.diag:

在TensorFlow中,有一个diag function,但它只是在对角线上的参数中指定的元素形成一个新矩阵,这不是我想要的。

我可以想象这是如何通过跨越来完成的...但是我在 TensorFlow 中看不到张量的跨越。

这可能是一种解决方法,但有效。

>> sess = tensorflow.InteractiveSession()
>> x = tensorflow.Variable([[1,2,3],[4,5,6],[7,8,9]])
>> x.initializer.run()
>> z = tensorflow.pack([x[i,i] for i in range(3)])
>> z.eval()
array([1, 5, 9], dtype=int32)

目前可以使用 tf.diag_part 提取对角线元素。这是他们的例子:

"""
'input' is [[1, 0, 0, 0],
            [0, 2, 0, 0],
            [0, 0, 3, 0],
            [0, 0, 0, 4]]
"""

tf.diag_part(input) ==> [1, 2, 3, 4]

旧答案(当 diag_part 时)不可用(如果您想实现现在不可用的东西,仍然相关):

查看math operations and tensor transformations后,好像不存在这样的操作。即使你可以用矩阵乘法提取这些数据,它也不会有效(得到对角线是 O(n))。

你有三种方法,从易到难。

  1. 计算张量,用numpy提取对角线,用TF构建一个变量
  2. 按照 Anurag 建议的方式使用 tf.pack(也使用 tf.shape
  3. 提取值 3
  4. 自己写op in C++,重建TF,原生使用。

使用gather操作。

x = tensorflow.Variable([[1,2,3],[4,5,6],[7,8,9]])
x_flat = tf.reshape(x, [-1])  # flatten the matrix
x_diag = tf.gather(x, [0, 3, 6])

根据上下文,掩码可能是“取消”矩阵对角线元素的好方法,尤其是如果您打算减少它的话:

mask = tf.diag(tf.ones([n]))
y = tf.mul(mask,y)
cost = -tf.reduce_sum(y)

使用 tensorflow 0.8 可以使用 tf.diag_part() 提取对角线元素(参见 documentation

更新

对于 tensorflow >= r1.12 其 tf.linalg.tensor_diag_part(参见 documentation

使用 tf.diag_part()

with tf.Session() as sess:
    x = tf.ones(shape=[3, 3])
    x_diag = tf.diag_part(x)
    print(sess.run(x_diag ))