Kedro - 如何将嵌套参数直接传递给节点

Kedro - how to pass nested parameters directly to node

kedro 建议将参数存储在 conf/base/parameters.yml 中。让我们假设它看起来像这样:

step_size: 1
model_params:
    learning_rate: 0.01
    test_data_ratio: 0.2
    num_train_steps: 10000

现在假设我有一些 data_engineering 管道,其 nodes.py 具有如下所示的功能:

def some_pipeline_step(num_train_steps):
    """
    Takes the parameter `num_train_steps` as argument.
    """
    pass

我将如何着手将嵌套参数直接传递给 data_engineering/pipeline.py 中的这个函数?我尝试失败:

from kedro.pipeline import Pipeline, node

from .nodes import split_data


def create_pipeline(**kwargs):
    return Pipeline(
        [
            node(
                some_pipeline_step,
                ["params:model_params.num_train_steps"],
                dict(
                    train_x="train_x",
                    train_y="train_y",
                ),
            )
        ]
    )

我知道我可以使用 ['parameters'] 将所有参数传递给函数,或者使用 ['params:model_params'] 传递所有 model_params 参数,但这似乎不雅,我觉得必须成为一种方式。非常感谢任何输入!

(免责声明:我是 Kedro 团队的一员)

感谢您的提问。遗憾的是,当前版本的 Kedro 不支持嵌套参数。临时解决方案是在节点内使用顶级键(正如您已经指出的那样)或使用某种参数过滤器装饰节点函数,这也不优雅。

可能最可行的解决方案是通过覆盖 _get_feed_dict 方法来自定义 ProjectContext(在 src/<package_name>/run.py 中)class:

class ProjectContext(KedroContext):
    # ...


    def _get_feed_dict(self) -> Dict[str, Any]:
        """Get parameters and return the feed dictionary."""
        params = self.params
        feed_dict = {"parameters": params}

        def _add_param_to_feed_dict(param_name, param_value):
            """This recursively adds parameter paths to the `feed_dict`,
            whenever `param_value` is a dictionary itself, so that users can
            specify specific nested parameters in their node inputs.

            Example:

                >>> param_name = "a"
                >>> param_value = {"b": 1}
                >>> _add_param_to_feed_dict(param_name, param_value)
                >>> assert feed_dict["params:a"] == {"b": 1}
                >>> assert feed_dict["params:a.b"] == 1
            """
            key = "params:{}".format(param_name)
            feed_dict[key] = param_value

            if isinstance(param_value, dict):
                for key, val in param_value.items():
                    _add_param_to_feed_dict("{}.{}".format(param_name, key), val)

        for param_name, param_value in params.items():
            _add_param_to_feed_dict(param_name, param_value)

        return feed_dict

另请注意,此问题已 addressed on develop 并将在下一个版本中提供。修复使用了上面代码片段中的方法。

正如 Dmitry 所提到的,kedro 0.16.0 introduced 节点输入中嵌套的参数值可以通过 . 运算符访问:

node(func, "params:a.b", None)

kedro 0.17.6 enabled overriding nested parameters with params in CLI,例如

kedro run --params="model.model_tuning.booster:gbtree"