如何为省略默认值的`dataclass`定义`__str__`?

How to define `__str__` for `dataclass` that omits default values?

给定一个 dataclass 实例,我希望 print()str() 仅列出非默认字段值。这在 dataclass 有很多字段而只有少数被更改时很有用。

@dataclasses.dataclass
class X:
  a: int = 1
  b: bool = False
  c: float = 2.0

x = X(b=True)
print(x)  # Desired output: X(b=True)

解决方案是添加自定义 __str__() 函数:

@dataclasses.dataclass
class X:
  a: int = 1
  b: bool = False
  c: float = 2.0

  def __str__(self):
    """Returns a string containing only the non-default field values."""
    s = ', '.join(f'{field.name}={getattr(self, field.name)!r}'
                  for field in dataclasses.fields(self)
                  if getattr(self, field.name) != field.default)
    return f'{type(self).__name__}({s})'

x = X(b=True)
print(x)        # X(b=True)
print(str(x))   # X(b=True)
print(repr(x))  # X(a=1, b=True, c=2.0)
print(f'{x}, {x!s}, {x!r}')  # X(b=True), X(b=True), X(a=1, b=True, c=2.0)

这也可以使用装饰器来实现:

def terse_str(cls):  # Decorator for class.
  def __str__(self):
    """Returns a string containing only the non-default field values."""
    s = ', '.join(f'{field.name}={getattr(self, field.name)}'
                  for field in dataclasses.fields(self)
                  if getattr(self, field.name) != field.default)
    return f'{type(self).__name__}({s})'

  setattr(cls, '__str__', __str__)
  return cls

@dataclasses.dataclass
@terse_str
class X:
  a: int = 1
  b: bool = False
  c: float = 2.0

我建议的一项改进是计算 dataclasses.fields 的结果,然后缓存结果的默认值。这将有助于提高性能,因为目前 dataclasses 每次调用时都会计算 fields

这是一个使用 元类 方法的简单示例。这应该在 python 3.8+ 中与海象 := 运算符一起工作。

请注意,我还对其进行了轻微修改,以便它处理定义 default_factory 可变 类型的字段。

from __future__ import annotations
import dataclasses


def terse_str(name, bases, cls_dict):  # Metaclass for class

    def __str__(self):
        cls_fields: tuple[dataclasses.Field, ...] = dataclasses.fields(self)

        field_to_default: dict[str, type] = {}
        for f in cls_fields:
            if f.default_factory is not dataclasses.MISSING:
                field_to_default[f.name] = f.default_factory()
            else:
                field_to_default[f.name] = f.default

        def __str__(self, name=name, fields=field_to_default):
            """Returns a string containing only the non-default field values."""
            s = ', '.join([f'{field}={val!r}'
                          for field, default in fields.items()
                          if (val := getattr(self, field)) != default])

            return f'{name}({s})'

        # set the __str__ with the cached `dataclass.fields`
        setattr(type(self), '__str__', __str__)
        # on initial run, compute and return __str__()
        return __str__(self)

    cls_dict['__str__'] = __str__
    return type(name, bases, cls_dict)


@dataclasses.dataclass
class X(metaclass=terse_str):
    a: int = 1
    b: bool = False
    c: float = 2.0
    d: list[str] = dataclasses.field(default_factory=lambda: [1, 2, 3])


x1 = X(b=True)
x2 = X(b=False, c=3, d=[1, 2])

print(x1)    # X(b=True)
print(x2)    # X(c=3, d=[1, 2])

最后,这里有一个快速而肮脏的测试,以确认缓存实际上对重复调用 str()print:

有益
import dataclasses
from timeit import timeit

def terse_str(cls):  # Decorator for class.
    def __str__(self):
        """Returns a string containing only the non-default field values."""
        s = ', '.join(f'{field.name}={getattr(self, field.name)}'
                      for field in dataclasses.fields(self)
                      if getattr(self, field.name) != field.default)
        return f'{type(self).__name__}({s})'

    setattr(cls, '__str__', __str__)
    return cls


def terse_str_meta(name, bases, cls_dict):  # Metaclass for class

    def __str__(self):

        field_to_default = {}
        for f in dataclasses.fields(self):
            if f.default_factory is not dataclasses.MISSING:
                field_to_default[f.name] = f.default_factory()
            else:
                field_to_default[f.name] = f.default

        def __str__(self, name=name, fields=field_to_default):
            s = ', '.join([f'{field}={val!r}'
                          for field, default in fields.items()
                          if (val := getattr(self, field)) != default])

            return f'{name}({s})'

        setattr(type(self), '__str__', __str__)
        return __str__(self)

    cls_dict['__str__'] = __str__
    return type(name, bases, cls_dict)


@dataclasses.dataclass
@terse_str
class X:
    a: int = 1
    b: bool = False
    c: float = 2.0


@dataclasses.dataclass
class X_Cached(metaclass=terse_str_meta):
    a: int = 1
    b: bool = False
    c: float = 2.0


print(f"Simple:  {timeit('str(X(b=True))', globals=globals()):.3f}")
print(f"Cached:  {timeit('str(X_Cached(b=True))', globals=globals()):.3f}")

print()
print(X(b=True))
print(X_Cached(b=True))

结果:

Simple:  2.177
Cached:  1.168