限制 hydra 结构化配置中的可能值

Restrict possible values in hydra structured configs

我尝试为 hydra framework. I use structured config schema 采用我的应用程序,我想限制某些字段的可能值。有什么办法吗?

这是我的代码:

my_app.py:

import hydra


@dataclass
class Config:
    # possible values are 'foo' and 'bar'
    some_value: str = "foo"


@hydra.main(config_path="configs", config_name="config")
def main(cfg: Config):
    print(cfg)


if __name__ == "__main__":
    main()

configs/config.yaml:

# value is incorrect.
# I need hydra to throw an exception in this case
some_value: "barrr"

几个选项:

1) 如果您可接受的值是可枚举的,请使用 Enum 类型:

from enum import Enum
from dataclasses import dataclass

class SomeValue(Enum):
    foo = 1
    bar = 2

@dataclass
class Config:
    # possible values are 'foo' and 'bar'
    some_value: SomeValue = SomeValue.foo

如果不需要花哨的逻辑来验证 some_value,这是我推荐的解决方案。

2) 如果你使用的是yaml文件,你可以使用OmegaConf注册一个custom resolver:

# my_python_file.py
from omegaconf import OmegaConf

def check_some_value(value: str) -> str:
    assert value in ("foo", "bar")
    return value

OmegaConf.register_new_resolver("check_foo_bar", check_some_value)

@hydra.main(...)
...

if __name__ == "__main__":
    main()
# my_yaml_file.yaml
some_value: ${check_foo_bar:foo}

当您在 python 代码中访问 cfg.some_value 时,如果该值与 check_some_value 函数不一致,将引发 AssertionError

3) 配置完成后,您可以调用OmegaConf.to_object 创建您的数据类的实例。这意味着将调用数据类的 __post_init__ 函数。

import hydra
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf

@dataclass
class Config:
    # possible values are 'foo' and 'bar'
    some_value: str = "foo"

    def __post_init__(self) -> None:
        assert self.some_value in ("foo", "bar")

@hydra.main(config_path="configs", config_name="config")
def main(dict_cfg: DictConfg):
    cfg: Config = OmegaConf.to_object(dict_cfg)
    print(cfg)

if __name__ == "__main__":
    main()