从 CLI 覆盖 hydra 配置组

Overwriting hydra configuration groups from CLI

我正在尝试从 CLI 覆盖一组参数,但我不确定该怎么做。我的 conf 结构如下

conf
├── config.yaml
├── optimizer
│   ├── adamw.yaml
│   ├── adam.yaml
│   ├── default.yaml
│   └── sgd.yaml
├── task
│   ├── default.yaml
│   └── nlp
│       ├── default_seq2seq.yaml
│       ├── summarization.yaml
│       └── text_classification.yaml

我的task/default 看起来像这样

# @package task
defaults:
  - _self_
  - /optimizer/adam@cfg.optimizer

_target_: src.core.task.Task
_recursive_: false

cfg:
  prefix_sep: ${training.prefix_sep}

optimiser/default 看起来像这样

_target_: null
lr: ${training.lr}
weight_decay: 0.001
no_decay:
  - bias
  - LayerNorm.weight

和一个特定的优化器,比如 adam.yaml,看起来像这样

defaults:
  - default

_target_: torch.optim.Adam

最终我要计算的配置是这样的

task:
  _target_: src.task.nlp.nli_generation.task.NLIGenerationTask
  _recursive_: false
  cfg:
    prefix_sep: ${training.prefix_sep}
    optimizer:
      _target_: torch.optim.Adam
      lr: ${training.lr}
      weight_decay: 0.001
      no_decay:
      - bias
      - LayerNorm.weight

我希望能够通过 CLI 修改优化器(例如,使用 sgd),但我不确定如何实现。我试过了,但我明白为什么会失败,this

python train.py task.cfg.optimizer=sgd # fails
python train.py task.cfg.optimizer=/optimizer/sgd #fails

关于如何实现这一目标的任何提示?

Github 讨论 here.

您不能覆盖此表单中的默认列表条目。 参见 this。 特别是:

CONFIG : A config to use when creating the output config. e.g. db/mysql, db/mysql@backup.
GROUP_DEFAULT : An overridable config. e.g. db: mysql, db@backup: mysql.

为了能够覆盖默认列表条目,您需要将其定义为 GROUP_DEFAULT。 在你的情况下,它可能看起来像

defaults:
  - _self_
  - /optimizer@cfg.optimizer: adam