合并 python 中的数据类

Merge dataclasses in python

我有一个像这样的数据类:

import dataclasses
import jax.numpy as jnp

@dataclasses.dataclass
class Metric:
    score1: jnp.ndarray
    score2: jnp.ndarray
    score3: jnp.ndarray

在我的代码中,我创建了它的多个实例,有没有一种简单的方法可以将每个属性中的两个属性合并? 例如,如果我有:

a = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
b = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))

我想合并它们,比如有一个单一的指标包含:

score1=jnp.array([10,10,10,10,10,10]), score2=jnp.array([20,20,20,20,20,20]) and score3=jnp.array([30,30,30,30,30,30])

最简单的事情可能就是定义一个方法:

import dataclasses
import jax.numpy as jnp


@dataclasses.dataclass
class Metric:
    score1: jnp.ndarray
    score2: jnp.ndarray
    score3: jnp.ndarray

    def concatenate(self, other):
        return Metric(
            jnp.concatenate((self.score1, other.score1)),
            jnp.concatenate((self.score2, other.score2)),
            jnp.concatenate((self.score3, other.score3)),
        )

然后 a.concatenate(b)。您也可以改为调用方法 __add__,这样就可以只使用 a + b。这更整洁,但可能会与 element-wise 添加混淆。

通过将 class Metric 注册为 pytree_node. google/flax,一个建立在jax,提供了 flax.struct.dataclass 助手来做到这一点。注册后,您可以使用 jax.tree_util 包来操作 Metric 个实例:

from flax.struct import dataclass as flax_dataclass
from jax.tree_util import tree_multimap
import jax.numpy as jnp

@flax_dataclass
class Metric:
    score1: jnp.ndarray
    score2: jnp.ndarray
    score3: jnp.ndarray

a = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
b = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))

tree_multimap(lambda x, y: jnp.concatenate([x, y]), a, b)

给出:

Metric(score1=DeviceArray([10, 10, 10, 10, 10, 10], dtype=int32), score2=DeviceArray([20, 20, 20, 20, 20, 20], dtype=int32), score3=DeviceArray([30, 30, 30, 30, 30, 30], dtype=int32))