使用不同的默认值集解析参数

Parsing args with different set of default values

我有一组训练参数和一组调整参数。它们共享相同的名称但不同的默认值。我想使用 argparse 来定义要使用的默认值组并解析这些值。

我了解到可以使用 add_subparsers 为每种模式设置子解析器。但是,它们的名称是相同的,这意味着我必须设置相同的参数两次(这很长)。

我还尝试包含两个解析器,第一个解析一些参数以确定使用哪组默认值,然后使用 parser.set_defaults(**defaults) 为第二个解析器设置默认值,就像这样:

train_defaults = dict(
    optimizer='AdamW',
    lr=1e-3,
    strategy='linear',
    warmup_steps=5_000,
    weight_decay=0.3
)


tune_defaults = dict(
    optimizer='SGD',
    lr=1e-2,
    strategy='cosine',
    warmup_steps=500,
    weight_decay=0.0
)

selector = argparse.ArgumentParser(description='Mode Selector')
mode = selector.add_mutually_exclusive_group()
mode.add_argument('-t', '--train', action='store_true', help='train model')
mode.add_argument('-u', '--tune', action='store_true', help='tune model')
select, unknown = selector.parse_known_args()
defaults = tune_defaults if select.tune else select.train
parser.set_defaults(**defaults)
args, unknown = parser.parse_known_args()

但是两个解析器会在某些args上发生冲突,例如-td引用parser中的--train_data,但它也会被selector解析将引发异常:

usage: run.py [-h] [-pt | -pa] [-t] [-u] [-v]
run.py: error: argument -t/--train: ignored explicit argument 'd'

(这是一个 MWE,实际参数可能会有所不同。

如您所见,多解析器解决方案可能容易出错。我看到两个选择:

使用环境变量

像这样:

import os

do_tuning = os.getenv("DO_THE_TUNING_MODE", None) is not None

...

defaults = tune_defaults if do_tuning else select.train
parser = argparse.ArgumentParser()
...
parser.set_defaults(**defaults)
args, unknown = parser.parse_known_args()

点赞

DO_THE_TUNING_MODE=1 run.py <options>

export DO_THE_TUNING_MODE=1
run.py <options>

(或者当然,不要设置为训练模式)

  • 优点:
    • Tuning/selection 方法在解析器之外,因此不会发生冲突
    • 用户可以在 shell 会话中将“状态”设置为调整或训练,而不必在 运行
    • 时连续设置选项
  • 缺点:
    • 环境变量的设置不如一次性调用命令行选项简单
    • 很容易忘记环境变量的设置

使用子解析器

这可能是最好的解决方案。您表示您不想这样做,因为您有太多选择,但这就是功能的用途。

def add_parsing_options(parser):
    # All your 40 options go here
    parser.add_argument(...)
    parser.add_argument(...)


parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
tuning_parser = subparser.add_parser("tune")
training_parser = subparser.add_parser("train")
add_parsing_options(tuning_parser)
add_parsing_options(training_parser)
tuning_parser.set_defaults(**tune_defaults)
training_parser.set_defaults(**train_defaults)
args, unknown = parser.parse_known_args()

点赞

run.py train <options>

run.py tune <options>
  • 优点:
    • 使用哪个模式的工具时明确
  • 缺点:
    • 每次使用该工具时都要输入一个额外的参数

我通过一些硬代码部分解决了这个问题,即 由于第一个解析器仅用于设置第二个解析器的默认参数,因此只有几个参数,在我的例子中是 2。 所以我所做的是将 sys.argv 分成两部分:

import sys
select, unknown = selector.parse_known_args(sys.argv[:3])
args, unknown = parser.parse_known_args(sys.argv[3:])

优点:

  • 拥有其他方法的大部分优点
  • 每次都不需要输入额外的参数

缺点:

  • 值 3 是超参数