在 Tensorflow 中查找子矩阵的秩
Find rank of a sub-matrix in Tensorflow
我有一个形状为 [4, 4, 2, 2]
的矩阵 g
,我需要在其中找到 g[0, 0]
、g[1, 1]
、g[2, 2]
和 [=16 的秩=] 都是 2x2
矩阵。我使用了 tf.rank
运算符,但它将 g
视为单个数组并计算秩和 returns 整个矩阵的单个值。我需要的是对应 g[i, j]
的 2x2
矩阵。
以下是 MWE:
import tensorflow as tf
import numpy as np
a = np.array([
[[[ 0., 0.], [ 0., 0.]], [[-1., -1.], [-1., -1.]], [[-2., -2.], [-2., -2.]], [[-3., -3.], [-3., -3.]]],
[[[ 1., 1.], [ 1., 1.]], [[ 0., 0.], [ 0., 0.]], [[-1., -1.], [-1., -1.]], [[-2., -2.], [-2., -2.]]],
[[[ 2., 2.], [ 2., 2.]], [[ 1., 1.], [ 1., 1.]], [[ 0., 0.], [ 0., 0.]], [[-1., -1.], [-1., -1.]]],
[[[ 3., 3.], [ 3., 3.]], [[ 2., 2.], [ 2., 2.]], [[ 1., 1.], [ 1., 1.]], [[ 0., 0.], [ 0., 0.]]]
])
rank = tf.rank(a) # Returns a number
除了使用 for
循环之外,还有什么方法可以得到这个等级矩阵吗?谢谢。
我认为 TensorFlow 中没有计算矩阵秩的函数。一种可能性是使用 tf.linalg.svd
并计算非零奇异值的数量:
import tensorflow as tf
EPS = 1e-6
a = tf.ones((4, 4, 2, 2), tf.float32)
s = tf.linalg.svd(a, full_matrices=False, compute_uv=False)
r = tf.math.count_nonzero(tf.abs(s) > EPS, axis=-1)
print(r.numpy())
# [[1 1 1 1]
# [1 1 1 1]
# [1 1 1 1]
# [1 1 1 1]]
我有一个形状为 [4, 4, 2, 2]
的矩阵 g
,我需要在其中找到 g[0, 0]
、g[1, 1]
、g[2, 2]
和 [=16 的秩=] 都是 2x2
矩阵。我使用了 tf.rank
运算符,但它将 g
视为单个数组并计算秩和 returns 整个矩阵的单个值。我需要的是对应 g[i, j]
的 2x2
矩阵。
以下是 MWE:
import tensorflow as tf
import numpy as np
a = np.array([
[[[ 0., 0.], [ 0., 0.]], [[-1., -1.], [-1., -1.]], [[-2., -2.], [-2., -2.]], [[-3., -3.], [-3., -3.]]],
[[[ 1., 1.], [ 1., 1.]], [[ 0., 0.], [ 0., 0.]], [[-1., -1.], [-1., -1.]], [[-2., -2.], [-2., -2.]]],
[[[ 2., 2.], [ 2., 2.]], [[ 1., 1.], [ 1., 1.]], [[ 0., 0.], [ 0., 0.]], [[-1., -1.], [-1., -1.]]],
[[[ 3., 3.], [ 3., 3.]], [[ 2., 2.], [ 2., 2.]], [[ 1., 1.], [ 1., 1.]], [[ 0., 0.], [ 0., 0.]]]
])
rank = tf.rank(a) # Returns a number
除了使用 for
循环之外,还有什么方法可以得到这个等级矩阵吗?谢谢。
我认为 TensorFlow 中没有计算矩阵秩的函数。一种可能性是使用 tf.linalg.svd
并计算非零奇异值的数量:
import tensorflow as tf
EPS = 1e-6
a = tf.ones((4, 4, 2, 2), tf.float32)
s = tf.linalg.svd(a, full_matrices=False, compute_uv=False)
r = tf.math.count_nonzero(tf.abs(s) > EPS, axis=-1)
print(r.numpy())
# [[1 1 1 1]
# [1 1 1 1]
# [1 1 1 1]
# [1 1 1 1]]