合并 python 中的数据类
Merge dataclasses in python
我有一个像这样的数据类:
import dataclasses
import jax.numpy as jnp
@dataclasses.dataclass
class Metric:
score1: jnp.ndarray
score2: jnp.ndarray
score3: jnp.ndarray
在我的代码中,我创建了它的多个实例,有没有一种简单的方法可以将每个属性中的两个属性合并?
例如,如果我有:
a = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
b = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
我想合并它们,比如有一个单一的指标包含:
score1=jnp.array([10,10,10,10,10,10]), score2=jnp.array([20,20,20,20,20,20]) and score3=jnp.array([30,30,30,30,30,30])
最简单的事情可能就是定义一个方法:
import dataclasses
import jax.numpy as jnp
@dataclasses.dataclass
class Metric:
score1: jnp.ndarray
score2: jnp.ndarray
score3: jnp.ndarray
def concatenate(self, other):
return Metric(
jnp.concatenate((self.score1, other.score1)),
jnp.concatenate((self.score2, other.score2)),
jnp.concatenate((self.score3, other.score3)),
)
然后 a.concatenate(b)
。您也可以改为调用方法 __add__
,这样就可以只使用 a + b
。这更整洁,但可能会与 element-wise 添加混淆。
通过将 class Metric
注册为 pytree_node
. google/flax
,一个建立在jax
,提供了 flax.struct.dataclass
助手来做到这一点。注册后,您可以使用 jax.tree_util
包来操作 Metric
个实例:
from flax.struct import dataclass as flax_dataclass
from jax.tree_util import tree_multimap
import jax.numpy as jnp
@flax_dataclass
class Metric:
score1: jnp.ndarray
score2: jnp.ndarray
score3: jnp.ndarray
a = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
b = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
tree_multimap(lambda x, y: jnp.concatenate([x, y]), a, b)
给出:
Metric(score1=DeviceArray([10, 10, 10, 10, 10, 10], dtype=int32), score2=DeviceArray([20, 20, 20, 20, 20, 20], dtype=int32), score3=DeviceArray([30, 30, 30, 30, 30, 30], dtype=int32))
我有一个像这样的数据类:
import dataclasses
import jax.numpy as jnp
@dataclasses.dataclass
class Metric:
score1: jnp.ndarray
score2: jnp.ndarray
score3: jnp.ndarray
在我的代码中,我创建了它的多个实例,有没有一种简单的方法可以将每个属性中的两个属性合并? 例如,如果我有:
a = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
b = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
我想合并它们,比如有一个单一的指标包含:
score1=jnp.array([10,10,10,10,10,10]), score2=jnp.array([20,20,20,20,20,20]) and score3=jnp.array([30,30,30,30,30,30])
最简单的事情可能就是定义一个方法:
import dataclasses
import jax.numpy as jnp
@dataclasses.dataclass
class Metric:
score1: jnp.ndarray
score2: jnp.ndarray
score3: jnp.ndarray
def concatenate(self, other):
return Metric(
jnp.concatenate((self.score1, other.score1)),
jnp.concatenate((self.score2, other.score2)),
jnp.concatenate((self.score3, other.score3)),
)
然后 a.concatenate(b)
。您也可以改为调用方法 __add__
,这样就可以只使用 a + b
。这更整洁,但可能会与 element-wise 添加混淆。
通过将 class Metric
注册为 pytree_node
. google/flax
,一个建立在jax
,提供了 flax.struct.dataclass
助手来做到这一点。注册后,您可以使用 jax.tree_util
包来操作 Metric
个实例:
from flax.struct import dataclass as flax_dataclass
from jax.tree_util import tree_multimap
import jax.numpy as jnp
@flax_dataclass
class Metric:
score1: jnp.ndarray
score2: jnp.ndarray
score3: jnp.ndarray
a = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
b = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
tree_multimap(lambda x, y: jnp.concatenate([x, y]), a, b)
给出:
Metric(score1=DeviceArray([10, 10, 10, 10, 10, 10], dtype=int32), score2=DeviceArray([20, 20, 20, 20, 20, 20], dtype=int32), score3=DeviceArray([30, 30, 30, 30, 30, 30], dtype=int32))