python 跨模块更改命名空间?

python change namespace across modules?

假设我创建了一个名为 my_mod 的模块,其中有两个文件 __init__.pymy_func.py:
__init__.py:

import numpy
import cupy

xp = numpy # xp defaults to numpy

# a decorator that change `xp` based on input array class
def array_dispatch(fcn):
    def wrapper(x, *args, **kwargs):
        global xp
        xp = cupy if isinstance(x, cupy.ndarray) else numpy
        return fcn(x, *args, **kwargs)
    return wrapper

from .my_func import *

my_func.py:

from my_mod import xp, array_dispatch

@array_dispatch
def print_xp(x):
    print(xp.__name__)

基本上我想 print_xp 根据输入 x 的 class 打印出“numpy”或“cupy”:如果输入 xprint_xp是一个numpy数组,然后打印出“numpy”;如果 x 是一个 cupy 数组,那么它应该打印出“cupy”。

然而,目前它总是打印出“numpy”,这是xp的默认值。有人可以帮助我理解为什么,解决方案是什么?谢谢!

要回答您的具体问题,请不要这样做:

from my_mod import xp, array_dispatch

改为使用

import my_mod

然后在你的函数中引用my_mod.xp

@my_mod.array_dispatch
def print_xp(x):
    print(my_mod.xp.__name__)

然后您将看到 my_mod 的全局命名空间的更新...

不过,您真的应该尽量避免使用这样的全局变量。

编辑:如果我正确理解您想要的内容,这是我会采用的方法。

import inspect
import cupy
import numpy

def array_dispatch(fcn):
    sig = inspect.signature(fcn)
    param = sig.parameters.get("xp")
    if param is None:
        raise ValueError("function must have an `xp` paramter")
    if param.kind is not inspect.Parameter.KEYWORD_ONLY:
        raise ValueError(f"`xp` parameter must be keyword only, got {param.kind}")
    
    def wrapper(x, *args, **kwargs):
        if isinstance(x, cupy.ndarray):
            xp = cupy
        elif isinstance(x, numpy.ndarray):
            xp = numpy
        else:
            raise TypeError(f"expected either a numpy.ndarray or a cupy.ndarray, got {type(x)}")
        return fcn(x, *args, xp=xp, **kwargs)
    return wrapper

那么,这个装饰器的一个示例用户:

从 my_mod 导入 xp,array_dispatch

@array_dispatch
def frobnicate(x, *, xp):
    return xp.tanh(x) + 42


import numpy as np

print(frobnicate(np.arange(10)))

定义xp为数组

import numpy
import cupy

xp = [numpy] # xp defaults to numpy

# a decorator that change `xp` based on input array class
def array_dispatch(fcn):
    def wrapper(x, *args, **kwargs):
        global xp
        xp[0] = cupy if isinstance(x, cupy.ndarray) else numpy
        return fcn(x, *args, **kwargs)
    return wrapper

from .my_func import *

from my_mod import xp, array_dispatch

@array_dispatch
def print_xp(x):
    print(xp[0].__name__)