如何将 dag_run.conf 用于类型参数

How to use dag_run.conf for typed arguments

我有一个创建 Google Dataproc 集群并向其提交作业的 DAG。

我希望能够通过 dag_run.conf 参数自定义集群(工人数量)和作业(传递给它的参数)。

集群创建

对于集群创建,我写了一个类似这样的逻辑:

DataprocCreateClusterOperator(...
        cluster_config = {...
             num_workers = "{% if 'cluster' is in dag_run.conf and 'secondary_worker_config' is in dag_run.conf['cluster'] and 'num_instances' is in dag_run.conf['cluster']['secondary_worker_config'] %}{{ dag_run.conf['cluster']['secondary_worker_config']['num_instances'] }}{% else %}16{% endif %}"
        }
)

也就是说,如果cluster.secondary_worker_config.num_instancesdag_run.conf中可用,则使用它,否则回退到默认值16

但是,在呈现时,它会扩展为 Python 字符串,如 "16",导致失败,因为 num_workers 参数必须是 intlong.

我无法在运算符声明期间将其解析为 int:

num_workers = int("{% ... %}")

因为这会尝试将整个 jinja 脚本解释为一个整数(而不是结果值)。

使用 | int jinja 过滤器都不能解决问题。

作业提交

我在提交作业时遇到了类似的问题。 运算符期望一个 job dict 参数,字段 spark.args 为 spark 作业提供参数。该字段必须是一个可迭代的,并且应该是一个字符串列表,例如:["--arg=foo", "bar"].

我希望能够通过 dag_run.conf:

提供一些参数来添加它们
{
    args = ["--new_arg=baz", "bar2"]
}

但是将这些参数添加到初始列表似乎是不可能的。您可以为所有附加参数获得一个参数:["--arg=foo", "bar", "--new_arg=baz bar2"],或者一个包含所有参数的字符串。

在任何情况下,生成的作业提交都没有按预期工作...


是否有解决此问题的现有方法?

如果不是,有没有办法在“模板渲染”之后添加一个“铸造步骤”,在提供程序操作符中或直接在 BaseOperator 摘要 class 中?


编辑

我认为 提出的解决方案是可行的方法。但是,对于那些不想升级 Airflow 的人,我尝试实施 Jarek 提出的解决方案。

import unittest
import datetime
from typing import Any

from airflow import DAG
from airflow.models import BaseOperator, TaskInstance


# Define an operator which check its argument type at runtime (during "execute")
class TypedOperator(BaseOperator):
    def __init__(self, int_param: int, **kwargs):
        super(TypedOperator, self).__init__(**kwargs)
        self.int_param = int_param

    def execute(self, context: Any):
        assert(type(self.int_param) is int)


# Extend the "typed" operator with an operator handling templating
class TemplatedOperator(TypedOperator):
    template_fields = ['templated_param']

    def __init__(self,
                 templated_param: str = "{% if 'value' is in dag_run.conf %}{{ dag_run.conf['value'] }}{% else %}16{% endif %}",
                 **kwargs):
        super(TemplatedOperator, self).__init__(int_param=int(templated_param), **kwargs)


# Run a test, instantiating a task and executing it
class JinjaTest(unittest.TestCase):

    def test_templating(self):
        print("Start test")
        dag = DAG("jinja_test_dag", default_args=dict(
            start_date=datetime.date.today().isoformat()
        ))
        print("Task intanciation (regularly done by scheduler)")
        task = TemplatedOperator(task_id="my_task", dag=dag)
        print("Done")

        print("Task execution (only done when DAG triggered)")
        context = TaskInstance(task=task, execution_date=datetime.datetime.now()).get_template_context()
        task.execute(context)
        print("Done")

        self.assertTrue(True)

给出输出:


Start test

Task intanciation (regularly done by scheduler)

Ran 1 test in 0.006s

FAILED (errors=1)

Error
Traceback (most recent call last):
  File "/home/alexis/AdYouLike/Repositories/data-airflow-dags/tests/data_airflow_dags/utils/tasks/test_jinja.py", line 38, in test_templating
    task = TemplatedOperator(task_id="my_task", dag=dag)
  File "/home/alexis/AdYouLike/Repositories/data-airflow-dags/.venv/lib/python3.6/site-packages/airflow/models/baseoperator.py", line 89, in __call__
    obj: BaseOperator = type.__call__(cls, *args, **kwargs)
  File "/home/alexis/AdYouLike/Repositories/data-airflow-dags/tests/data_airflow_dags/utils/tasks/test_jinja.py", line 26, in __init__
    super(TemplatedOperator, self).__init__(int_param=int(templated_param), **kwargs)
ValueError: invalid literal for int() with base 10: "{% if 'value' is in dag_run.conf %}{{ dag_run.conf['value'] }}{% else %}16{% endif %}"

如您所见,这在任务实例化步骤失败了,因为在 TemplatedOperator.__init__ 中我们尝试转换为 int JINJA 模板(而不是渲染值)。 也许我在这个解决方案中遗漏了一点,但它似乎无法按原样使用。

最简单的方法是定义派生自 DataprocCreateClusterOperator 的自定义运算符。这非常简单,您甚至可以在 dag 文件中完成:

概念上是这样的:

class MyDataprocCreateClusterOperator(DataprocCreateClusterOperator):
   template_fields = DataprocCreateClusterOperator.template_fields + ['my_param']
   def __init__(my_param='{{ ... }}', .....):
      super(int_param=int(my_param), ....)

不幸的是,所有 Jinja 模板都呈现为字符串,因此 @JarekPotiuk 提出的解决方案是您最好的选择。

但是,对于使用 Airflow 2.1+ 的任何人或者如果您想升级,可以在 DAG 级别设置一个新参数:render_template_as_native_obj

启用此参数时,Jinja 模板的输出将作为本机 Python 类型(例如列表、元组、整数等)返回。在此处了解更多信息:https://airflow.apache.org/docs/apache-airflow/stable/concepts/operators.html#rendering-fields-as-native-python-objects