传递给 scipy.optimize.curve_fit 的函数需要满足哪些特定要求才能 运行?

What specific requirements does the function passed to scipy.optimize.curve_fit need to fulfill in order to run?

我正在使用 matplotlib 的 hist 函数将统计模型拟合到分布。例如,我的代码使用以下代码符合指数分布:

    try:

        def expDist(x, a, x0):
            return a*(exp(-(x/x0))/x0)

        self.n, self.bins, patches = plt.hist(self.getDataSet(), self.getDatasetSize()/10, normed=1, facecolor='blue', alpha = 0.55)
        popt,pcov = curve_fit(expDist,self.bins[:-1], self.n, p0=[1,mean])
        print "Fitted gaussian curve to data with params a %f, x0 %f" % (popt[0], popt[1])
        self.a = popt[0]
        self.x0 = popt[1]

        self.fitted = True
    except RuntimeError:
        print "Unable to fit data to exponential curve"

运行良好,但是当我修改它以在 ab

之间进行均匀分布时执行相同的操作
    def uniDist(x, a, b):
        if((x >= a)and(x <= b)):
            return float(1.0/float(b-a))
        else:
            return 0.000

    try:



        self.n, self.bins, patches = plt.hist(self.getDataSet(), self.getDatasetSize()/10, normed=1, facecolor='blue', alpha = 0.55)
        popt,pcov = curve_fit(uniDist,self.bins[:-1], self.n, p0=[a, b])
        print "Fitted uniform distribution curve to data with params a %f, b %f" % (popt[0], popt[1])
        self.a = popt[0]
        self.b = popt[1]

        self.fitted = True
    except RuntimeError:
        print "Unable to fit data to uniform distribution pdf curve"

代码崩溃,

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

问题似乎是 curve_fit 中的某处,该函数正在尝试调用要拟合的函数(expDist,在本例中为 uniDist)一组值,但我无法弄清楚 expDist 函数如何能够在不崩溃的情况下接受任何可迭代的东西?

您的怀疑部分正确。 curve_fit 确实向函数传递了一个可迭代对象,但不仅仅是任何可迭代对象:numpy.ndarray。这些恰好拥有矢量化算术运算符,所以

a*(exp(-(x/x0))/x0)

将简单地 element-wise 处理输入数组而不会出现任何错误(并且输出正确)。甚至没有太多魔法:对于函数的每次评估,参数 ax0 将是标量,只有 x 是数组。

现在,uniDist 的问题在于它不仅包含算术运算符:它还包含比较运算符。只要将单个数组与标量进行比较,这些方法就可以正常工作:

>>> import numpy as np
>>> a = np.arange(5)
>>> a
array([0, 1, 2, 3, 4])
>>> a>2
array([False, False, False,  True,  True], dtype=bool)

以上说明在数组和标量上使用比较运算符将再次产生 element-wise 结果。当您尝试将逻辑运算符应用于其中两个布尔数组时,会出现您看到的错误:

>>> a>2
array([False, False, False,  True,  True], dtype=bool)
>>> a<4
array([ True,  True,  True,  True, False], dtype=bool)
>>> (a>2) and (a<4)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

错误信息有点混乱。可以追溯到 python 会尝试为 array1 and array2 得出一个结果(在本机 python 中会 return 任一数组基于它们的空性).然而,numpy 怀疑这不是你想做的,忍不住去猜。

由于您希望您的函数对两个布尔数组(来自比较操作)进行操作 element-wise,因此您必须使用 & 运算符。这是本机 python 中的 "binary and",但对于 numpy 数组,这会为您提供数组的 element-wise "logical and"。您还可以使用 numpy.logical_and(或者在您的情况下 scipy.logical_and)更明确:

>>> (a>2) & (a<4)
array([False, False, False,  True, False], dtype=bool)
>>> np.logical_and(a>2,a<4)
array([False, False, False,  True, False], dtype=bool)

请注意,对于 & 的情况,您始终必须在比较中加上括号,因为 a>2&a<4 又会是模棱两可的(对程序员而言)和错误的(考虑到您 想要的它来做)。由于布尔值的 "binary and" 的行为完全符合您的预期,因此可以安全地重写您的函数以使用 & 而不是 and 来比较这两个比较。

但是,您仍然需要更改一个步骤:在输入 ndarray 的情况下,if 的行为也会有所不同。 Python忍不住在一个if里单选,如果你把数组放进去也是这样。但是您真正想要做的是限制输出的元素 element-wise (再次)。所以你要么必须遍历你的数组(不要),要么以矢量化的方式再次做这个选择。后者是惯用的 numpy/scipy:

import scipy as sp
def uniDist(x, a, b):
    return sp.where((a<=x) & (x<=b), 1.0/(b-a), 0.0)

这个(即numpy.where) will return an array the same size as x. For the elements where the condition is True, the value of the output will be 1/(b-a). For the rest the output is 0. For scalar x, the return value is a numpy scalar. Note that I removed the float conversion in the above example, since having 1.0 in the numerator will definitely give you true division, despite your using python 2. Although I would suggest using python 3, or at least from __future__ import division.


小提示:即使对于标量情况,我也建议使用 python 的运算符链接进行比较,这有助于实现这一目的。我的意思是你可以简单地做 if a <= x <= b: ...,与大多数语言不同,这在功能上等同于你写的(但更漂亮)。