使用数据类型或其他特定于库的变量作为 hydra 中的参数
use data types or other library-specific variables as arguments in hydra
我想使用 python 数据类型——内置的和从 numpy、tensorflow 等库导入的——作为我的 hydra 配置中的参数。
类似于:
# config.yaml
arg1: np.float32
arg2: tf.float16
我目前正在这样做:
# config.yaml
arg1: 'float32'
arg2: 'float16
# my python code
# ...
DTYPES_LOOKUP = {
'float32': np.float32,
'float16': tf.float16
}
arg1 = DTYPES_LOOKUP[config.arg1]
arg2 = DTYPES_LOOKUP[config.arg2]
有没有更hydronic/优雅的解决方案?谢谢!
hydra.utils.get_class
函数是否为您解决了这个问题?
# config.yaml
arg1: numpy.float32 # note: use "numpy" here, not "np"
arg2: tensorflow.float16
# python code
...
from hydra.utils import get_class
arg1 = get_class(config.arg1)
arg2 = get_class(config.arg2)
更新 1:使用自定义解析器
根据下面 miccio 的评论,这里是使用 OmegaConf custom resolver 包装 get_class
函数的演示。
from omegaconf import OmegaConf
from hydra.utils import get_class
OmegaConf.register_new_resolver(name="get_cls", resolver=lambda cls: get_class(cls))
config = OmegaConf.create("""
# config.yaml
arg1: "${get_cls: numpy.float32}"
arg2: "${get_cls: tensorflow.float16}"
""")
arg1 = config.arg1
arg1 = config.arg2
更新 2:
原来 get_class("numpy.float32")
成功了,但 get_class("tensorflow.float16")
引发了 ValueError。
原因是 get_class
检查返回值确实是 class(使用 isinstance(cls, type)
)。
函数hydra.utils.get_method
稍微宽松一些,只检查返回值是否可调用,但这仍然不适用于tf.float16
。
>>> isinstance(tf.float16, type)
False
>>> callable(tf.float16)
False
包装 tensorflow.as_dtype
函数的自定义解析器可能是有序的。
>>> tf.as_dtype("float16")
tf.float16
我想使用 python 数据类型——内置的和从 numpy、tensorflow 等库导入的——作为我的 hydra 配置中的参数。 类似于:
# config.yaml
arg1: np.float32
arg2: tf.float16
我目前正在这样做:
# config.yaml
arg1: 'float32'
arg2: 'float16
# my python code
# ...
DTYPES_LOOKUP = {
'float32': np.float32,
'float16': tf.float16
}
arg1 = DTYPES_LOOKUP[config.arg1]
arg2 = DTYPES_LOOKUP[config.arg2]
有没有更hydronic/优雅的解决方案?谢谢!
hydra.utils.get_class
函数是否为您解决了这个问题?
# config.yaml
arg1: numpy.float32 # note: use "numpy" here, not "np"
arg2: tensorflow.float16
# python code
...
from hydra.utils import get_class
arg1 = get_class(config.arg1)
arg2 = get_class(config.arg2)
更新 1:使用自定义解析器
根据下面 miccio 的评论,这里是使用 OmegaConf custom resolver 包装 get_class
函数的演示。
from omegaconf import OmegaConf
from hydra.utils import get_class
OmegaConf.register_new_resolver(name="get_cls", resolver=lambda cls: get_class(cls))
config = OmegaConf.create("""
# config.yaml
arg1: "${get_cls: numpy.float32}"
arg2: "${get_cls: tensorflow.float16}"
""")
arg1 = config.arg1
arg1 = config.arg2
更新 2:
原来 get_class("numpy.float32")
成功了,但 get_class("tensorflow.float16")
引发了 ValueError。
原因是 get_class
检查返回值确实是 class(使用 isinstance(cls, type)
)。
函数hydra.utils.get_method
稍微宽松一些,只检查返回值是否可调用,但这仍然不适用于tf.float16
。
>>> isinstance(tf.float16, type)
False
>>> callable(tf.float16)
False
包装 tensorflow.as_dtype
函数的自定义解析器可能是有序的。
>>> tf.as_dtype("float16")
tf.float16