如何使用 RK4 方法修复 mass-spring-system 模拟中不正确的能量守恒问题

How to fix incorrect energy conservation problem in mass-spring-system simulation using RK4 method

我正在做一个模拟,你创建不同的一定质量的球,由你可以定义的 springs 连接(在下面的程序中,所有 springs 都有自然长度 L 和 spring常数k)。我是怎么做的 我创建了一个函数 accel(b,BALLS),(注意 b 是特定的球,BALLS 是所有的球 objects 在不同的更新阶段) 这让我通过计算作用在这个球上的所有力(球的张力 springs 连接到它和重力) 我认为这个函数绝对正确,问题出在 while 循环的其他地方。然后我使用这个网站上描述的 RK4 方法:http://spiff.rit.edu/richmond/nbody/OrbitRungeKutta4.pdf in the while loop to update velocity and position of each ball. To test my understanding of the method I first made a simulation where only two balls and one spring is involved on Desmos: https://www.desmos.com/calculator/4ag5gkerag 我考虑了能量显示,发现 RK4 确实比欧拉方法好得多。现在我在 python 中实现了它,希望它能与球和 spring 的任意配置一起工作,但是当我有两个球和一个 spring 时,能量甚至都不守恒!我看不出我做了什么不同,至少当涉及两个球时。当我向系统引入第三个球和第二个 spring 时,能量每秒增加数百。这是我第一次用RK4编写模拟代码,我希望你们能发现其中的错误。我有一个想法,也许问题是因为有多个身体,当我同时更新他们的 kas 或 kvs 时会出现困难,但我又一次无法发现模拟两个球时这段代码所做的事情之间的任何区别以及我在 Desmos 文件中使用的方法。这是我在 python:

