使用数据类型或其他特定于库的变量作为 hydra 中的参数

use data types or other library-specific variables as arguments in hydra

我想使用 python 数据类型——内置的和从 numpy、tensorflow 等库导入的——作为我的 hydra 配置中的参数。 类似于:

# config.yaml

arg1: np.float32
arg2: tf.float16

我目前正在这样做:

# config.yaml

arg1: 'float32'
arg2: 'float16
# my python code
# ...
DTYPES_LOOKUP = {
  'float32': np.float32,
  'float16': tf.float16
}
arg1 = DTYPES_LOOKUP[config.arg1]
arg2 = DTYPES_LOOKUP[config.arg2]

有没有更hydronic/优雅的解决方案?谢谢!

hydra.utils.get_class函数是否为您解决了这个问题?

# config.yaml

arg1: numpy.float32  # note: use "numpy" here, not "np"
arg2: tensorflow.float16
# python code
...
from hydra.utils import get_class
arg1 = get_class(config.arg1)
arg2 = get_class(config.arg2)

更新 1:使用自定义解析器

根据下面 miccio 的评论,这里是使用 OmegaConf custom resolver 包装 get_class 函数的演示。

from omegaconf import OmegaConf
from hydra.utils import get_class

OmegaConf.register_new_resolver(name="get_cls", resolver=lambda cls: get_class(cls))

config = OmegaConf.create("""
# config.yaml

arg1: "${get_cls: numpy.float32}"
arg2: "${get_cls: tensorflow.float16}"
""")

arg1 = config.arg1
arg1 = config.arg2

更新 2:

原来 get_class("numpy.float32") 成功了,但 get_class("tensorflow.float16") 引发了 ValueError。 原因是 get_class 检查返回值确实是 class(使用 isinstance(cls, type))。

函数hydra.utils.get_method稍微宽松一些,只检查返回值是否可调用,但这仍然不适用于tf.float16

>>> isinstance(tf.float16, type)
False
>>> callable(tf.float16)
False

包装 tensorflow.as_dtype 函数的自定义解析器可能是有序的。

>>> tf.as_dtype("float16")
tf.float16