Python 用于计算基本拉格朗日多项式的代码对于大量数据花费的时间太长

Python code to calculate base Lagrange Polynomial is taking too long for large numbers of data

因此,我需要帮助最大限度地减少仅通过使用 NumPy 运行 具有大量数据的代码所需的时间。我认为 for 循环使我的代码效率低下..但我不知道如何将 for 循环变成列表理解,这可能有助于它 运行 更快..

def lagrange(p,node,n,x):
    m=[]

        #base lagrange polynomial
    for i in range(n):
        for j in range(p+1):
            L=1
            for k in range(p+1):
                if k!=j:
                    L= L*(x[i] - node[k])/(node[j] - node[k])
                
            m.append(L)
    lagrange= np.array(m).reshape(n,p+1)
    return lagrange


def interpolant(a,b,p,n,x,f):
        m=[]
        node=np.linspace(a,b,p+1)
        for j in range(n):
             polynomial=0 
            for i in range(p+1):
                polynomial += f(node[i]) * lagrange(p,node,n,x) 
        m.append(polynomial) 

        interpolant = np.array(inter) 

                return interpolant
    

看来 lagrange_poly(...) 的值无缘无故 重新计算了 n*(p+1) 这非常非常昂贵!您可以在循环之前计算一次,将其存储在一个变量中,稍后再使用该变量。

固定代码如下:

def uniform_poly_interpolation(a,b,p,n,x,f,produce_fig):
        inter=[]
        xhat=np.linspace(a,b,p+1)
        #use for loop to iterate interpolant.
        mat = lagrange_poly(p,xhat,n,x,1e-10)[0]
        for j in range(n):
                po=0 
                for i in range(p+1):
                    po += f(xhat[i]) * mat[i,j]
        inter.append(po) 

        interpolant = np.array(inter) 
        return interpolant

这应该快得多。


此外,执行速度很慢,因为从 CPython 访问 Numpy 数组的标量值非常慢。 Numpy 旨在与数组一起使用,而不是在循环中提取标量值。此外,循环 CPython 解释器相对较慢。您可以使用 Numba 有效地解决这个问题,它使用 JIT-compiler.

将您的代码编译为非常快速的本机代码

这是 Numba 代码:

import numba as nb

@nb.njit
def lagrange_poly(p, xhat, n, x, tol):
    error_flag = 0
    er = 1
    lagrange_matrix = np.empty((n, p+1), dtype=np.float64)

    for l in range(p):
        if abs(xhat[l] - xhat[l+1]) < tol:
            error_flag = er

    # Base lagrange polynomial
    for i in range(n):
        for j in range(p+1):
            L = 1.0
            for k in range(p+1):
                if k!=j:
                    L = L * (x[i] - xhat[k]) / (xhat[j] - xhat[k])
            lagrange_matrix[i, j] = L

    return lagrange_matrix, error_flag

总的来说,这应该快几个数量级