从节点和系数创建 BSpline

Create BSpline from knots and coefficients

如果只知道点和系数,如何创建样条曲线?我在这里使用 scipy.interpolate.BSpline,但也对其他标准包开放。所以基本上我希望能够为某人提供那些简短的系数数组,以便他们能够重新创建对数据的拟合。请参阅下面失败的红色虚线曲线。

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import BSpline, LSQUnivariateSpline

x = np.linspace(0, 10, 50) # x-data
y = np.exp(-(x-5)**2/4)    # y-data

# define the knot positions

t = [1, 2, 4, 5, 6, 8, 9]

# get spline fit

s1 = LSQUnivariateSpline(x, y, t)

x2 = np.linspace(0, 10, 200) # new x-grid
y2 = s1(x2) # evaluate spline on that new grid

# FAILED: try to construct BSpline using the knots and coefficients

k = s1.get_knots()
c = s1.get_coeffs()
s2 = BSpline(t,c,2)

# plotting

plt.plot(x, y, label='original')
plt.plot(t, s1(t),'o', label='knots')
plt.plot(x2, y2, '--', label='spline 1')
plt.plot(x2, s2(x2), 'r:', label='spline 2') 
plt.legend()

The fine print under get_knots 说:

Internally, the knot vector contains 2*k additional boundary knots.

这意味着,要从 get_knots 中获得一个可用的结数组,应该在数组的开头添加 k 个左边界结的副本,并在数组的开头添加 k 个副本右边界结在最后。这里 k 是样条的阶数,通常为 3(您要求默认阶数 LSQUnivariateSpline,因此为 3)。所以:

kn = s1.get_knots()
kn = 3*[kn[0]] + list(kn) + 3*[kn[-1]]
c = s1.get_coeffs()
s2 = BSpline(kn, c, 3)    # not "2" as in your sample; we are working with a cubic spline 

现在,样条 s2 与 s1 相同:

等效地,kn = 4*[x[0]] + t + 4*[x[-1]] 会起作用:您的 t 列表仅包含内部结,因此添加 x[0]x[-1],然后每个重复 k 倍。

重复的数学原因是 B 样条需要一些空间来构建,因为它们的 inductive definition 需要 (k-1) 度样条存在于我们定义的每个间隔周围k 次样条。

如果你不太关心结位置的细节,这里有一个稍微更紧凑的方法。 tk 数组就是您要查找的内容。一旦 tk 到手,就可以使用 y=splev(x,tk,der=0) 行复制样条。

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import splrep,splev
import matplotlib.pyplot as plt 

### Input data
x_arr = np.linspace(0, 10, 50)  # x-data
y_arr = np.exp(-(x_arr-5)**2/4)    # y-data
### Order of spline 
order = 3
### Make the spline 
tk = splrep(x_arr, y_arr, k=order) # Returns the knots and coefficents
### Evaluate the spline using the knots and coefficents on the domian x
x = np.linspace(0, 10, 1000) # new x-grid
y = splev(x, tk, der=0)
### Plot
f,ax=plt.subplots()
ax.scatter(x_arr, y_arr, label='original')
ax.plot(x,y,label='Spline')
ax.legend(fontsize=15)
plt.show()