中的代码
    import pygame
    import sys
    import math
    import numpy as np
    
    
    pygame.init()
    width = 1200
    height = 900
    SCREEN = pygame.display.set_mode((width, height))
    font = pygame.font.Font(None, 25)
    TIME = pygame.time.Clock()
    
    dampwall = 1
    dt = 0.003
    g = 20
    k=10
    L=200
    
    
    def dist(a, b):
        return math.sqrt((a[0] - b[0])*(a[0] - b[0]) + (a[1] - b[1])*(a[1] - b[1]))
    
    
    def mag(a):
        return dist(a, [0, 0])
    
    def dp(a, b):
        return a[0]*b[0]+a[1]*b[1]
    
    
    def norm(a):
        return list(np.array(a)/mag(a))
    
    
    def reflect(a, b):
        return norm([2*a[1]*b[0]*b[1]+a[0]*(b[0]**2 - b[1]**2), 2*a[0]*b[0]*b[1]+a[1]*(-b[0]**2 + b[1]**2)])
    
    
    
    
    class ball:
        def __init__(self, x, y, vx, vy, mass,spr,index,ka,kv):
            self.r = [x, y]
            self.v = [vx, vy]
    
            self.radius = 5
            self.mass = mass
            self.spr=spr
            self.index = index
            self.ka=ka
            self.kv=kv
            
        def detectbounce(self,width,height):
            if self.r[0] + self.radius > width/2 and self.r[0]+self.v[0] > self.r[0] or  self.r[0] - self.radius < -width/2 and self.r[0]+self.v[0] < self.r[0] or self.r[1] + self.radius > height/2 and self.r[1]+self.v[1] > self.r[1] or self.r[1] - self.radius < -height/2 and self.r[1]+self.v[1] < self.r[1]:
                return True
            
    
        def bounce_walls(self, width, height):
            
            
            if self.r[0] + self.radius > width/2 and self.r[0]+self.v[0] > self.r[0]:
                self.v[0] *= -dampwall
    
            if self.r[0] - self.radius < -width/2 and self.r[0]+self.v[0] < self.r[0]:
                self.v[0] *= -dampwall
    
            if self.r[1] + self.radius > height/2 and self.r[1]+self.v[1] > self.r[1]:
                self.v[1] *= -dampwall
    
            if self.r[1] - self.radius < -height/2 and self.r[1]+self.v[1] < self.r[1]:
                self.v[1] *= -dampwall
        
        
    
        def update_r(self,v, h):
    
            self.r[0] += v[0] * h 
            self.r[1] += v[1] * h
        
        def un_update_r(self,v, h):
    
            self.r[0] += -v[0] * h 
            self.r[1] += -v[1] * h
    
        
        def KE(self):
            return 0.5 * self.mass * mag(self.v)**2
    
        def GPE(self):
            return self.mass * g * (-self.r[1] + height)
        
    
        def draw(self, screen, width, height):
            pygame.draw.circle(screen, (0, 0, 255), (self.r[0] +
                               width / 2, self.r[1] + height / 2), self.radius)
            
    
    
    
    #(self, x, y, vx, vy, mass,spr,index,ka,kv):
    # balls = [ball(1, 19, 0, 0,5,[1],0,[0,0,0,0],[0,0,0,0]), ball(250, 20, 0,0,1,[0],1,[0,0,0,0],[0,0,0,0])]   
    # springs = [[0, 1]]
    
    balls = [ball(1, 19, 0, 0,5,[1,3],0,[0,0,0,0],[0,0,0,0]), ball(250, 20, 0,0,2,[0,2,3],1,[0,0,0,0],[0,0,0,0]),ball(450, 0, 0,0,2,[1,3],1,[0,0,0,0],[0,0,0,0]),ball(250, -60, 0,0,2,[0,1,2],1,[0,0,0,0],[0,0,0,0])]   
    springs = [[0, 1],[1,2],[0,3],[1,3],[2,3]]
    
    
    
    
    
    
    
    
    
    def accel(b,BALLS):
    
        A=[0,g]
        for i in range(0,len(b.spr)):
            ball1=b
            ball2=BALLS[b.spr[i]]
            r1 = norm(list(np.array(ball2.r) - np.array(ball1.r)))
            lnow = dist(ball1.r, ball2.r)
            force = k * (lnow - L)
            A[0]+=force/ball1.mass*r1[0]
            A[1]+=force/ball1.mass*r1[1]
            
        return A
            
    initE=0
    while True:
        TIME.tick(200)
        SCREEN.fill((0, 0, 0))
    
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                sys.exit()
    
        #compute k1a and k1v for all balls
        for ball in balls:
    
                ball.ka[0]=accel(ball,balls)
                ball.kv[0]=ball.v
                
        #create newb1 based on 'updated' position of all balls with their own k1v
        newb=[]
        for ball in balls:
                ball.update_r(ball.kv[0], dt/2)
                newb.append(ball)
                ball.un_update_r(ball.kv[0], dt/2)
                
        #compute k2a and k2v for all balls based on newb1
        for ball in balls:
                ball.update_r(ball.kv[0], dt/2)
                ball.ka[1]=accel(ball,newb)
                ball.un_update_r(ball.kv[0], dt/2)
                
                ball.kv[1]=[ball.v[0]+0.5*dt*ball.ka[0][0],ball.v[1]+0.5*dt*ball.ka[0][1]]
    
        #create newb2 based on 'updated' position of all balls with their own k2v       
        newb=[]
        for ball in balls:
     
                ball.update_r(ball.kv[1], dt/2)
                newb.append(ball)
                ball.un_update_r(ball.kv[1], dt/2)
                
        #compute k3a and k3v for all balls
        for ball in balls:
            
                ball.update_r(ball.kv[1], dt/2)
                ball.ka[2]=accel(ball,newb)
                ball.un_update_r(ball.kv[1], dt/2)
                
                ball.kv[2]=[ball.v[0]+0.5*dt*ball.ka[1][0],ball.v[1]+0.5*dt*ball.ka[1][1]]
        
        newb=[]
        for ball in balls:
    
                ball.update_r(ball.kv[2], dt)
                newb.append(ball)
                ball.un_update_r(ball.kv[2], dt)
        
        #compute k4a and k4v for all balls
        for ball in balls:
                ball.update_r(ball.kv[2], dt)
                ball.ka[3]=accel(ball,newb)
                ball.un_update_r(ball.kv[2], dt)
                
                ball.kv[3]=[ball.v[0]+dt*ball.ka[2][0],ball.v[1]+dt*ball.ka[2][1]]
                
        #final stage of update
        for ball in balls:
            if ball.detectbounce(width,height)==True:
                ball.bounce_walls(width, height)
            else:
                ball.v=[ball.v[0]+dt*(ball.ka[0][0]+2*ball.ka[1][0]+2*ball.ka[2][0]+ball.ka[3][0])/6, ball.v[1]+dt*(ball.ka[0][1]+2*ball.ka[1][1]+2*ball.ka[2][1]+ball.ka[3][1])/6]
                ball.r=[ball.r[0]+dt*(ball.kv[0][0]+2*ball.kv[1][0]+2*ball.kv[2][0]+ball.kv[3][0])/6, ball.r[1]+dt*(ball.kv[0][1]+2*ball.kv[1][1]+2*ball.kv[2][1]+ball.kv[3][1])/6]
            
        for ball in balls:      
            ball.draw(SCREEN, width, height)
            for i in range(0,len(ball.spr)):
                ball1=ball
                ball2=balls[ball.spr[i]]
                pygame.draw.line(SCREEN, (0, 0, 155), (
                    ball1.r[0]+width/2, ball1.r[1]+height/2), (ball2.r[0]+width/2, ball2.r[1]+height/2))
        
        #check for energy        
                
        KE = 0
        EPE = 0
        GPE = 0
        for i in range(0, len(springs)):
    
            EPE += 1/2 * k * \
                (L - dist(balls[springs[i][0]].r,
                 balls[springs[i][1]].r))**2
    
        for i in range(0, len(balls)):
            KE += balls[i].KE()
            GPE += balls[i].GPE()
    
    
        if initE == 0:
                initE += KE+EPE+GPE
    
    
        text = font.render('init Energy: ' + str(round(initE,1))+' '+'KE: ' + str(round(KE, 1)) + ' '+'EPE: ' + str(round(EPE, 1))+' ' + 'GPE: ' + str(round(GPE, 1)) + ' ' + 'Total: ' + str(round(KE+EPE+GPE, 1)) + ' ' + 'Diff: ' + str(round((KE+EPE+GPE-initE), 1)),
                               True, (255, 255, 255))
    
        textRect = text.get_rect()
        textRect.center = (370, 70)
        SCREEN.blit(text, textRect)
                
    
        pygame.display.flip()

