在代码中使用 a.any() 或 a.all() 错误来解决 Coupled ODE

Use a.any() or a.all() error in the code to solve Coupled ODE

上下文: 我不确定这是否是 post 这个问题的正确站点,如果不是,请告诉我。我的目标是求解半人马座阿尔法星系统代码中给出的耦合微分方程。

代码:

#Import scipy, numpy and mpmath
import scipy as sci
import numpy as np
import mpmath as mp
#Import matplotlib and associated modules for 3D and animations
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import animation
#Import decimal for better precision
from decimal import *
getcontext().prec = 10000

#Define universal gravitation constant
G=Decimal(6.67408e-11) #N-m2/kg2
#Reference quantities
m_nd=Decimal(1.989e+30) #kg #mass of the sun
r_nd=Decimal(5.326e+12) #m #distance between stars in Alpha Centauri
v_nd=Decimal(30000) #m/s #relative velocity of earth around the sun
t_nd=Decimal(79.91*365*24*3600*0.51) #s #orbital period of Alpha Centauri
#Net constants
K1=G*t_nd*m_nd/(r_nd**2*v_nd)
K2=v_nd*t_nd/r_nd

#Define masses
m1=Decimal(1.1) #Alpha Centauri A
m2=Decimal(0.907) #Alpha Centauri B 
m3=Decimal(1.0) #Third Star

#Define initial position vectors    
r1=np.array([Decimal(-0.5),Decimal(0),Decimal(0)])
r2=np.array([Decimal(0.5),Decimal(0),Decimal(0)])
r3=np.array([Decimal(0),Decimal(1),Decimal(0)])

#Find Centre of Mass
r_com=(m1*r1+m2*r2+m3*r3)/(m1+m2+m3)
#Define initial velocities
v1=np.array([Decimal(0.01),Decimal(0.01),Decimal(0)])
v2=np.array([Decimal(-0.05),Decimal(0),Decimal(-0.1)])
v3=np.array([Decimal(0),Decimal(-0.01),Decimal(0)])

#Find velocity of COM
v_com=(m1*v1+m2*v2+m3*v3)/(m1+m2+m3)#Define initial velocities

def ThreeBodyEquations(w,t,G,m1,m2,m3):
    r1=w[:3]
    r2=w[3:6]
    r3=w[6:9]
    v1=w[9:12]
    v2=w[12:15]
    v3=w[15:18]
    r12=sci.linalg.norm(r2-r1)
    r13=sci.linalg.norm(r3-r1)
    r23=sci.linalg.norm(r3-r2)
    
    dv1bydt=K1*m2*(r2-r1)/r12**3+K1*m3*(r3-r1)/r13**3+(61**2)*r1
    dv2bydt=K1*m1*(r1-r2)/r12**3+K1*m3*(r3-r2)/r23**3+(61**2)*r2
    dv3bydt=K1*m1*(r1-r3)/r13**3+K1*m2*(r2-r3)/r23**3+(61**2)*r3
    dr1bydt=K2*v1
    dr2bydt=K2*v2
    dr3bydt=K2*v3
    r12_derivs=sci.concatenate((dr1bydt,dr2bydt))
    r_derivs=sci.concatenate((r12_derivs,dr3bydt))
    v12_derivs=sci.concatenate((dv1bydt,dv2bydt))
    v_derivs=sci.concatenate((v12_derivs,dv3bydt))
    derivs=sci.concatenate((r_derivs,v_derivs))
    return derivs

#Package initial parameters
init_params=np.array([r1,r2,r3,v1,v2,v3]) #Initial parameters
init_params=init_params.flatten() #Flatten to make 1D array
time_span=sci.linspace(0,20,500) #20 orbital periods and 500 points

#Run the ODE solver
three_body_sol=mp.odefun(ThreeBodyEquations,time_span,init_params,time_span)

r1_sol=three_body_sol[:,:3]
r2_sol=three_body_sol[:,3:6]
r3_sol=three_body_sol[:,6:9]

#Create figure
fig=plt.figure(figsize=(15,15))
#Create 3D axes
ax=fig.add_subplot(111,projection="3d")
#Plot the orbits
ax.plot(r1_sol[:,0],r1_sol[:,1],r1_sol[:,2],color="darkblue")
ax.plot(r2_sol[:,0],r2_sol[:,1],r2_sol[:,2],color="tab:red")
#Plot the final positions of the stars
ax.scatter(r1_sol[-1,0],r1_sol[-1,1],r1_sol[-1,2],color="darkblue",marker="o",s=100,label="Alpha Centauri A")
ax.scatter(r2_sol[-1,0],r2_sol[-1,1],r2_sol[-1,2],color="tab:red",marker="o",s=100,label="Alpha Centauri B")
#Add a few more bells and whistles
ax.set_xlabel("x-coordinate",fontsize=14)
ax.set_ylabel("y-coordinate",fontsize=14)
ax.set_zlabel("z-coordinate",fontsize=14)
ax.set_title("Visualization of orbits of stars in a two-body system\n",fontsize=14)
ax.legend(loc="upper left",fontsize=14)

令我惊讶的是,我收到了这个错误

ValueError                                Traceback (most recent call last)
<ipython-input-11-8ecff918f44e> in <module>
     88 #Run the ODE solver
     89 import scipy.integrate
---> 90 three_body_sol=mp.odefun(ThreeBodyEquations,time_span,init_params,time_span)
     91 
     92 r1_sol=three_body_sol[:,:3]
/usr/local/lib/python3.8/dist-packages/mpmath/calculus/odes.py in odefun(ctx, F, x0, y0, tol, degree, method, verbose)
    228 
    229     """
--> 230     if tol:
    231         tol_prec = int(-ctx.log(tol, 2))+10
    232     else:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

现在我推测 Python 希望我在输入初始参数时使用 a.any()a.all()np.any(time_span)np.any(init_params) 也会抛出一个错误。有人可以告诉我出了什么问题,我该如何纠正?提前谢谢你

您需要阅读并理解文档。 mpmath.odefunscipy.integrate.odeint 有着根本的不同。 mpmath.odefun 提供了一个更类似于 scipy.integrate.ode 步进器 class 的动态解决方案对象,因为在调用时它不计算(很多),它只是初始化一个对象。在对返回对象的后续调用中计算并存储以“密集输出”形式存在的实际解决方案数据。根据需要扩展该数据的时间范围。

如何做到这一点可以在文档示例中看到。在你的情况下,这可以作为

three_body_fun=mp.odefun(ThreeBodyEquations,time_span[0],init_params, tol=1e-4, degree=5)
three_body_sol = [ three_body_fun(t) for t in time_span]

要使其开始工作,您需要将 ThreeBodyEquations 更改为仅按该顺序包含参数 t,w。没有必要将常量作为参数传递,它们取自全局上下文。请注意 w 及其切片是简单列表,您需要转换为向量格式才能应用向量减法。

使用numpy/scipy数组存储多精度数据,进行向量运算应该没有问题。然而,范数函数将使用浮点平方根函数,这会使多精度实现的所有努力无效,因此最好编写自己的欧几里德范数或使用 mpmath 变体(如果可用)。

我没有测试 Decimal 和 mpmath 协同工作的效果,我只是将前者的导入替换为

Decimal = lambda x: mp.mpf(x)

您可能希望将 mpmath 精度设置为合理的值,使用 25 位小数可能是明智的,使用 80 位可以作为参考值,使用 10000 位或大约 3000 位小数会不合理地增加计算时间,尤其是因为您的常量和输入没有那样的准确性。