如何定义数据类,使其每个属性都是其子类属性的列表?

How to define a dataclass so each of its attributes is the list of its subclass attributes?

我有这个代码:

from dataclasses import dataclass
from typing import List

@dataclass
class Position:
    name: str
    lon: float
    lat: float

@dataclass
class Section:
    positions: List[Position]

pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2 , pos3])

print(sec.positions)

如何在数据类 Section 中创建附加属性,以便它们成为其子类 Position 的属性列表?

在我的示例中,我希望部分对象也 returns:

sec.name = ['a', 'b', 'c']   #[pos1.name,pos2.name,pos3.name]
sec.lon = [52, 46, 45]       #[pos1.lon,pos2.lon,pos3.lon]
sec.lat = [10, -10, -10]     #[pos1.lat,pos2.lat,pos3.lat]

我尝试将数据类定义为:

@dataclass
class Section:
    positions: List[Position]
    names :  List[Position.name]

但它不起作用,因为名称不是位置的属性。我可以稍后在代码中定义对象属性(例如通过 secs.name = [x.name for x in section.positions])。但如果能在数据类定义级别完成就更好了。

发布这个问题后我找到了答案的开头()。

但我想知道是否有更通用/“自动”的方法来定义 Section 方法:.names()、.lons()、.lats()、...?因此开发人员不必单独定义每个方法,而是根据 Positions 对象属性创建这些方法?

您可以在调用 __init__ 后创建一个新字段:

from dataclasses import dataclass, field, fields
from typing import List


@dataclass
class Position:
    name: str
    lon: float
    lat: float


@dataclass
class Section:
    positions: List[Position]
    _pos: dict = field(init=False, repr=False)

    def __post_init__(self):
        # create _pos after init is done, read only!
        Section._pos = property(Section._get_positions)

    def _get_positions(self):
        _pos = {}

        # iterate over all fields and add to _pos
        for field in [f.name for f in fields(self.positions[0])]:
            if field not in _pos:
                _pos[field] = []

            for p in self.positions:
                _pos[field].append(getattr(p, field))
        return _pos


pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2, pos3])

print(sec.positions)
print(sec._pos['name'])
print(sec._pos['lon'])
print(sec._pos['lat'])

输出:

[Position(name='a', lon=52, lat=10), Position(name='b', lon=46, lat=-10), Position(name='c', lon=45, lat=-10)]
['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]

编辑:

如果你只是需要它更通用,你可以覆盖 __getattr__:

from dataclasses import dataclass, field, fields
from typing import List


@dataclass
class Position:
    name: str
    lon: float
    lat: float


@dataclass
class Section:
    positions: List[Position]

    def __getattr__(self, keyName):
        for f in fields(self.positions[0]):
            if f"{f.name}s" == keyName:
                return [getattr(x, f.name) for x in self.positions]
        # Error handling here: Return empty list, raise AttributeError, ...

pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2, pos3])

print(sec.names)
print(sec.lons)
print(sec.lats)

输出:

['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]

经过深思熟虑,我想到了一种替代解决方案,方法是:


from dataclasses import dataclass
from typing import List

@dataclass
class Position:
    name: str
    lon: float
    lat: float

@dataclass
class Section:
    positions: List[Position]

    def names(self):
        return [x.name for x in self.positions]

    def lons(self):
        return [x.lon for x in self.positions]

    def lats(self):
        return [x.lat for x in self.positions]


pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2 , pos3])

print(sec.positions)
print(sec.names())
print(sec.lons())
print(sec.lats())

但我想知道是否有更通用/“自动”的方法来定义 Section 方法:.names(), .lons(), .lats(), ...? 因此开发人员不必单独定义每个方法,而是根据 Positions 对象属性创建这些方法?

按照我的理解,你想要声明数据class是平面数据容器(如Position),它们嵌套到另一个数据的容器中class(如 Section)。然后,外部数据class 应该能够通过简单的名称访问访问其内部数据class(es) 的所有属性的列表。

我们可以在常规数据 class 的工作方式之上实现这种功能(例如,调用它 introspect),并且可以按需启用它,类似于已经现有标志:

from dataclasses import is_dataclass, fields, dataclass as dc

# existing dataclass siganture, plus "instrospection" keyword
def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
              unsafe_hash=False, frozen=False, introspect=False):

    def wrap(cls):
        # run original dataclass decorator
        dc(cls, init=init, repr=repr, eq=eq, order=order,
           unsafe_hash=unsafe_hash, frozen=frozen)

        # add our custom "introspect" logic on top
        if introspect:
            for field in fields(cls):
                # only consider nested dataclass in containers
                try:
                    name = field.type._name
                except AttributeError:
                    continue
                if name not in ("List", "Set", "Tuple"):
                    continue
                contained_dc = field.type.__args__[0]
                if not is_dataclass(contained_dc):
                    continue
                # once we got them, add their fields as properties
                for dc_field in fields(contained_dc):
                    # if there are name-clashes, use f"{field.name}_{dc_field.name}" instead
                    property_name = dc_field.name
                    # bind variables explicitly to avoid funny bugs
                    def magic_property(self, field=field, dc_field=dc_field):
                        return [getattr(attr, dc_field.name) for attr in getattr(self, field.name)]
                    # here is where the magic happens
                    setattr(
                        cls,
                        property_name,
                        property(magic_property)
                    )
        return cls

    # Handle being called with or without parens
    if _cls is None:
        return wrap
    return wrap(_cls)

生成的 dataclass 函数现在可以按以下方式使用:

# regular dataclass
@dataclass
class Position:
    name: str
    lon: float
    lat: float
    
# this one will introspect its fields and try to add magic properties
@dataclass(introspect=True)
class Section:
    positions: List[Position]

就是这样。属性在 class 构造期间添加,如果任何对象在其生命周期内发生变化,甚至会相应更新:

>>> p_1 = Position("1", 1.0, 1.1)
>>> p_2 = Position("2", 2.0, 2.1)
>>> p_3 = Position("3", 3.0, 3.1)
>>> section = Section([p_1 , p_2, p_3])
>>> section.name
['1', '2', '3']
>>> section.lon
[1.0, 2.0, 3.0]
>>> p_1.lon = 5.0
>>> section.lon
[5.0, 2.0, 3.0]