这是由 Lutz Lehmann 编辑和更正的内容,并进行了一些额外的改进:

import pygame
import sys
import math
import numpy as np


pygame.init()
width = 1200
height = 900
SCREEN = pygame.display.set_mode((width, height))
font = pygame.font.Font(None, 25)
TIME = pygame.time.Clock()

dampwall = 1
dt = 0.003
g = 5
k = 10
L = 200

digits = 6


def dist(a, b):
    return math.sqrt((a[0] - b[0])*(a[0] - b[0]) + (a[1] - b[1])*(a[1] - b[1]))


def mag(a):
    return dist(a, [0, 0])


def dp(a, b):
    return a[0]*b[0]+a[1]*b[1]


def norm(a):
    return list(np.array(a)/mag(a))


def reflect(a, b):
    return norm([2*a[1]*b[0]*b[1]+a[0]*(b[0]**2 - b[1]**2), 2*a[0]*b[0]*b[1]+a[1]*(-b[0]**2 + b[1]**2)])


class Ball:
    def __init__(self, x, y, vx, vy, mass, spr, index, ka, kv):
        self.r = [x, y]
        self.v = [vx, vy]

        self.radius = 5
        self.mass = mass
        self.spr = spr
        self.index = index
        self.ka = ka
        self.kv = kv

    def copy(self):
        return Ball(self.r[0], self.r[1], self.v[0], self.v[1], self.mass, self.spr, self.index, self.ka, self.kv)

    def detectbounce(self, width, height):
        if self.r[0] + self.radius > width/2 and self.r[0]+self.v[0] > self.r[0] or self.r[0] - self.radius < -width/2 and self.r[0]+self.v[0] < self.r[0] or self.r[1] + self.radius > height/2 and self.r[1]+self.v[1] > self.r[1] or self.r[1] - self.radius < -height/2 and self.r[1]+self.v[1] < self.r[1]:
            return True

    def bounce_walls(self, width, height):

        if self.r[0] + self.radius > width/2 and self.r[0]+self.v[0] > self.r[0]:
            self.v[0] *= -dampwall

        if self.r[0] - self.radius < -width/2 and self.r[0]+self.v[0] < self.r[0]:
            self.v[0] *= -dampwall

        if self.r[1] + self.radius > height/2 and self.r[1]+self.v[1] > self.r[1]:
            self.v[1] *= -dampwall

        if self.r[1] - self.radius < -height/2 and self.r[1]+self.v[1] < self.r[1]:
            self.v[1] *= -dampwall

    def update_r(self, v, h):

        self.r[0] += v[0] * h
        self.r[1] += v[1] * h

    def un_update_r(self, v, h):

        self.r[0] += -v[0] * h
        self.r[1] += -v[1] * h

    def KE(self):
        return 0.5 * self.mass * mag(self.v)**2

    def GPE(self):
        return self.mass * g * (-self.r[1] + height)

    def draw(self, screen, width, height):
        pygame.draw.circle(screen, (0, 0, 255), (self.r[0] +
                           width / 2, self.r[1] + height / 2), self.radius)


