如何用 sympy 分解 2x2 仿射矩阵?

How to decompose a 2x2 affine matrix with sympy?

我试图用 sympy 来展示仿射矩阵的分解 显示在以下 stackexchange post:

https://math.stackexchange.com/questions/612006/decomposing-an-affine-transformation

我设置了两个矩阵A_paramsA_matrix,其中前者代表 原始矩阵值,后者是由其构造的矩阵 基础参数。

import sympy
import itertools as it
import ubelt as ub
domain = {'real': True}

theta = sympy.symbols('theta', **domain)
sx, sy = sympy.symbols('sx, sy', nonzero=True, **domain)
m = sympy.symbols('m', **domain)

S = sympy.Matrix([  # scale
    [sx,  0],
    [ 0, sy]])

H = sympy.Matrix([  # shear
    [1, m],
    [0, 1]])

R = sympy.Matrix([  # rotation
    [sympy.cos(theta), -sympy.sin(theta)],
    [sympy.sin(theta),  sympy.cos(theta)]])


A_params = sympy.simplify((R @ H @ S))
a11, a12, a21, a22 = sympy.symbols(
    'a11, a12, a21, a22', real=True)
A_matrix = sympy.Matrix([[a11, a12], [a21, a22]])


print(ub.hzcat(['A_matrix = ', sympy.pretty(A_matrix)]))
print(ub.hzcat(['A_params = ', sympy.pretty(A_params)]))
A_matrix = ⎡a₁₁  a₁₂⎤
           ⎢        ⎥
           ⎣a₂₁  a₂₂⎦
A_params = ⎡sx⋅cos(θ)  sy⋅(m⋅cos(θ) - sin(θ))⎤
           ⎢                                 ⎥
           ⎣sx⋅sin(θ)  sy⋅(m⋅sin(θ) + cos(θ))⎦

据我了解,我应该能够简单地将这两个矩阵设置为 相等,然后求解感兴趣的参数。但是,我越来越 意想不到的结果。

首先,如果我只是尝试求解“sx”,我得不到结果。

## Option 1: Matrix equality
mat_equation = sympy.Eq(A_matrix, A_params)
soln_sx = sympy.solve(mat_equation, sx)
print('soln_sx = {!r}'.format(soln_sx))

## Option 2: List of equations
lhs_iter = it.chain.from_iterable(A_matrix.tolist())
rhs_iter = it.chain.from_iterable(A_params.tolist())
equations = [sympy.Eq(lhs, rhs) for lhs, rhs in zip(lhs_iter, rhs_iter)]
soln_sx = sympy.solve(equations, sx)
print('soln_sx = {!r}'.format(soln_sx))
soln_sx = []
soln_sx = []

但是如果我尝试同时求解所有变量,我会得到一个结果 但它不符合我的预期

solve_for = (sx, theta, sy, m)
solutions = sympy.solve(mat_equation, *solve_for)
for sol, symbol in zip(solutions[0], solve_for):
    sol = sympy.simplify(sol)
    print('sol({!r}) = {!r}'.format(symbol, sol))
    # sympy.pretty_print(sol)
sol(sx) = -(a11**2 + a11*sqrt(a11**2 + a21**2) + a21**2)/(a11 + sqrt(a11**2 + a21**2))
sol(theta) = -2*atan((a11 + sqrt(a11**2 + a21**2))/a21)
sol(sy) = (-8*a11**6*a22 + 8*a11**5*a12*a21 - 8*a11**5*a22*sqrt(a11**2 + a21**2) + 8*a11**4*a12*a21*sqrt(a11**2 + a21**2) - 12*a11**4*a21**2*a22 + 12*a11**3*a12*a21**3 - 8*a11**3*a21**2*a22*sqrt(a11**2 + a21**2) + 8*a11**2*a12*a21**3*sqrt(a11**2 + a21**2) - 4*a11**2*a21**4*a22 + 4*a11*a12*a21**5 - a11*a21**4*a22*sqrt(a11**2 + a21**2) + a12*a21**5*sqrt(a11**2 + a21**2))/(8*a11**6 + 8*a11**5*sqrt(a11**2 + a21**2) + 16*a11**4*a21**2 + 12*a11**3*a21**2*sqrt(a11**2 + a21**2) + 9*a11**2*a21**4 + 4*a11*a21**4*sqrt(a11**2 + a21**2) + a21**6)
sol(m) = (a11*a12 + a21*a22)/(a11*a22 - a12*a21)

