通过 scipy 拟合函数传递 numpy 数组

Passing numpy array through scipy fitting function

具有拟合功能:

def conv_gauss(x, constant, c, beta, sigma): 
    return constant * np.exp((x-c)/beta) * math.erfc((x-c)/(np.sqrt(2)*sigma) + sigma/(np.sqrt(2)*beta))

以及 x 和 y 数据:

x_data = [5751.0, 5752.0, 5753.0, 5754.0, 5755.0, 5756.0, 5757.0, 5758.0, 5759.0, 5760.0, 5761.0, 5762.0, 5763.0, 5764.0, 5765.0, 5766.0, 5767.0, 5768.0, 5769.0, 5770.0, 5771.0, 5772.0, 5773.0, 5774.0, 5775.0, 5776.0, 5777.0, 5778.0, 5779.0, 5780.0, 5781.0, 5782.0, 5783.0, 5784.0, 5785.0, 5786.0, 5787.0, 5788.0, 5789.0, 5790.0, 5791.0, 5792.0, 5793.0, 5794.0, 5795.0, 5796.0, 5797.0, 5798.0, 5799.0, 5800.0, 5801.0, 5802.0, 5803.0, 5804.0, 5805.0, 5806.0, 5807.0, 5808.0, 5809.0, 5810.0, 5811.0, 5812.0, 5813.0, 5814.0, 5815.0, 5816.0, 5817.0, 5818.0, 5819.0, 5820.0, 5821.0, 5822.0, 5823.0, 5824.0, 5825.0, 5826.0, 5827.0, 5828.0, 5829.0, 5830.0, 5831.0, 5832.0, 5833.0, 5834.0, 5835.0, 5836.0, 5837.0, 5838.0, 5839.0, 5840.0, 5841.0, 5842.0, 5843.0, 5844.0, 5845.0, 5846.0, 5847.0, 5848.0, 5849.0, 5850.0]

y_data = [2250.0, 2259.0, 2382.0, 2546.0, 2527.0, 2837.0, 2972.0, 3154.0, 3345.0, 3664.0, 4209.0, 4415.0, 4857.0, 5372.0, 5985.0, 6743.0, 7735.0, 8634.0, 9555.0, 11085.0, 12155.0, 13598.0, 15205.0, 17236.0, 19170.0, 21515.0, 24252.0, 26325.0, 28877.0, 31945.0, 35266.0, 38525.0, 42205.0, 45128.0, 48527.0, 52116.0, 55999.0, 59361.0, 62160.0, 65897.0, 69085.0, 71548.0, 73684.0, 74978.0, 76676.0, 76531.0, 76203.0, 75874.0, 74414.0, 72638.0, 69968.0, 67469.0, 64681.0, 61530.0, 58005.0, 54193.0, 51032.0, 47377.0, 43618.0, 40237.0, 37268.0, 33845.0, 30856.0, 27464.0, 25015.0, 22483.0, 20294.0, 18180.0, 16023.0, 14171.0, 12534.0, 11074.0, 9764.0, 8708.0, 7672.0, 6668.0, 5988.0, 5208.0, 4585.0, 4163.0, 3845.0, 3326.0, 2904.0, 2784.0, 2529.0, 2271.0, 2219.0, 2103.0, 1943.0, 1766.0, 1782.0, 1650.0, 1578.0, 1576.0, 1481.0, 1483.0, 1357.0, 1365.0, 1338.0, 1269.0]

通过拟合函数:

popt, pcov = optimize.curve_fit(conv_gauss, x_data, y_data, p0=[1,5800,1,50])

通过函数传递数组时出错。由于 curve_fit 需要拟合数据,我不知道如何解决这个问题。错误如下所示:

TypeError: only size-1 arrays can be converted to Python scalars

math.erfc替换为scipy.special.erfc

math.erfc 适用于标量。

后者也在array_likes

