在标量输入上生成 np.vectorize return 标量值
make np.vectorize return scalar value on scalar input
以下代码 return 是一个数组而不是预期的浮点值。
def f(x):
return x+1
f = np.vectorize(f, otypes=[np.float])
>>> f(10.5)
array(11.5)
如果输入是标量而不是奇怪的数组类型,有没有办法强制它 return 简单标量值?
我觉得很奇怪它默认情况下不这样做,因为所有其他 ufunc 像 np.cos、np.sin 等都做 return 常规标量
编辑:
这是有效的代码:
import numpy as np
import functools
def as_scalar_if_possible(func):
@functools.wraps(func) #this is here just to preserve signature
def wrapper(*args, **kwargs):
return func(*args, **kwargs)[()]
return wrapper
@as_scalar_if_possible
@np.vectorize
def f(x):
return x + 1
print(f(11.5)) # 打印 12.5
结果在技术上是一个标量,因为它的形状是 ()
。例如,np.array(11.5)[0]
不是有效操作,将导致异常。事实上,在大多数情况下,返回的结果将充当标量。
例如
x = np.array(11.5)
print(x + 1) # prints 12.5
print(x < 12) # prints True, rather than [ True]
x[0] # raises IndexError
如果你想得到一个 "proper" 标量值,那么你可以只包装矢量化函数来检查返回数组的形状。这就是 numpy ufuncs 在幕后所做的事情。
例如
import numpy as np
def as_scalar_if_possible(func):
def wrapper(arr):
arr = func(arr)
return arr if arr.shape else np.asscalar(arr)
return wrapper
@as_scalar_if_possible
@np.vectorize
def f(x):
return x + 1
print(f(11.5)) # prints 12.5
以下代码 return 是一个数组而不是预期的浮点值。
def f(x):
return x+1
f = np.vectorize(f, otypes=[np.float])
>>> f(10.5)
array(11.5)
如果输入是标量而不是奇怪的数组类型,有没有办法强制它 return 简单标量值?
我觉得很奇怪它默认情况下不这样做,因为所有其他 ufunc 像 np.cos、np.sin 等都做 return 常规标量
编辑: 这是有效的代码:
import numpy as np
import functools
def as_scalar_if_possible(func):
@functools.wraps(func) #this is here just to preserve signature
def wrapper(*args, **kwargs):
return func(*args, **kwargs)[()]
return wrapper
@as_scalar_if_possible
@np.vectorize
def f(x):
return x + 1
print(f(11.5)) # 打印 12.5
结果在技术上是一个标量,因为它的形状是 ()
。例如,np.array(11.5)[0]
不是有效操作,将导致异常。事实上,在大多数情况下,返回的结果将充当标量。
例如
x = np.array(11.5)
print(x + 1) # prints 12.5
print(x < 12) # prints True, rather than [ True]
x[0] # raises IndexError
如果你想得到一个 "proper" 标量值,那么你可以只包装矢量化函数来检查返回数组的形状。这就是 numpy ufuncs 在幕后所做的事情。
例如
import numpy as np
def as_scalar_if_possible(func):
def wrapper(arr):
arr = func(arr)
return arr if arr.shape else np.asscalar(arr)
return wrapper
@as_scalar_if_possible
@np.vectorize
def f(x):
return x + 1
print(f(11.5)) # prints 12.5