在经历了一段艰难的时间之后,我想看看 如果我至少可以验证来自 stackexchange 的解决方案。所以我把它编码了 象征性地:

# This is the guided solution by Stéphane Laurent
recon_sx = sympy.sqrt(a11 * a11 + a21 * a21)
recon_theta = sympy.atan2(a21, a11)
recon_sin_t = sympy.sin(recon_theta)
recon_cos_t = sympy.cos(recon_theta)

recon_msy = a12 * recon_sin_t + a22 * recon_cos_t

condition2 = sympy.simplify(sympy.Eq(recon_sin_t, 0))
condition1 = sympy.simplify(sympy.Not(condition2))
sy_cond1 = (recon_msy * recon_cos_t - a12) / recon_sin_t
sy_cond2 = (a22 - recon_msy * recon_sin_t) / recon_cos_t

recon_sy = sympy.Piecewise((sy_cond1, condition1), (sy_cond2, condition2))

recon_m = recon_msy / recon_sy

recon_S = sympy.Matrix([  # scale
    [recon_sx,  0],
    [ 0, recon_sy]])

recon_H = sympy.Matrix([  # shear
    [1, recon_m],
    [0, 1]])

recon_R = sympy.Matrix([  # rotation
    [sympy.cos(recon_theta), -sympy.sin(recon_theta)],
    [sympy.sin(recon_theta),  sympy.cos(recon_theta)]])

# Recombine the components
A_recon = sympy.simplify((recon_R @ recon_H @ recon_S))
print(ub.hzcat(['A_recon = ', sympy.pretty(A_recon)]))

结果很像我所期望的,但事实并非如此 似乎一直简化到可以以编程方式进行的地步 已验证。

A_recon = ⎡     ⎧                                       a₂₁            ⎤
          ⎢     ⎪            a₁₂              for ──────────────── ≠ 0 ⎥
          ⎢     ⎪                                    _____________     ⎥
          ⎢     ⎪                                   ╱    2      2      ⎥
          ⎢a₁₁  ⎨                                 ╲╱  a₁₁  + a₂₁       ⎥
          ⎢     ⎪                                                      ⎥
          ⎢     ⎪a₁₁⋅a₂₂ + a₁₂⋅a₂₁ - a₂₁⋅a₂₂                           ⎥
          ⎢     ⎪───────────────────────────         otherwise         ⎥
          ⎢     ⎩            a₁₁                                       ⎥
          ⎢                                                            ⎥
          ⎢     ⎧-a₁₁⋅a₁₂ + a₁₁⋅a₂₂ + a₁₂⋅a₂₁            a₂₁           ⎥
          ⎢     ⎪────────────────────────────  for ──────────────── ≠ 0⎥
          ⎢     ⎪            a₂₁                      _____________    ⎥
          ⎢a₂₁  ⎨                                    ╱    2      2     ⎥
          ⎢     ⎪                                  ╲╱  a₁₁  + a₂₁      ⎥
          ⎢     ⎪                                                      ⎥
          ⎣     ⎩            a₂₂                      otherwise        ⎦

我的想法是条件是乱七八糟的,所以我试了一下 使用两种情况:

recon_sy2 = sy_cond1
recon_m2 = recon_msy / recon_sy2

recon_S2 = sympy.Matrix([  # scale
    [recon_sx,  0],
    [ 0, recon_sy2]])

recon_H2 = sympy.Matrix([  # shear
    [1, recon_m2],
    [0, 1]])


recon_sy3 = sy_cond2
recon_m3 = recon_msy / recon_sy3

recon_S3 = sympy.Matrix([  # scale
    [recon_sx,  0],
    [ 0, recon_sy3]])

recon_H3 = sympy.Matrix([  # shear
    [1, recon_m3],
    [0, 1]])


