从 CLI 覆盖 hydra 配置组
Overwriting hydra configuration groups from CLI
我正在尝试从 CLI 覆盖一组参数,但我不确定该怎么做。我的 conf 结构如下
conf
├── config.yaml
├── optimizer
│ ├── adamw.yaml
│ ├── adam.yaml
│ ├── default.yaml
│ └── sgd.yaml
├── task
│ ├── default.yaml
│ └── nlp
│ ├── default_seq2seq.yaml
│ ├── summarization.yaml
│ └── text_classification.yaml
我的task/default
看起来像这样
# @package task
defaults:
- _self_
- /optimizer/adam@cfg.optimizer
_target_: src.core.task.Task
_recursive_: false
cfg:
prefix_sep: ${training.prefix_sep}
而 optimiser/default
看起来像这样
_target_: null
lr: ${training.lr}
weight_decay: 0.001
no_decay:
- bias
- LayerNorm.weight
和一个特定的优化器,比如 adam.yaml,看起来像这样
defaults:
- default
_target_: torch.optim.Adam
最终我要计算的配置是这样的
task:
_target_: src.task.nlp.nli_generation.task.NLIGenerationTask
_recursive_: false
cfg:
prefix_sep: ${training.prefix_sep}
optimizer:
_target_: torch.optim.Adam
lr: ${training.lr}
weight_decay: 0.001
no_decay:
- bias
- LayerNorm.weight
我希望能够通过 CLI 修改优化器(例如,使用 sgd),但我不确定如何实现。我试过了,但我明白为什么会失败,this
python train.py task.cfg.optimizer=sgd # fails
python train.py task.cfg.optimizer=/optimizer/sgd #fails
关于如何实现这一目标的任何提示?
Github 讨论 here.
您不能覆盖此表单中的默认列表条目。
参见 this。
特别是:
CONFIG : A config to use when creating the output config. e.g. db/mysql, db/mysql@backup.
GROUP_DEFAULT : An overridable config. e.g. db: mysql, db@backup: mysql.
为了能够覆盖默认列表条目,您需要将其定义为 GROUP_DEFAULT。
在你的情况下,它可能看起来像
defaults:
- _self_
- /optimizer@cfg.optimizer: adam
我正在尝试从 CLI 覆盖一组参数,但我不确定该怎么做。我的 conf 结构如下
conf
├── config.yaml
├── optimizer
│ ├── adamw.yaml
│ ├── adam.yaml
│ ├── default.yaml
│ └── sgd.yaml
├── task
│ ├── default.yaml
│ └── nlp
│ ├── default_seq2seq.yaml
│ ├── summarization.yaml
│ └── text_classification.yaml
我的task/default
看起来像这样
# @package task
defaults:
- _self_
- /optimizer/adam@cfg.optimizer
_target_: src.core.task.Task
_recursive_: false
cfg:
prefix_sep: ${training.prefix_sep}
而 optimiser/default
看起来像这样
_target_: null
lr: ${training.lr}
weight_decay: 0.001
no_decay:
- bias
- LayerNorm.weight
和一个特定的优化器,比如 adam.yaml,看起来像这样
defaults:
- default
_target_: torch.optim.Adam
最终我要计算的配置是这样的
task:
_target_: src.task.nlp.nli_generation.task.NLIGenerationTask
_recursive_: false
cfg:
prefix_sep: ${training.prefix_sep}
optimizer:
_target_: torch.optim.Adam
lr: ${training.lr}
weight_decay: 0.001
no_decay:
- bias
- LayerNorm.weight
我希望能够通过 CLI 修改优化器(例如,使用 sgd),但我不确定如何实现。我试过了,但我明白为什么会失败,this
python train.py task.cfg.optimizer=sgd # fails
python train.py task.cfg.optimizer=/optimizer/sgd #fails
关于如何实现这一目标的任何提示?
Github 讨论 here.
您不能覆盖此表单中的默认列表条目。 参见 this。 特别是:
CONFIG : A config to use when creating the output config. e.g. db/mysql, db/mysql@backup.
GROUP_DEFAULT : An overridable config. e.g. db: mysql, db@backup: mysql.
为了能够覆盖默认列表条目,您需要将其定义为 GROUP_DEFAULT。 在你的情况下,它可能看起来像
defaults:
- _self_
- /optimizer@cfg.optimizer: adam