fb-hydra:如何实现 2 个嵌套结构化配置?
fb-hydra: How to implement 2 nested structured configs?
我有 2 个子配置和一个具有这些子配置的主(?)配置。我设计的配置如下:
from dataclasses import dataclass, field
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, DictConfig
from typing import Any, List
@dataclass
class DBConfig:
host: str = "localhost"
driver: str = MISSING
port: int = MISSING
@dataclass
class MySQLConfig(DBConfig):
driver: str = "mysql"
port: int = 3306
@dataclass
class PostGreSQLConfig(DBConfig):
driver: str = "postgresql"
port: int = 5432
timeout: int = 10
@dataclass
class ConnectionConfig:
target: str = "app.my_class.MyClass"
params: DBConfig = MISSING
defaults: List[Any] = field(
default_factory=lambda: [
{
"params": "mysql", # I'd like to set mysql as a default
}
]
)
@dataclass
class AConfig:
name: str = "foo"
@dataclass
class BConfig(AConfig):
age: int = 10
@dataclass
class CConfig(AConfig):
age: int = 20
@dataclass
class SomeOtherConfig:
target: str = "app.my_class.MyClass2"
params: AConfig = MISSING
defaults: List[Any] = field(
default_factory=lambda: [
{
"params": "bconfig", # I'd like to set bconfig as a default
}
]
)
@dataclass
class Config:
db_connection: ConnectionConfig = ConnectionConfig()
some_other: SomeOtherConfig = SomeOtherConfig()
@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
print(cfg.pretty())
# connection = hydra.utils.instantiate(cfg)
# print(connection)
if __name__ == "__main__":
cs = ConfigStore.instance()
cs.store(
name="config",
node=Config,
)
cs.store(group="params", name="mysql", node=MySQLConfig)
cs.store(group="params", name="postgresql", node=PostGreSQLConfig)
cs.store(group="params", name="bconfig", node=BConfig)
cs.store(group="params", name="cconfig", node=CConfig)
my_app()
当我 运行 没有任何选项的程序时我的期望是什么:
db_connection:
target: app.my_class.MyClass
params:
host: localhost
driver: mysql
port: 3306
some_other:
target: app.my_class.MyClass2
params:
name: "foo"
age: 10
但是结果:
db_connection:
target: app.my_class.MyClass
params: ???
defaults:
- params: mysql
some_other:
target: app.my_class.MyClass2
params: ???
defaults:
- params: bconfig
首先,从 Hydra 1.0 开始 - 默认列表仅在 primary config 中受支持。
以下是两个版本,第一个版本在您的示例中尽可能少地更改,第二个版本稍微清理一下。
示例 1:
from dataclasses import dataclass, field
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, DictConfig
from typing import Any, List
@dataclass
class DBConfig:
host: str = "localhost"
driver: str = MISSING
port: int = MISSING
@dataclass
class MySQLConfig(DBConfig):
driver: str = "mysql"
port: int = 3306
@dataclass
class PostGreSQLConfig(DBConfig):
driver: str = "postgresql"
port: int = 5432
timeout: int = 10
@dataclass
class ConnectionConfig:
target: str = "app.my_class.MyClass"
params: DBConfig = MISSING
@dataclass
class AConfig:
name: str = "foo"
@dataclass
class BConfig(AConfig):
age: int = 10
@dataclass
class CConfig(AConfig):
age: int = 20
@dataclass
class SomeOtherConfig:
target: str = "app.my_class.MyClass2"
params: AConfig = MISSING
@dataclass
class Config:
db_connection: ConnectionConfig = ConnectionConfig()
some_other: SomeOtherConfig = SomeOtherConfig()
defaults: List[Any] = field(
default_factory=lambda: [
{"db_connection/params": "mysql"},
{"some_other/params": "bconfig"},
]
)
@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
print(cfg.pretty())
if __name__ == "__main__":
cs = ConfigStore.instance()
cs.store(
name="config", node=Config,
)
cs.store(group="db_connection/params", name="mysql", node=MySQLConfig)
cs.store(group="db_connection/params", name="postgresql", node=PostGreSQLConfig)
cs.store(group="some_other/params", name="bconfig", node=BConfig)
cs.store(group="some_other/params", name="cconfig", node=CConfig)
my_app()
示例 2:
from dataclasses import dataclass, field
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, DictConfig
from hydra.types import ObjectConf
from typing import Any, List
@dataclass
class DBConfig:
host: str = "localhost"
driver: str = MISSING
port: int = MISSING
@dataclass
class MySQLConfig(DBConfig):
driver: str = "mysql"
port: int = 3306
@dataclass
class PostGreSQLConfig(DBConfig):
driver: str = "postgresql"
port: int = 5432
timeout: int = 10
@dataclass
class AConfig:
name: str = "foo"
@dataclass
class BConfig(AConfig):
age: int = 10
@dataclass
class CConfig(AConfig):
age: int = 20
defaults = [{"db_connection": "mysql"}, {"some_other": "bconfig"}]
@dataclass
class Config:
db_connection: ObjectConf = MISSING
some_other: ObjectConf = MISSING
defaults: List[Any] = field(default_factory=lambda: defaults)
cs = ConfigStore.instance()
cs.store(name="config", node=Config)
cs.store(
group="db_connection",
name="mysql",
node=ObjectConf(target="MySQL", params=MySQLConfig),
)
cs.store(
group="db_connection",
name="postgresql",
node=ObjectConf(target="PostgeSQL", params=PostGreSQLConfig),
)
cs.store(
group="some_other",
name="bconfig",
node=ObjectConf(target="ClassB", params=BConfig()),
)
cs.store(
group="some_other",
name="cconfig",
node=ObjectConf(target="ClassC", params=AConfig()),
)
@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
print(cfg.pretty())
if __name__ == "__main__":
my_app()
我有 2 个子配置和一个具有这些子配置的主(?)配置。我设计的配置如下:
from dataclasses import dataclass, field
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, DictConfig
from typing import Any, List
@dataclass
class DBConfig:
host: str = "localhost"
driver: str = MISSING
port: int = MISSING
@dataclass
class MySQLConfig(DBConfig):
driver: str = "mysql"
port: int = 3306
@dataclass
class PostGreSQLConfig(DBConfig):
driver: str = "postgresql"
port: int = 5432
timeout: int = 10
@dataclass
class ConnectionConfig:
target: str = "app.my_class.MyClass"
params: DBConfig = MISSING
defaults: List[Any] = field(
default_factory=lambda: [
{
"params": "mysql", # I'd like to set mysql as a default
}
]
)
@dataclass
class AConfig:
name: str = "foo"
@dataclass
class BConfig(AConfig):
age: int = 10
@dataclass
class CConfig(AConfig):
age: int = 20
@dataclass
class SomeOtherConfig:
target: str = "app.my_class.MyClass2"
params: AConfig = MISSING
defaults: List[Any] = field(
default_factory=lambda: [
{
"params": "bconfig", # I'd like to set bconfig as a default
}
]
)
@dataclass
class Config:
db_connection: ConnectionConfig = ConnectionConfig()
some_other: SomeOtherConfig = SomeOtherConfig()
@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
print(cfg.pretty())
# connection = hydra.utils.instantiate(cfg)
# print(connection)
if __name__ == "__main__":
cs = ConfigStore.instance()
cs.store(
name="config",
node=Config,
)
cs.store(group="params", name="mysql", node=MySQLConfig)
cs.store(group="params", name="postgresql", node=PostGreSQLConfig)
cs.store(group="params", name="bconfig", node=BConfig)
cs.store(group="params", name="cconfig", node=CConfig)
my_app()
当我 运行 没有任何选项的程序时我的期望是什么:
db_connection:
target: app.my_class.MyClass
params:
host: localhost
driver: mysql
port: 3306
some_other:
target: app.my_class.MyClass2
params:
name: "foo"
age: 10
但是结果:
db_connection:
target: app.my_class.MyClass
params: ???
defaults:
- params: mysql
some_other:
target: app.my_class.MyClass2
params: ???
defaults:
- params: bconfig
首先,从 Hydra 1.0 开始 - 默认列表仅在 primary config 中受支持。 以下是两个版本,第一个版本在您的示例中尽可能少地更改,第二个版本稍微清理一下。
示例 1:
from dataclasses import dataclass, field
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, DictConfig
from typing import Any, List
@dataclass
class DBConfig:
host: str = "localhost"
driver: str = MISSING
port: int = MISSING
@dataclass
class MySQLConfig(DBConfig):
driver: str = "mysql"
port: int = 3306
@dataclass
class PostGreSQLConfig(DBConfig):
driver: str = "postgresql"
port: int = 5432
timeout: int = 10
@dataclass
class ConnectionConfig:
target: str = "app.my_class.MyClass"
params: DBConfig = MISSING
@dataclass
class AConfig:
name: str = "foo"
@dataclass
class BConfig(AConfig):
age: int = 10
@dataclass
class CConfig(AConfig):
age: int = 20
@dataclass
class SomeOtherConfig:
target: str = "app.my_class.MyClass2"
params: AConfig = MISSING
@dataclass
class Config:
db_connection: ConnectionConfig = ConnectionConfig()
some_other: SomeOtherConfig = SomeOtherConfig()
defaults: List[Any] = field(
default_factory=lambda: [
{"db_connection/params": "mysql"},
{"some_other/params": "bconfig"},
]
)
@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
print(cfg.pretty())
if __name__ == "__main__":
cs = ConfigStore.instance()
cs.store(
name="config", node=Config,
)
cs.store(group="db_connection/params", name="mysql", node=MySQLConfig)
cs.store(group="db_connection/params", name="postgresql", node=PostGreSQLConfig)
cs.store(group="some_other/params", name="bconfig", node=BConfig)
cs.store(group="some_other/params", name="cconfig", node=CConfig)
my_app()
示例 2:
from dataclasses import dataclass, field
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, DictConfig
from hydra.types import ObjectConf
from typing import Any, List
@dataclass
class DBConfig:
host: str = "localhost"
driver: str = MISSING
port: int = MISSING
@dataclass
class MySQLConfig(DBConfig):
driver: str = "mysql"
port: int = 3306
@dataclass
class PostGreSQLConfig(DBConfig):
driver: str = "postgresql"
port: int = 5432
timeout: int = 10
@dataclass
class AConfig:
name: str = "foo"
@dataclass
class BConfig(AConfig):
age: int = 10
@dataclass
class CConfig(AConfig):
age: int = 20
defaults = [{"db_connection": "mysql"}, {"some_other": "bconfig"}]
@dataclass
class Config:
db_connection: ObjectConf = MISSING
some_other: ObjectConf = MISSING
defaults: List[Any] = field(default_factory=lambda: defaults)
cs = ConfigStore.instance()
cs.store(name="config", node=Config)
cs.store(
group="db_connection",
name="mysql",
node=ObjectConf(target="MySQL", params=MySQLConfig),
)
cs.store(
group="db_connection",
name="postgresql",
node=ObjectConf(target="PostgeSQL", params=PostGreSQLConfig),
)
cs.store(
group="some_other",
name="bconfig",
node=ObjectConf(target="ClassB", params=BConfig()),
)
cs.store(
group="some_other",
name="cconfig",
node=ObjectConf(target="ClassC", params=AConfig()),
)
@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
print(cfg.pretty())
if __name__ == "__main__":
my_app()