@Vulwsztyn 是正确的。我只想补充一点,您最初的猜测 p0 不会导致 curve_fit 的良好收敛。 constant 需要接近 max(y_data)。此外,我将 sigma 设置为 1。看看下面的代码和结果图。

import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
from scipy.special import erfc


def conv_gauss(x, constant, c, beta, sigma):
    return (constant * np.exp((x - c) / beta)
            * erfc((x - c) / np.sqrt(2) / sigma + sigma / np.sqrt(2) / beta))


x_data = [5751.0, 5752.0, 5753.0, 5754.0, 5755.0, 5756.0, 5757.0, 5758.0,
          5759.0, 5760.0, 5761.0, 5762.0, 5763.0, 5764.0, 5765.0, 5766.0,
          5767.0, 5768.0, 5769.0, 5770.0, 5771.0, 5772.0, 5773.0, 5774.0,
          5775.0, 5776.0, 5777.0, 5778.0, 5779.0, 5780.0, 5781.0, 5782.0,
          5783.0, 5784.0, 5785.0, 5786.0, 5787.0, 5788.0, 5789.0, 5790.0,
          5791.0, 5792.0, 5793.0, 5794.0, 5795.0, 5796.0, 5797.0, 5798.0,
          5799.0, 5800.0, 5801.0, 5802.0, 5803.0, 5804.0, 5805.0, 5806.0,
          5807.0, 5808.0, 5809.0, 5810.0, 5811.0, 5812.0, 5813.0, 5814.0,
          5815.0, 5816.0, 5817.0, 5818.0, 5819.0, 5820.0, 5821.0, 5822.0,
          5823.0, 5824.0, 5825.0, 5826.0, 5827.0, 5828.0, 5829.0, 5830.0,
          5831.0, 5832.0, 5833.0, 5834.0, 5835.0, 5836.0, 5837.0, 5838.0,
          5839.0, 5840.0, 5841.0, 5842.0, 5843.0, 5844.0, 5845.0, 5846.0,
          5847.0, 5848.0, 5849.0, 5850.0]

y_data = [2250.0, 2259.0, 2382.0, 2546.0, 2527.0, 2837.0, 2972.0, 3154.0,
          3345.0, 3664.0, 4209.0, 4415.0, 4857.0, 5372.0, 5985.0, 6743.0,
          7735.0, 8634.0, 9555.0, 11085.0, 12155.0, 13598.0, 15205.0, 17236.0,
          19170.0, 21515.0, 24252.0, 26325.0, 28877.0, 31945.0, 35266.0,
          38525.0, 42205.0, 45128.0, 48527.0, 52116.0, 55999.0, 59361.0,
          62160.0, 65897.0, 69085.0, 71548.0, 73684.0, 74978.0, 76676.0,
          76531.0, 76203.0, 75874.0, 74414.0, 72638.0, 69968.0, 67469.0,
          64681.0, 61530.0, 58005.0, 54193.0, 51032.0, 47377.0, 43618.0,
          40237.0, 37268.0, 33845.0, 30856.0, 27464.0, 25015.0, 22483.0,
          20294.0, 18180.0, 16023.0, 14171.0, 12534.0, 11074.0, 9764.0,
          8708.0, 7672.0, 6668.0, 5988.0, 5208.0, 4585.0, 4163.0, 3845.0,
          3326.0, 2904.0, 2784.0, 2529.0, 2271.0, 2219.0, 2103.0, 1943.0,
          1766.0, 1782.0, 1650.0, 1578.0, 1576.0, 1481.0, 1483.0, 1357.0,
          1365.0, 1338.0, 1269.0]

popt, pcov = curve_fit(conv_gauss, x_data, y_data, p0=[7.5e4, 5.8e3, 1, 1])
plt.plot(x_data, y_data, linestyle='none', marker='o')
plt.plot(np.linspace(5751, 5850, 100), conv_gauss(np.linspace(5751, 5850, 100),
                                                  *popt))
plt.show()