# Recombine the components
A_recon2 = sympy.simplify((recon_R @ recon_H2 @ recon_S2))
A_recon3 = sympy.simplify((recon_R @ recon_H3 @ recon_S3))
print('')
print(ub.hzcat(['A_recon2 = ', sympy.pretty(A_recon2)]))
print('')
print(ub.hzcat(['A_recon3 = ', sympy.pretty(A_recon3)]))
A_recon2 = ⎡a₁₁              a₁₂             ⎤
           ⎢                                 ⎥
           ⎢     -a₁₁⋅a₁₂ + a₁₁⋅a₂₂ + a₁₂⋅a₂₁⎥
           ⎢a₂₁  ────────────────────────────⎥
           ⎣                 a₂₁             ⎦

A_recon3 = ⎡     a₁₁⋅a₂₂ + a₁₂⋅a₂₁ - a₂₁⋅a₂₂⎤
           ⎢a₁₁  ───────────────────────────⎥
           ⎢                 a₁₁            ⎥
           ⎢                                ⎥
           ⎣a₂₁              a₂₂            ⎦

但这似乎不允许进一步简化。

我不太明白 a22/a12 是如何从 top/bottom 方程中弹出的 分别,但如果这种分解是正确的,他们应该,但这些 结果让我担心事实并非如此。

所以我的问题有两个:

  1. 任何 sympy 大师都可以帮助我获得分解工作的基本解决方案吗?

  2. 参考SEpost中的分解是不是错了?或者我不包括 允许简化的约束?如果是这样,我将如何做到这一点?

更新

当所有变量都联合求解时,我可以通过在 sympy.solve 的方程上使用 sympy.radsimp 来更进一步(仍然不确定为什么它不能自己求解 sx)。

solve_for = (sx, theta, sy, m)
solutions = sympy.solve(mat_equation, *solve_for, dict=True)
# minimal=True, quick=True, cubics=False, quartics=False, quintics=False, check=False)
for sol in solutions:
    for sym, symsol0 in sol.items():
        symsol = sympy.radsimp(symsol0)
        symsol = sympy.trigsimp(symsol)
        symsol = sympy.simplify(symsol)
        symsol = sympy.radsimp(symsol)
        print('\n=====')
        print('sym = {!r}'.format(sym))
        print('symsol  = {!r}'.format(symsol))
        print('--')
        sympy.pretty_print(symsol, wrap_line=False)
        print('--')
        print('=====\n')
=====
sym = sx
symsol  = -sqrt(a11**2 + a21**2)
--
    _____________
   ╱    2      2 
-╲╱  a₁₁  + a₂₁  
--
=====


=====
sym = theta
symsol  = 2*atan((a11 + sqrt(a11**2 + a21**2))/a21)
--
      ⎛         _____________⎞
      ⎜        ╱    2      2 ⎟
      ⎜a₁₁ + ╲╱  a₁₁  + a₂₁  ⎟
2⋅atan⎜──────────────────────⎟
      ⎝         a₂₁          ⎠
--
=====


=====
sym = m
symsol  = (a11*a12 + a21*a22)/(a11*a22 - a12*a21)
--
a₁₁⋅a₁₂ + a₂₁⋅a₂₂
─────────────────
a₁₁⋅a₂₂ - a₁₂⋅a₂₁
--
=====


=====
sym = sy
symsol  = (-a11*a22*sqrt(a11**2 + a21**2) + a12*a21*sqrt(a11**2 + a21**2))/(a11**2 + a21**2)
--
             _____________              _____________
            ╱    2      2              ╱    2      2 
- a₁₁⋅a₂₂⋅╲╱  a₁₁  + a₂₁   + a₁₂⋅a₂₁⋅╲╱  a₁₁  + a₂₁  
─────────────────────────────────────────────────────
                        2      2                     
                     a₁₁  + a₂₁                      
--
=====

但是 sx 的解决方案更接近我想要的(虽然它是一个负根,我认为这在技术上是正确的,但我的印象是 sympy 只处理原则根)。

主要问题仍然悬而未决。 (虽然我更有信心原来的 SE post 是正确的)。

而且好像在说“m”在分母中有行列式,这很有趣。 (分子是行的点积)。

更新2

我开始认为 sympy 或 Se post 中存在一些错误。我开始进行数值检查,它给出了我认为无法调和的错误(即旋转后相同)。

数值校验码为

