如何检查函数的输入是否在数据类型限制内?

How to check if function's input is within the data type limit?

我有一个函数,它接受一个 array-like 参数和一个 value 参数作为输入。在此函数的单元测试期间(我使用 hypothesis),如果抛出非常大的 valuenp.float128 无法处理的),则函数失败。

检测此类值并正确处理它们的好方法是什么?

下面是我的函数的代码:

def find_nearest(my_array, value):
    """ Find the nearest value in an unsorted array.
    """
    # Convert to numpy array and drop NaN values.
    my_array = np.array(my_array, copy=False, dtype=np.float128)
    my_array = my_array[~np.isnan(my_array)]

    return my_array[(np.abs(my_array - value)).argmin()]

引发错误的示例:

find_nearest([0.0, 1.0], 1.8446744073709556e+19)

抛出:0.0,但正确答案是 1.0

如果我不能抛出正确答案,至少我希望能够抛出异常。问题是现在我不知道如何识别错误的输入。适合其他情况的更一般的答案是可取的,因为我认为这是一个反复出现的问题。

注意,float128 实际上不是 128 位精度!它实际上是一个 longdouble 实现:https://en.wikipedia.org/wiki/Extended_precision。这种存储类型的精度是 63 位——这就是它在 1e+19 附近失败的原因,因为这对你来说是 63 位二进制位。当然,如果你的数组中的差异大于 1,它将能够区分那个数字,这只是意味着你试图让它区分的任何差异都必须大于你的 1/2**63输入 value.

What is the internal precision of numpy.float128? 这是一个详细说明相同内容的旧答案。我已经完成测试并确认 np.float128 正好是具有 63 位精度的 longdouble

我建议您为 value 设置一个最大值,如果您的值大于该值,则:

  1. 将值减少到该数字,前提是数组中的所有内容都将小于该数字。

  2. 报错

像这样:

VALUE_MAX = 1e18
def find_nearest(my_array, value):
    if value > VALUE_MAX:
        value = VALUE_MAX
    ...

或者,您可以选择更科学的方法,例如实际比较您的 value 与数组的最大值:

def find_nearest(my_array, value):
    my_array = np.array(my_array, dtype=np.float128)
    if value > np.amax(my_array):
        value = np.amax(my_array)
    elif value < np.amin(my_array):
        value = np.amin(my_array)
    ...

这样你就可以确保你永远不会 运行 遇到这个问题 - 因为你的值总是最多与数组的最大值一样大,或者至少与数组的最小值一样大。

这里的问题似乎不是 float128 无法处理 1.844...e+19,而是您可能无法将两个具有如此根本不同尺度的浮点数相加并且期望得到准确的结果:

In [1]: 1.8446744073709556e+19 - 1.0 == 1.8446744073709556e+19
Out[1]: True

如果你真的需要这种精度,你最好的选择是使用 Decimal 对象并将它们作为 dtype 'object':

放入一个 numpy 数组中
In [1]: from decimal import Decimal

In [2]: big_num = Decimal(1.8446744073709556e+19)

In [3]: big_num  # Note the slight innaccuracies due to floating point conversion
Out[3]: Decimal('18446744073709555712')

In [4]: a = np.array([Decimal(0.0), Decimal(1.0)], dtype='object')

In [5]: a[np.abs(a - big_num).argmin()]
Out[5]: Decimal('1')

请注意,这将比典型的 Numpy 操作慢得多,因为它必须为每个计算恢复到 Python 而不是能够利用其自己的优化库(因为 numpy 没有 Decimal类型)。

编辑:

如果您不需要这个解决方案,只是想知道您当前的代码是否会失败,我建议使用非常科学的方法 "just try":

fails = len(set(my_array)) == len(set(my_array - value))

这可以确保当您减去 valuemy_array 中的唯一数字 X 时,您会得到唯一的结果。这是关于减法的一个普遍真实的事实,如果它失败了,那是因为浮点运算不够精确,无法将 value - X 作为不同于 valueX 的数字来处理。