python 跨模块更改命名空间?
python change namespace across modules?
假设我创建了一个名为 my_mod
的模块,其中有两个文件 __init__.py
和 my_func.py
:
__init__.py:
import numpy
import cupy
xp = numpy # xp defaults to numpy
# a decorator that change `xp` based on input array class
def array_dispatch(fcn):
def wrapper(x, *args, **kwargs):
global xp
xp = cupy if isinstance(x, cupy.ndarray) else numpy
return fcn(x, *args, **kwargs)
return wrapper
from .my_func import *
和my_func.py:
from my_mod import xp, array_dispatch
@array_dispatch
def print_xp(x):
print(xp.__name__)
基本上我想 print_xp
根据输入 x
的 class 打印出“numpy”或“cupy”:如果输入 x
到print_xp
是一个numpy数组,然后打印出“numpy”;如果 x
是一个 cupy 数组,那么它应该打印出“cupy”。
然而,目前它总是打印出“numpy”,这是xp
的默认值。有人可以帮助我理解为什么,解决方案是什么?谢谢!
要回答您的具体问题,请不要这样做:
from my_mod import xp, array_dispatch
改为使用
import my_mod
然后在你的函数中引用my_mod.xp
:
@my_mod.array_dispatch
def print_xp(x):
print(my_mod.xp.__name__)
然后您将看到 my_mod
的全局命名空间的更新...
不过,您真的应该尽量避免使用这样的全局变量。
编辑:如果我正确理解您想要的内容,这是我会采用的方法。
import inspect
import cupy
import numpy
def array_dispatch(fcn):
sig = inspect.signature(fcn)
param = sig.parameters.get("xp")
if param is None:
raise ValueError("function must have an `xp` paramter")
if param.kind is not inspect.Parameter.KEYWORD_ONLY:
raise ValueError(f"`xp` parameter must be keyword only, got {param.kind}")
def wrapper(x, *args, **kwargs):
if isinstance(x, cupy.ndarray):
xp = cupy
elif isinstance(x, numpy.ndarray):
xp = numpy
else:
raise TypeError(f"expected either a numpy.ndarray or a cupy.ndarray, got {type(x)}")
return fcn(x, *args, xp=xp, **kwargs)
return wrapper
那么,这个装饰器的一个示例用户:
从 my_mod 导入 xp,array_dispatch
@array_dispatch
def frobnicate(x, *, xp):
return xp.tanh(x) + 42
import numpy as np
print(frobnicate(np.arange(10)))
定义xp
为数组
import numpy
import cupy
xp = [numpy] # xp defaults to numpy
# a decorator that change `xp` based on input array class
def array_dispatch(fcn):
def wrapper(x, *args, **kwargs):
global xp
xp[0] = cupy if isinstance(x, cupy.ndarray) else numpy
return fcn(x, *args, **kwargs)
return wrapper
from .my_func import *
和
from my_mod import xp, array_dispatch
@array_dispatch
def print_xp(x):
print(xp[0].__name__)
假设我创建了一个名为 my_mod
的模块,其中有两个文件 __init__.py
和 my_func.py
:
__init__.py:
import numpy
import cupy
xp = numpy # xp defaults to numpy
# a decorator that change `xp` based on input array class
def array_dispatch(fcn):
def wrapper(x, *args, **kwargs):
global xp
xp = cupy if isinstance(x, cupy.ndarray) else numpy
return fcn(x, *args, **kwargs)
return wrapper
from .my_func import *
和my_func.py:
from my_mod import xp, array_dispatch
@array_dispatch
def print_xp(x):
print(xp.__name__)
基本上我想 print_xp
根据输入 x
的 class 打印出“numpy”或“cupy”:如果输入 x
到print_xp
是一个numpy数组,然后打印出“numpy”;如果 x
是一个 cupy 数组,那么它应该打印出“cupy”。
然而,目前它总是打印出“numpy”,这是xp
的默认值。有人可以帮助我理解为什么,解决方案是什么?谢谢!
要回答您的具体问题,请不要这样做:
from my_mod import xp, array_dispatch
改为使用
import my_mod
然后在你的函数中引用my_mod.xp
:
@my_mod.array_dispatch
def print_xp(x):
print(my_mod.xp.__name__)
然后您将看到 my_mod
的全局命名空间的更新...
不过,您真的应该尽量避免使用这样的全局变量。
编辑:如果我正确理解您想要的内容,这是我会采用的方法。
import inspect
import cupy
import numpy
def array_dispatch(fcn):
sig = inspect.signature(fcn)
param = sig.parameters.get("xp")
if param is None:
raise ValueError("function must have an `xp` paramter")
if param.kind is not inspect.Parameter.KEYWORD_ONLY:
raise ValueError(f"`xp` parameter must be keyword only, got {param.kind}")
def wrapper(x, *args, **kwargs):
if isinstance(x, cupy.ndarray):
xp = cupy
elif isinstance(x, numpy.ndarray):
xp = numpy
else:
raise TypeError(f"expected either a numpy.ndarray or a cupy.ndarray, got {type(x)}")
return fcn(x, *args, xp=xp, **kwargs)
return wrapper
那么,这个装饰器的一个示例用户:
从 my_mod 导入 xp,array_dispatch
@array_dispatch
def frobnicate(x, *, xp):
return xp.tanh(x) + 42
import numpy as np
print(frobnicate(np.arange(10)))
定义xp
为数组
import numpy
import cupy
xp = [numpy] # xp defaults to numpy
# a decorator that change `xp` based on input array class
def array_dispatch(fcn):
def wrapper(x, *args, **kwargs):
global xp
xp[0] = cupy if isinstance(x, cupy.ndarray) else numpy
return fcn(x, *args, **kwargs)
return wrapper
from .my_func import *
和
from my_mod import xp, array_dispatch
@array_dispatch
def print_xp(x):
print(xp[0].__name__)