# (self, x, y, vx, vy, mass,spr,index,ka,kv):


# balls = [Ball(1, 19, 0, 0, 1, [1], 0, [0, 0, 0, 0], [0, 0, 0, 0]),
#          Ball(250, 20, 0, 0, 1, [0], 1, [0, 0, 0, 0], [0, 0, 0, 0])]
# springs = [[0, 1]]

balls = [Ball(1, 19, 0, 0,5,[1,3],0,[0,0,0,0],[0,0,0,0]), Ball(250, 20, 0,0,2,[0,2,3],1,[0,0,0,0],[0,0,0,0]),Ball(450, 0, 0,0,2,[1,3],1,[0,0,0,0],[0,0,0,0]),Ball(250, -60, 0,0,2,[0,1,2],1,[0,0,0,0],[0,0,0,0])]

# n=5
# resprings=[]

# for i in range(0,n):
#     for j in range(0,n):
#         if i==0 and j==0:
#             resprings.append([1,2,n,n+1,2*n])
#         if i==n and j==0:
#             resprings.apend([n*(n-1)+1,n*(n-1)+2,n*(n-2),n*(n-3),n*(n-2)+1])
#         if j==0 and i!=0 or i!=n:
#             resprings.append([(i-1)*n+1,(i-1)*n+2,(i-2)*n,(i-2)*n+1,(i)*n,(i)*n+1])
        
            

def getsprings(B):
    S=[]
    for i in range(0,len(B)):
        theball=B[i]
        for j in range(len(theball.spr)):
            spring=sorted([i,theball.spr[j]])
            if spring not in S:
                S.append(spring)

    return S
            
    
springs = getsprings(balls)    
        
    





def accel(b, BALLS):

    A = [0, g]
    for i in range(0, len(b.spr)):
        ball1 = b
        ball2 = BALLS[b.spr[i]]
        r1 = norm(list(np.array(ball2.r) - np.array(ball1.r)))
        lnow = dist(ball1.r, ball2.r)
        force = k * (lnow - L)
        A[0] += force/ball1.mass*r1[0]
        A[1] += force/ball1.mass*r1[1]

    return A