params = [sx, theta, sy, m]
params_rand = {p: np.random.rand() for p in params}
A_params_rand = A_params.subs(params_rand)
matrix_rand = {lhs: rhs for lhs, rhs in zip(elements, ub.flatten(A_params_rand.tolist()))}
A_matrix_rand = A_matrix.subs(matrix_rand)
A_solved_rand = A_solved_recon.subs(matrix_rand)
A_recon_rand = A_recon.subs(matrix_rand)

mat1 = np.array(A_matrix_rand.tolist()).astype(float)
mat2 = np.array(A_params_rand.tolist()).astype(float)
mat3 = np.array(A_recon_rand.tolist()).astype(float)
assert np.all(np.isclose(mat1, mat2))

print(mat2 - mat3)

mat4 = np.array(A_solved_rand.tolist()).astype(float)

随机值似乎总是在矩阵中的a22处产生一些错误,所以我认为从手动输入的分解矩阵的sympy重建是错误的,或者分解本身是错误的。任何帮助都将非常有价值。

经过与同事的讨论,原来我在代码中犯了一个简单的错误。我交换了 sin 和 cos 项。在使用@Stéphane Laurent 的分解时修复此问题可以正确重建矩阵:

import sympy
import ubelt as ub

domain = {'real': True}

theta = sympy.symbols('theta', **domain)
sx, sy = sympy.symbols('sx, sy', **domain)
m = sympy.symbols('m', **domain)
params = [sx, theta, sy, m]

S = sympy.Matrix([  # scale
    [sx,  0],
    [ 0, sy]])

H = sympy.Matrix([  # shear
    [1, m],
    [0, 1]])

R = sympy.Matrix((  # rotation
    [sympy.cos(theta), -sympy.sin(theta)],
    [sympy.sin(theta),  sympy.cos(theta)]))

A_params = sympy.simplify((R @ H @ S))
a11, a12, a21, a22 = sympy.symbols(
    'a11, a12, a21, a22', real=True)
A_matrix = sympy.Matrix(((a11, a12), (a21, a22)))

print(ub.hzcat(['A_matrix = ', sympy.pretty(A_matrix)]))
print(ub.hzcat(['A_params = ', sympy.pretty(A_params)]))


# This is the guided solution by Stéphane Laurent
recon_sx = sympy.sqrt(a11 * a11 + a21 * a21)
recon_theta = sympy.atan2(a21, a11)
recon_sin_t = sympy.sin(recon_theta)
recon_cos_t = sympy.cos(recon_theta)

recon_msy = a12 * recon_cos_t + a22 * recon_sin_t


# condition2 = sympy.simplify(sympy.Eq(recon_sin_t, 0))
# condition1 = sympy.simplify(sympy.Not(condition2))
condition1 = sympy.Gt(recon_sin_t ** 2, recon_cos_t ** 2)
condition2 = sympy.Le(recon_sin_t ** 2, recon_cos_t ** 2)
sy_cond1 = (recon_msy * recon_cos_t - a12) / recon_sin_t
sy_cond2 = (a22 - recon_msy * recon_sin_t) / recon_cos_t
recon_sy = sympy.Piecewise((sy_cond1, condition1), (sy_cond2, condition2))
recon_m = sympy.simplify(recon_msy / recon_sy)


# Substitute the decomposition into the "A_params" to reconstruct "A_matrix"
recon_symbols = {
    sx: recon_sx,
    theta: recon_theta,
    m: recon_m,
    sy: recon_sy
}

for sym, symval in recon_symbols.items():
    # symval = sympy.radsimp(symval)
    symval = sympy.trigsimp(symval)
    symval = sympy.simplify(symval)
    if not isinstance(symval, sympy.Piecewise):
        symval = sympy.radsimp(symval)
    print('\n=====')
    print('sym = {!r}'.format(sym))
    print('symval  = {!r}'.format(symval))
    print('--')
    sympy.pretty_print(symval)
    print('=====\n')

A_recon = A_params.subs(recon_symbols)
A_recon = sympy.simplify(A_recon)
print(ub.hzcat(['A_recon = ', sympy.pretty(A_recon)]))

使用 Laurent 明确定义的分解的重建输出:

A_matrix = ⎡a₁₁  a₁₂⎤
           ⎢        ⎥
           ⎣a₂₁  a₂₂⎦
A_params = ⎡sx⋅cos(θ)  sy⋅(m⋅cos(θ) - sin(θ))⎤
           ⎢                                 ⎥
           ⎣sx⋅sin(θ)  sy⋅(m⋅sin(θ) + cos(θ))⎦

