Python 访问数据类 default_factory 字段而不实例化它

Python access dataclass default_factory field without instantiating it

我有一个数据classBar的class方法create_instance()需要检查数据field(default_factory=...)类型的参数class 在 return 数据实例之前 class:

from dataclasses import dataclass, field
from typing import Dict


@dataclass
class Bar:
    number: int
    static_dict: Dict[str, int] = field(
        default_factory=lambda: {
            'foo': 1,
        }
    )

    @classmethod
    def create_instance(cls) -> 'Bar':
        access = cls.static_dict
        return cls(number=access['foo'])

if __name__ == '__main__':
    foo = Bar.static_dict  # 1) fails already
    foo = Bar.create_instance()  # 2) fails too since it uses 1)

这失败了:

AttributeError: type object 'Bar' has no attribute 'static_dict'

没有类型提示 return 值,例如

@classmethod
def create_instance(cls):
    access = cls.static_dict
    return cls(number=access['foo-1'])

如果通过提及数据失败class 本身没有属性:

AttributeError: type object 'Bar' has no attribute 'static_dict'.

使用静态方法也会出现同样的错误,例如带有类型提示:

@staticmethod
def create_instance_static() -> 'Bar':
    access = Bar.static_dict
    return Bar(number=access['foo'])

我非常丑陋的解决方法是定义一个全局变量并将其传递给 class,这样我就可以使用 create_instance()-方法:

from dataclasses import dataclass, field
from typing import Dict

GLOBAL_DICT = {
    'foo': 2,
}


@dataclass
class BarWorkaround:
    number: int
    static_dict: Dict[str, int] = field(
        default_factory=lambda: GLOBAL_DICT
    )

    @classmethod
    def create_instance(cls) -> 'BarWorkaround':
        access = GLOBAL_DICT
        return cls(number=access['foo'])


if __name__ == '__main__':
    foo = BarWorkaround.create_instance()  # works as intended

不能说我喜欢为这样的东西定义全局变量。我的猜测是数据 classes 的构造方式干扰了我访问静态默认属性,但我不明白它是什么。

可以考虑把create_instance转成class方法,然后调用dataclasses.fields, which returns a list of dataclass fields for a provided dataclass. Each element in the fields tuple will be a Field类型。

例如:

from dataclasses import dataclass, field, fields
from typing import Dict


@dataclass
class Bar:
    number: int
    static_dict: Dict[str, int] = field(
        default_factory=lambda: {
            'foo': 1,
        }
    )

    @classmethod
    def create_instance(cls) -> 'Bar':
        access = next(f for f in fields(Bar)
                      if f.name == 'static_dict').default_factory()

        return Bar(number=access['foo'])


if __name__ == '__main__':
    # this should fail (`static_dict` is an instance - not class - attribute)
    # foo = Bar.static_dict

    # this should work though
    static_dict_field = next(f for f in fields(Bar) if f.name == 'static_dict')
    lambda_fn = static_dict_field.default_factory
    print(lambda_fn())  # {'foo': 1}

    foo = Bar.create_instance()  # 2) works with `dataclasses.fields`

如果您计划多次调用 create_instance,出于性能原因,缓存 default_factory 值可能是个好主意,例如。使用“缓存”class 属性,或像这样的简单解决方法:

from dataclasses import dataclass, field, fields
from typing import Dict, ClassVar, Callable


@dataclass
class Bar:
    number: int
    static_dict: Dict[str, int] = field(
        default_factory=lambda: {
            'foo': 1,
        }
    )
    # added for type hinting and IDE support
    #
    # note: `dataclasses` ignores anything annotated with `ClassVar`, or else
    # not annotated to begin with.
    __static_dict_factory__: ClassVar[Callable[[], Dict[str, int]]]

    @classmethod
    def create_instance(cls) -> 'Bar':
        return Bar(number=cls.__static_dict_factory__()['foo'])


# need to set it outside, or perhaps consider using a metaclass approach
setattr(Bar, '__static_dict_factory__',
        next(f for f in fields(Bar) if f.name == 'static_dict').default_factory)


if __name__ == '__main__':
    foo1 = Bar.create_instance()
    foo2 = Bar.create_instance()

    foo1.static_dict['key'] = 123
    print(foo1.static_dict)

    # assert each instance gets its own copy of the dict
    assert foo1.static_dict != foo2.static_dict