Kedro - 如何将嵌套参数直接传递给节点
Kedro - how to pass nested parameters directly to node
kedro
建议将参数存储在 conf/base/parameters.yml
中。让我们假设它看起来像这样:
step_size: 1
model_params:
learning_rate: 0.01
test_data_ratio: 0.2
num_train_steps: 10000
现在假设我有一些 data_engineering
管道,其 nodes.py
具有如下所示的功能:
def some_pipeline_step(num_train_steps):
"""
Takes the parameter `num_train_steps` as argument.
"""
pass
我将如何着手将嵌套参数直接传递给 data_engineering/pipeline.py
中的这个函数?我尝试失败:
from kedro.pipeline import Pipeline, node
from .nodes import split_data
def create_pipeline(**kwargs):
return Pipeline(
[
node(
some_pipeline_step,
["params:model_params.num_train_steps"],
dict(
train_x="train_x",
train_y="train_y",
),
)
]
)
我知道我可以使用 ['parameters']
将所有参数传递给函数,或者使用 ['params:model_params']
传递所有 model_params
参数,但这似乎不雅,我觉得必须成为一种方式。非常感谢任何输入!
(免责声明:我是 Kedro 团队的一员)
感谢您的提问。遗憾的是,当前版本的 Kedro 不支持嵌套参数。临时解决方案是在节点内使用顶级键(正如您已经指出的那样)或使用某种参数过滤器装饰节点函数,这也不优雅。
可能最可行的解决方案是通过覆盖 _get_feed_dict
方法来自定义 ProjectContext
(在 src/<package_name>/run.py
中)class:
class ProjectContext(KedroContext):
# ...
def _get_feed_dict(self) -> Dict[str, Any]:
"""Get parameters and return the feed dictionary."""
params = self.params
feed_dict = {"parameters": params}
def _add_param_to_feed_dict(param_name, param_value):
"""This recursively adds parameter paths to the `feed_dict`,
whenever `param_value` is a dictionary itself, so that users can
specify specific nested parameters in their node inputs.
Example:
>>> param_name = "a"
>>> param_value = {"b": 1}
>>> _add_param_to_feed_dict(param_name, param_value)
>>> assert feed_dict["params:a"] == {"b": 1}
>>> assert feed_dict["params:a.b"] == 1
"""
key = "params:{}".format(param_name)
feed_dict[key] = param_value
if isinstance(param_value, dict):
for key, val in param_value.items():
_add_param_to_feed_dict("{}.{}".format(param_name, key), val)
for param_name, param_value in params.items():
_add_param_to_feed_dict(param_name, param_value)
return feed_dict
另请注意,此问题已 addressed on develop 并将在下一个版本中提供。修复使用了上面代码片段中的方法。
正如 Dmitry 所提到的,kedro 0.16.0
introduced 节点输入中嵌套的参数值可以通过 .
运算符访问:
node(func, "params:a.b", None)
而 kedro 0.17.6
enabled overriding nested parameters with params
in CLI,例如
kedro run --params="model.model_tuning.booster:gbtree"
kedro
建议将参数存储在 conf/base/parameters.yml
中。让我们假设它看起来像这样:
step_size: 1
model_params:
learning_rate: 0.01
test_data_ratio: 0.2
num_train_steps: 10000
现在假设我有一些 data_engineering
管道,其 nodes.py
具有如下所示的功能:
def some_pipeline_step(num_train_steps):
"""
Takes the parameter `num_train_steps` as argument.
"""
pass
我将如何着手将嵌套参数直接传递给 data_engineering/pipeline.py
中的这个函数?我尝试失败:
from kedro.pipeline import Pipeline, node
from .nodes import split_data
def create_pipeline(**kwargs):
return Pipeline(
[
node(
some_pipeline_step,
["params:model_params.num_train_steps"],
dict(
train_x="train_x",
train_y="train_y",
),
)
]
)
我知道我可以使用 ['parameters']
将所有参数传递给函数,或者使用 ['params:model_params']
传递所有 model_params
参数,但这似乎不雅,我觉得必须成为一种方式。非常感谢任何输入!
(免责声明:我是 Kedro 团队的一员)
感谢您的提问。遗憾的是,当前版本的 Kedro 不支持嵌套参数。临时解决方案是在节点内使用顶级键(正如您已经指出的那样)或使用某种参数过滤器装饰节点函数,这也不优雅。
可能最可行的解决方案是通过覆盖 _get_feed_dict
方法来自定义 ProjectContext
(在 src/<package_name>/run.py
中)class:
class ProjectContext(KedroContext):
# ...
def _get_feed_dict(self) -> Dict[str, Any]:
"""Get parameters and return the feed dictionary."""
params = self.params
feed_dict = {"parameters": params}
def _add_param_to_feed_dict(param_name, param_value):
"""This recursively adds parameter paths to the `feed_dict`,
whenever `param_value` is a dictionary itself, so that users can
specify specific nested parameters in their node inputs.
Example:
>>> param_name = "a"
>>> param_value = {"b": 1}
>>> _add_param_to_feed_dict(param_name, param_value)
>>> assert feed_dict["params:a"] == {"b": 1}
>>> assert feed_dict["params:a.b"] == 1
"""
key = "params:{}".format(param_name)
feed_dict[key] = param_value
if isinstance(param_value, dict):
for key, val in param_value.items():
_add_param_to_feed_dict("{}.{}".format(param_name, key), val)
for param_name, param_value in params.items():
_add_param_to_feed_dict(param_name, param_value)
return feed_dict
另请注意,此问题已 addressed on develop 并将在下一个版本中提供。修复使用了上面代码片段中的方法。
正如 Dmitry 所提到的,kedro 0.16.0
introduced 节点输入中嵌套的参数值可以通过 .
运算符访问:
node(func, "params:a.b", None)
而 kedro 0.17.6
enabled overriding nested parameters with params
in CLI,例如
kedro run --params="model.model_tuning.booster:gbtree"