=====
sym = sx
symval  = sqrt(a11**2 + a21**2)
--
   _____________
  ╱    2      2
╲╱  a₁₁  + a₂₁
=====


=====
sym = theta
symval  = atan2(a21, a11)
--
atan2(a₂₁, a₁₁)
=====


=====
sym = m
symval  = (a11*a12 + a21*a22)/(a11*a22 - a12*a21)
--
a₁₁⋅a₁₂ + a₂₁⋅a₂₂
─────────────────
a₁₁⋅a₂₂ - a₁₂⋅a₂₁
=====


=====
sym = sy
symval  = (a11*a22*sqrt(a11**2 + a21**2) - a12*a21*sqrt(a11**2 + a21**2))/(a11**2 + a21**2)
--
           _____________              _____________
          ╱    2      2              ╱    2      2
a₁₁⋅a₂₂⋅╲╱  a₁₁  + a₂₁   - a₁₂⋅a₂₁⋅╲╱  a₁₁  + a₂₁
───────────────────────────────────────────────────
                       2      2
                    a₁₁  + a₂₁
=====

A_recon = ⎡a₁₁  a₁₂⎤
          ⎢        ⎥
          ⎣a₂₁  a₂₂⎦

我还能够让求解器生成一个正确重构“A_matrix”的解决方案,尽管我不得不跳过一些障碍,并且分解采用不同(有点奇怪)的形式。但它确实产生了正确的答案:

mat_equation = sympy.Eq(A_matrix, A_params)
solve_for = (sx, theta, sy, m)
solutions = sympy.solve(mat_equation, *solve_for, dict=True)
solved = {}
# minimal=True, quick=True, cubics=False, quartics=False, quintics=False, check=False)
for sol in solutions:
    for sym, symsol0 in sol.items():
        symsol = sympy.radsimp(symsol0)
        symsol = sympy.trigsimp(symsol)
        symsol = sympy.simplify(symsol)
        symsol = sympy.radsimp(symsol)
        print('\n=====')
        print('sym = {!r}'.format(sym))
        print('symsol  = {!r}'.format(symsol))
        print('--')
        sympy.pretty_print(symsol, wrap_line=False)
        solved[sym] = symsol
        print('--')
        print('=====\n')

    A_matrix[0, :].dot(A_matrix[1, :]) / A_matrix.det()

A_solved_recon = sympy.simplify(A_params.subs(solved))

print(ub.hzcat(['A_solved_recon = ', sympy.pretty(A_solved_recon)]))

虽然我还没有弄清楚所有的细节,但这个 sympy 计算的分解似乎是正确的:

=====
sym = sx
symsol  = -sqrt(a11**2 + a21**2)
--
    _____________
   ╱    2      2 
-╲╱  a₁₁  + a₂₁  
--
=====


=====
sym = theta
symsol  = -2*atan((a11 + sqrt(a11**2 + a21**2))/a21)
--
       ⎛         _____________⎞
       ⎜        ╱    2      2 ⎟
       ⎜a₁₁ + ╲╱  a₁₁  + a₂₁  ⎟
-2⋅atan⎜──────────────────────⎟
       ⎝         a₂₁          ⎠
--
=====


=====
sym = m
symsol  = (a11*a12 + a21*a22)/(a11*a22 - a12*a21)
--
a₁₁⋅a₁₂ + a₂₁⋅a₂₂
─────────────────
a₁₁⋅a₂₂ - a₁₂⋅a₂₁
--
=====


=====
sym = sy
symsol  = (-a11*a22*sqrt(a11**2 + a21**2) + a12*a21*sqrt(a11**2 + a21**2))/(a11**2 + a21**2)
--
             _____________              _____________
            ╱    2      2              ╱    2      2 
- a₁₁⋅a₂₂⋅╲╱  a₁₁  + a₂₁   + a₁₂⋅a₂₁⋅╲╱  a₁₁  + a₂₁  
─────────────────────────────────────────────────────
                        2      2                     
                     a₁₁  + a₂₁                      
--
=====

A_solved_recon = ⎡a₁₁  a₁₂⎤
                 ⎢        ⎥
                 ⎣a₂₁  a₂₂⎦