如何在 Dbt Cloud Run JobOperator 中访问 Xcom 值

How to access Xcom value in DbtCloudRunJobOperator

我正在编写 Dag,但在尝试从任务中提取 Xcom 值时遇到问题。

我想实现这样的目标

  1. 编写函数func获取{{ dag_run.conf }}
  2. 中手动触发的dag_run参数的值

{"dlf_40_start_at":"2022-01-01", "dlf_40_end_at":"2022-01-02"},将值推送到Xcom


def func(**kwargs):
    dag_run_conf = kwargs["dag_run"].conf
    task_instance: TaskInstance = kwargs.get("task_instance") or kwargs.get("ti")
    task_instance.xcom_push(key='current_load_window_start', value=dag_run_conf['dlf_40_start_at'])
    task_instance.xcom_push(key='current_load_window_end', value=dag_run_conf['dlf_40_end_at'])

  1. PythonOperator中调用函数func将值推送到Xcom。
    extract_manual_config_parameter = PythonOperator(
        task_id='extract_manual_config_parameter',
        python_callable=func,
        dag=dag_capture_min_id,
        provide_context=True  # Remove this if you are using Airflow >=2.0
    )


  1. 将我的user_defined_macros定义为
def get_return_value(a, b):
    print("dlf_40_start_at", a, "dlf_40_end_at", b)
    return [a, b]
  1. 在 DbtCloudRunJobOperator 中,评估 jinja 模板并从中获取值。

这是我的完整代码

from datetime import timedelta, datetime

from airflow import DAG
from airflow.api.client.local_client import Client
from airflow.models import Variable, TaskInstance, XCom
from airflow.models.dagrun import DagRun
from airflow.operators.python_operator import PythonOperator

from common.dbt_cloud_plugin.operators.domain_dbt_cloud_job_operator import DomainDbtCloudJobOperator
from common.utils.slack_utils import SlackUtils
from data_product_insights.scv_main_pipeline.utils import snowflake_operator, stage_config, \
    create_common_tasks_for_matching_closure_dag

from data_product_insights.scv_main_pipeline.config import DbtStageConfig

stg = Variable.get('stage')
dbt_config = {
    'dev': DbtStageConfig(dbt_cloud_job_name='scv-ga-data-extraction-dev'),
    'stg': DbtStageConfig(dbt_cloud_job_name='scv-ga-data-extraction-stg'),
    'prd': DbtStageConfig(dbt_cloud_job_name='scv-ga-data-extraction-prd')
}
dbt_stage_config = dbt_config[stg]


def close_process_and_trigger_downstream(**kwargs: dict) -> None:  # pylint: disable=unused-argument
    dag_run: DagRun = kwargs.get("dag_run")
    loop_index = dag_run.conf.get("loop_index", 1)
    process_id = dag_run.conf.get("process_id")

    sql = f"UPDATE SCV_STAGING.ETL_PROCESS SET ETL_END_LOAD_DATE = CURRENT_TIMESTAMP(), " \
          f"IS_NEXT_ETL = FALSE, ETL_STATUS = 'done', MAX_SESSIONID = {loop_index} " \
          f"WHERE PROCESS_ID = {process_id} AND TYPE = 'matching'"

    snowflake_operator(
        task_id="close_matching_process_record",
        sql=sql
    ).execute(kwargs)

    client = Client(None, None)
    client.trigger_dag(stage_config.dag_id_scv_search_event_pipeline,
                       run_id=datetime.now().strftime("%d_%m_%Y_%H_%M_%S"),
                       conf={"process_id": process_id})


def func(**kwargs):
    dag_run_conf = kwargs["dag_run"].conf
    task_instance: TaskInstance = kwargs.get("task_instance") or kwargs.get("ti")
    task_instance.xcom_push(key='current_load_window_start', value=dag_run_conf['dlf_40_start_at'])
    task_instance.xcom_push(key='current_load_window_end', value=dag_run_conf['dlf_40_end_at'])


def get_return_value(a, b):
    print("dlf_40_start_at", a, "dlf_40_end_at", b)
    return [a, b]


with DAG(dag_id=stage_config.dag_id_scv_matching_process_closure_dag,
         description='Matching engine of scv data',
         schedule_interval=None,
         default_args={
             'start_date': stage_config.start_date,
             'retries': stage_config.retries,
             'retry_delay': timedelta(minutes=1),
             'on_failure_callback': SlackUtils(stage_config.slack_connection_id).post_slack_failure,
             'retry_exponential_backoff': stage_config.retry_back_off,
         },
         user_defined_macros={"get_return_value": get_return_value},
         params=stage_config.as_dict()) as dag_capture_min_id:


    extract_manual_config_parameter = PythonOperator(
        task_id='extract_manual_config_parameter',
        python_callable=func,
        dag=dag_capture_min_id,
        provide_context=True  # Remove this if you are using Airflow >=2.0
    )

    return_value = "{{ get_return_value( task_instance.xcom_pull(task_ids='extract_manual_config_parameter', " \
                   "key='current_load_window_start'), task_instance.xcom_pull(" \
                   "task_ids='extract_manual_config_parameter', key='current_load_window_end')) }} "


    close_matching_process_and_trigger_downstream_dags = PythonOperator(
        task_id="close_matching_process_and_trigger_downstream_dags",
        python_callable=close_process_and_trigger_downstream,
        provide_context=True
    )


    dbt_40_dlf = DomainDbtCloudJobOperator(
        task_id='dbt_40_dlf',
        xcom_task_id='extract_manual_config_parameter',
        dbt_cloud_conn_id=dbt_stage_config.dbt_cloud_connection,
        job_name=dbt_stage_config.dbt_cloud_job_name,
        data={
            "cause": "Kicked off from Airflow",
            "git_branch": dbt_stage_config.dbt_job_dev_version,
            "steps_override": ['dbt test'],
            "xcom_value": return_value,
            "haha": "haha"
        },
        dag=dag_capture_min_id
    ).build_operators()

    create_common_tasks_for_matching_closure_dag(
        dag=dag_capture_min_id,
        downstream_task=extract_manual_config_parameter
    )
    extract_manual_config_parameter >> dbt_40_dlf.first
    dbt_40_dlf.last >> close_matching_process_and_trigger_downstream_dags

日志结果表明 jinja 模板未被评估。

INFO - The data is {'cause': 'Kicked off via Airflow: scv_matching_process_closure_dag', 'git_branch': 'development', 'steps_override': ['dbt test'], 'xcom_value': "{{ get_return_value(task_instance.xcom_pull(task_ids='extract_manual_config_parameter', key='current_load_window_start'), task_instance.xcom_pull(task_ids='extract_manual_config_parameter', key='current_load_window_end')) }} ", 'haha': 'haha'}

我是不是做错了什么?

XCOM 会自动使用来自任务的 return 值,因此您可以只 return 您的 return_value 而不是使用 DomainDbtCloudJobOperator,它会以你当前所在的任务名称保存为XCOM。

那么在你的下一个任务中,你只需要做这样的事情:

my_xcom = kwargs['task_instance'].xcom_pull(
    task_ids='my_previous_task_name')

我认为您的方法可以进行一些更改,但非常over-engineered。