SymPy:提取矩阵的下三角部分

SymPy: Extract the lower triangular part of a matrix

我正在尝试提取 SymPy 矩阵的下三角部分。由于我在 SymPy 中找不到 tril 方法,我定义了:

def tril (M):
    m = M.copy()
    for row_index in range (m.rows):
        for col_index in range (row_index + 1, m.cols):
            m[row_index, col_index] = 0
    return (m)

似乎有效:

是否有更优雅的方法来提取 SymPy 矩阵的下三角部分?

.copy()是确保原始矩阵完整性的推荐方法吗?

你可以简单地使用 numpy 函数:

import numpy as np    
np.tril(M)

*当然,如下所述,您应该转换回 sympy.Matrix(np.tril(M))。但这取决于你接下来要做什么。

In [99]: M
Out[99]: 
⎡a  b  c⎤
⎢       ⎥
⎢d  e  f⎥
⎢       ⎥
⎣g  h  i⎦

另一个答案建议使用 np.tril 函数:

In [100]: np.tril(M)
Out[100]: 
array([[a, 0, 0],
       [d, e, 0],
       [g, h, i]], dtype=object)

M 转换为 numpy 数组 - 对象 dtype 因为符号。结果也是一个numpy数组。

你的函数returns一个sympy.Matrix.

In [101]: def tril (M):
     ...:     m = M.copy()
     ...:     for row_index in range (m.rows):
     ...:         for col_index in range (row_index + 1, m.cols):
     ...:             m[row_index, col_index] = 0
     ...:     return (m)
     ...: 

In [102]: tril(M)
Out[102]: 
⎡a  0  0⎤
⎢       ⎥
⎢d  e  0⎥
⎢       ⎥
⎣g  h  i⎦

作为一般规则,混合使用 sympynumpy 会导致混淆,如果不是错误的话。 numpy 最适合数字工作。它可以像 symbols 一样处理 non-numeric 个对象,但数学是 hit-or-miss.

np.tri... 函数建立在 np.tri 函数之上:

In [114]: np.tri(3).astype(int)
Out[114]: 
array([[1, 0, 0],
       [1, 1, 0],
       [1, 1, 1]])

我们可以从中创建一个符号矩阵:

In [115]: m1 = Matrix(np.tri(3).astype(int))

In [116]: m1
Out[116]: 
⎡1  0  0⎤
⎢       ⎥
⎢1  1  0⎥
⎢       ⎥
⎣1  1  1⎦

并做element-wise乘法:

In [117]: M.multiply_elementwise(m1)
Out[117]: 
⎡a  0  0⎤
⎢       ⎥
⎢d  e  0⎥
⎢       ⎥
⎣g  h  i⎦

np.tri 通过比较列数组和行来工作:

In [123]: np.arange(3)[:,None]>=np.arange(3)
Out[123]: 
array([[ True, False, False],
       [ True,  True, False],
       [ True,  True,  True]])

In [124]: _.astype(int)
Out[124]: 
array([[1, 0, 0],
       [1, 1, 0],
       [1, 1, 1]])

另一个答案建议 lower_triangular。看它的代码很有意思:

    def entry(i, j):
        return self[i, j] if i + k >= j else self.zero

    return self._new(self.rows, self.cols, entry)

它对每个元素应用 i>=j 测试。 _new 必须迭代行和列。

在 SymPy 中,M.lower_triangular(k) 将给出第 k 个对角线下方的下三角元素。默认为 k=0.