initE = 0
while True:
    TIME.tick(200)
    SCREEN.fill((0, 0, 0))

    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            pygame.quit()
            sys.exit()
    for ball in balls:
        ball.bounce_walls(width, height)

    # compute k1a and k1v for all balls
    for ball in balls:

        ball.ka[0] = accel(ball, balls)
        ball.kv[0] = ball.v

    # create newb1 based on 'updated' position of all balls with their own k1v
    newb = []
    for ball in balls:
        ball.update_r(ball.kv[0], dt/2)
        newb.append(ball.copy())
        ball.un_update_r(ball.kv[0], dt/2)

    # compute k2a and k2v for all balls based on newb1
    for ball in balls:
        ball.update_r(ball.kv[0], dt/2)
        ball.ka[1] = accel(ball, newb)
        ball.un_update_r(ball.kv[0], dt/2)

        ball.kv[1] = [ball.v[0]+0.5*dt*ball.ka[0]
                      [0], ball.v[1]+0.5*dt*ball.ka[0][1]]

    # create newb2 based on 'updated' position of all balls with their own k2v
    newb = []
    for ball in balls:

        ball.update_r(ball.kv[1], dt/2)
        newb.append(ball.copy())
        ball.un_update_r(ball.kv[1], dt/2)

    # compute k3a and k3v for all balls
    for ball in balls:

        ball.update_r(ball.kv[1], dt/2)
        ball.ka[2] = accel(ball, newb)
        ball.un_update_r(ball.kv[1], dt/2)

        ball.kv[2] = [ball.v[0]+0.5*dt*ball.ka[1]
                      [0], ball.v[1]+0.5*dt*ball.ka[1][1]]

    newb = []
    for ball in balls:

        ball.update_r(ball.kv[2], dt)
        newb.append(ball.copy())
        ball.un_update_r(ball.kv[2], dt)

    # compute k4a and k4v for all balls
    for ball in balls:
        ball.update_r(ball.kv[2], dt)
        ball.ka[3] = accel(ball, newb)
        ball.un_update_r(ball.kv[2], dt)

        ball.kv[3] = [ball.v[0]+dt*ball.ka[2][0], ball.v[1]+dt*ball.ka[2][1]]

    # final stage of update
    for ball in balls:
        ball.v = [ball.v[0]+dt*(ball.ka[0][0]+2*ball.ka[1][0]+2*ball.ka[2][0]+ball.ka[3][0])/6,
                  ball.v[1]+dt*(ball.ka[0][1]+2*ball.ka[1][1]+2*ball.ka[2][1]+ball.ka[3][1])/6]
        ball.r = [ball.r[0]+dt*(ball.kv[0][0]+2*ball.kv[1][0]+2*ball.kv[2][0]+ball.kv[3][0])/6,
                  ball.r[1]+dt*(ball.kv[0][1]+2*ball.kv[1][1]+2*ball.kv[2][1]+ball.kv[3][1])/6]

    for ball in balls:
        ball.draw(SCREEN, width, height)
        for i in range(0, len(ball.spr)):
            ball1 = ball
            ball2 = balls[ball.spr[i]]
            pygame.draw.line(SCREEN, (0, 0, 155), (
                ball1.r[0]+width/2, ball1.r[1]+height/2), (ball2.r[0]+width/2, ball2.r[1]+height/2))

    # check for energy

    KE = 0
    EPE = 0
    GPE = 0
    for i in range(0, len(springs)):

        EPE += 1/2 * k * \
            (L - dist(balls[springs[i][0]].r,
             balls[springs[i][1]].r))**2

    for i in range(0, len(balls)):
        KE += balls[i].KE()
        GPE += balls[i].GPE()

    if initE == 0:
        initE += KE+EPE+GPE
    
    
    text1 = font.render(f"initial energy: {str(round(initE, digits))}", True, (255, 255, 255))
    text2 = font.render(f"kinetic energy: {str(round(KE, digits))}", True, (255, 255, 255))
    text3 = font.render(f"elastic potential energy: {str(round(EPE, digits))}", True, (255, 255, 255))
    text4 = font.render(f"gravitational energy: {str(round(GPE, digits))}", True, (255, 255, 255))
    text5 = font.render(f"total energy: {str(round(KE + EPE + GPE, digits))}", True, (255, 255, 255))
    text6 = font.render(f"change in energy: {str(round(KE + EPE + GPE - initE, digits))}", True, (255, 255, 255))

    SCREEN.blit(text1, (10, 10))
    SCREEN.blit(text2, (10, 60))
    SCREEN.blit(text3, (10, 110))
    SCREEN.blit(text4, (10, 160))
    SCREEN.blit(text5, (10, 210))
    SCREEN.blit(text6, (10, 260))
    

    pygame.display.flip()

直接错误好像是这个

    for ball in balls:
            ...
            newb1.append(ball)
            ...

因为 ball 只是对 class ball 实例的引用,因此 newb1 是对 balls 中对象的引用列表,如果你操纵一个或另一个,它总是被改变的相同数据记录。

你需要应用复制机制,因为你有列表的列表,你需要一个深复制,或者一个专用的复制成员方法,否则你只复制ball实例中的数组引用,所以你得到不同的实例,但指向相同的数组。

可能不是错误,但在同一范围内将 class 名称也作为变量名称仍然是个坏主意。