如何通过 Airflow 轮询和跟踪外部作业状态?

How to poll and keep track of external job status through Airflow?

我正在使用 Airflow 进行轮询 boto3 以检查 SageMaker Autopilot 作业的状态。我正在使用 PythonSensor 等待 JobStatusJobSecondaryStatus 的状态变为 return Completed,然后结束整个管道。这些是它们可以包含的值,我在代码中对其进行了枚举:

'AutoMLJobStatus': 'Completed'|'InProgress'|'Failed'|'Stopped'|'Stopping',
'AutoMLJobSecondaryStatus': 'Starting'|'AnalyzingData'|'FeatureEngineering'|'ModelTuning'|'MaxCandidatesReached'|'Failed'|'Stopped'|'MaxAutoMLJobRuntimeReached'|'Stopping'|'CandidateDefinitionsGenerated'|'GeneratingExplainabilityReport'|'Completed'|'ExplainabilityError'|'DeployingModel'|'ModelDeploymentError'

_sagemaker_job_status 通过来自上游任务的 xcom 获取 automl_job_name 并成功通过。使用此作业名称,我可以将其传递给 descibe_auto_ml_job() 以通过 AutoMLJobStatusAutoMLJobSecondaryStatus.

获取状态

主要目的是通过 Slack 发送消息以查看作业所处的所有独特阶段。目前,我正在尝试将所有唯一的作业状态保存到一个集合中,然后在发送包含作业状态的消息之前检查该集合。

但每次 _sagemaker_job_status 被戳时,集合的值似乎是相同的,因此每次函数被戳时都会发送一条松弛消息,我记录了集合,两者都是空的。下面我做了一个更简单的例子。

import airflow
from airflow import DAG
from airflow.exceptions import AirflowFailException
from airflow.operators.dummy import DummyOperator
from airflow.operators.python import PythonOperator
from airflow.sensors.python import PythonSensor
import boto3


def _sagemaker_job_status(templates_dict, **context):
    """
    Checks the SageMaker AutoMLJobStatus and AutoMLJobSecondaryStatus
    for updates and when both are complete the entire process is marked as
    successful
    """
    automl_job_name = templates_dict.get("automl_job_name")
    if not automl_job_name:
        error_message = "AutoMLJobName was not passed from upstream"
        print(error_message)
        task_fail_slack_alert(
            context=context,
            extra_message=error_message,
        )
    client = boto3.client("sagemaker", "us-east-1")
    response = client.describe_auto_ml_job(
        AutoMLJobName=automl_job_name,
    )
    job_status = response.get("AutoMLJobStatus")
    secondary_job_status = response.get("AutoMLJobSecondaryStatus")
    past_job_statuses = set()
    past_secondary_job_statuses = set()
    print(f"Past Job Statuses : {past_job_statuses}")
    print(f"Past Secondary Job Statuses : {past_secondary_job_statuses}")
    # If the job status has not been already seen
    if (
        job_status not in past_job_statuses
        and secondary_job_status not in past_secondary_job_statuses
    ):
        message = f"""
            JobStatus : {job_status}
            JobSecondaryStatus : {secondary_job_status}
            """
        print(message)
        task_success_slack_alert(
            context=context,
            extra_message=message,
        )
    past_job_statuses.add(job_status)
    past_secondary_job_statuses.add(secondary_job_status)
    # If the main job fails
    if job_status == JobStatus.Failed.value: 
        error_message = "SageMaker Autopilot Job Failed!"
        task_fail_slack_alert(
            context=context,
            extra_message=error_message,
        )
        raise AirflowFailException(error_message)
    
    return (
        job_status == JobStatus.Completed.value 
        and secondary_job_status == JobSecondaryStatus.Completed.value
    )

args = {
    "owner": "Yudhiesh",
    "start_date": airflow.utils.dates.days_ago(1),
    "schedule_interval": "@once",
    "on_failure_callback": task_fail_slack_alert,
}

with DAG(
    dag_id="02_lasic_retraining_sagemaker_autopilot",
    default_args=args,
    render_template_as_native_obj=True,
) as dag:

    sagemaker_job_status = PythonSensor(
        task_id="sagemaker_job_status",
        python_callable=_sagemaker_job_status,
        templates_dict={
            "automl_job_name": "{{task_instance.xcom_pull(task_ids='train_model_sagemaker_autopilot')}}",  # noqa: E501
        },
    )

    end = DummyOperator(
        task_id="end",
    )

    sagemaker_job_status >> end

我创建了一个与以前类似的设置,但这次我从 JobStatusJobSecondaryStatus 的枚举中随机生成值,并尝试只打印唯一的值,结果是它工作得很好。任何人都可以解释为什么会发生这种情况以及我可以对主要示例做些什么来让它工作吗?

import airflow
import random
from airflow import DAG
from airflow.sensors.python import PythonSensor
from airflow.operators.dummy import DummyOperator
from airflow.exceptions import AirflowFailException

def _mimic_sagemaker_job_status():
    job_statuses = [status.value for status in JobStatus]
    job_secondary_statuses = [
        secondary_status.value for secondary_status in JobSecondaryStatus
    ]
    past_job_statuses = set()
    past_secondary_job_statuses = set()
    job_status = random.choice(job_statuses)
    job_secondary_status = random.choice(job_secondary_statuses)
    if (
        job_status not in past_job_statuses
        and job_secondary_status not in past_secondary_job_statuses
    ):
        message = f"""
            JobStatus : {job_status}
            JobSecondaryStatus : {job_secondary_status}
            """
        # Send alerts on every new job status update
        print(message)
    past_job_statuses.add(job_status)
    past_secondary_job_statuses.add(job_secondary_status)
    if (
        job_status == JobStatus.Failed.value
        or job_secondary_status == JobSecondaryStatus.Failed.value
    ):
        raise AirflowFailException("SageMaker Autopilot Job Failed!")

    return (
        job_secondary_status == JobSecondaryStatus.Completed.value
        and job_status == JobStatus.Completed.value
    )


with DAG(
    dag_id="04_sagemaker_sensor",
    start_date=airflow.utils.dates.days_ago(3),
    schedule_interval="@once",
    render_template_as_native_obj=True,
) as dag:

    wait_for_status = PythonSensor(
        task_id="wait_for_status",
        python_callable=_mimic_sagemaker_job_status,
        dag=dag,
    )

    end = DummyOperator(
        task_id="end",
    )

    wait_for_status >> end

上面代码中使用的枚举:

from enum import Enum

class JobStatus(Enum):
    """
    Enum of all the potential values of a SageMaker Autopilot job status
    """

    Completed = "Completed"
    InProgress = "InProgress"
    Failed = "Failed"
    Stopped = "Stopped"
    Stopping = "Stopping"


class JobSecondaryStatus(Enum):
    """
    Enum of all the potential values of a SageMaker Autopilot job secondary
    status
    """

    Starting = "Starting"
    AnalyzingData = "AnalyzingData"
    FeatureEngineering = "FeatureEngineering"
    ModelTuning = "ModelTuning"
    MaxCandidatesReached = "MaxCandidatesReached"
    Failed = "Failed"
    Stopped = "Stopped"
    MaxAutoMLJobRuntimeReached = "MaxAutoMLJobRuntimeReached"
    Stopping = "Stopping"
    CandidateDefinitionsGenerated = "CandidateDefinitionsGenerated"
    GeneratingExplainabilityReport = "GeneratingExplainabilityReport"
    Completed = "Completed"
    ExplainabilityError = "ExplainabilityError"
    DeployingModel = "DeployingModel"
    ModelDeploymentError = "ModelDeploymentError"

编辑:

我想主要示例的另一个解决方法是让操作员创建一个临时文件,其中包含 sagemaker 作业状态之前的集合的 JSON,然后在 sagemaker 作业状态中我可以检查该作业状态保存到文件,然后打印它们(如果它们是唯一的)。我刚刚意识到我也可以使用数据库。

所以我似乎无法让它按原样工作,所以我求助于创建一个 JSON 文件来存储我在 PythonSensor 中读取和写入的不同 SageMaker Autopilot 作业状态。

这会接收上一步中的 AutoMLJobName,创建作业状态的临时文件,以及 returns AutoMLJobName 和 JSON 文件的名称。

import tempfile

def _create_job_status_json(templates_dict, **context):
    automl_job_name = templates_dict.get("sagemaker_autopilot_data_paths")
    if not automl_job_name:
        error_message = "AutoMLJobName was not passed from upstream"
        print(error_message)
        task_fail_slack_alert(
            context=context,
            extra_message=error_message,
        )
    initial = {
        "JobStatus": [],
        "JobSecondaryStatus": [],
    }
    file = tempfile.NamedTemporaryFile(mode="w", delete=False)
    json.dump({"Status": initial}, file)
    file.flush()
    return (file.name, automl_job_name)

接下来这个函数根据名称读取JSON文件,然后根据boto3 sagemaker客户端检查不同的作业状态。如果主要工作失败,那么整个 运行 都会失败。如果其中一个是唯一的,它会将作业状态添加到字典中。完成后,它将把字典写入 JSON 文件。当整个作业完成时,它会以 Slack 消息的形式发送有关最佳模型的一些详细信息。它 returns true 当两个工作状态都是 Completed 时。请注意,如果作业成功或失败,我还删除了 JSON 文件。

import airflow
from airflow import DAG
from airflow.exceptions import AirflowFailException
import boto3

def _sagemaker_job_status(templates_dict, **context):
    """
    Checks the SageMaker AutoMLJobStatus and AutoMLJobSecondaryStatus
    for updates and when both are complete the entire process is marked as
    successful
    """
    file_name, automl_job_name = templates_dict.get("automl_job_data")
    job_status_dict = {}
    client = boto3.client("sagemaker", "us-east-1")
    if not client:
        raise AirflowFailException(
            "Unable to get access to boto3 sagemaker client",
        )
    with open(file_name, "r") as json_file:
        response = client.describe_auto_ml_job(
            AutoMLJobName=automl_job_name,
        )
        job_status = response.get("AutoMLJobStatus")
        secondary_job_status = response.get("AutoMLJobSecondaryStatus")
        job_status_dict = json.load(json_file)
        status = job_status_dict.get("Status")
        past_job_statuses = status.get("JobStatus")
        past_secondary_job_statuses = status.get("JobSecondaryStatus")
        if job_status == JobStatus.Failed.value:
            error_message = "SageMaker Autopilot Job Failed!"
            task_fail_slack_alert(
                context=context,
                extra_message=error_message,
            )
            os.remove(file_name)
            raise AirflowFailException(error_message)
        if (
            job_status not in past_job_statuses
            or secondary_job_status not in past_secondary_job_statuses
        ):
            message = f"""
                JobStatus : {job_status}
                JobSecondaryStatus : {secondary_job_status}
                """
            print(message)
            task_success_slack_alert(
                context=context,
                extra_message=message,
            )
            past_job_statuses.append(job_status)
            past_secondary_job_statuses.append(secondary_job_status)

    with open(file_name, "w") as file:
        json.dump(job_status_dict, file)

        if (
            job_status == JobStatus.Completed.value
            and secondary_job_status == JobSecondaryStatus.Completed.value
        ):
            os.remove(file_name)
            response = client.describe_auto_ml_job(
                AutoMLJobName=automl_job_name,
            )
            best_candidate = response.get("BestCandidate")
            best_candidate_id = best_candidate.get("CandidateName")
            best_metric_name = (
                best_candidate.get("FinalAutoMLJobObjectiveMetric")
                .get("MetricName")
                .split(":")[1]
                .upper()
            )
            best_metric_value = round(
                best_candidate.get("FinalAutoMLJobObjectiveMetric").get(
                    "Value",
                ),
                3,
            )
            message = f"""
                Best Candidate ID : {best_candidate_id}
                Best Candidate Metric Score : {best_metric_value}{best_metric_name}
            """  # noqa: E501
            task_success_slack_alert(
                context=context,
                extra_message=message,
            )

        return (
            job_status == JobStatus.Completed.value
            and secondary_job_status == JobSecondaryStatus.Completed.value
        )

DAG代码:

import airflow
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.sensors.python import PythonSensor

args = {
    "owner": "Yudhiesh",
    "start_date": airflow.utils.dates.days_ago(1),
    "schedule_interval": "@once",
    "on_failure_callback": task_fail_slack_alert,
}

with DAG(
    dag_id="02_lasic_retraining_sagemaker_autopilot",
    default_args=args,
    render_template_as_native_obj=True,
) as dag:


    create_job_status_json = PythonOperator(
        task_id="create_job_status_json",
        python_callable=_create_job_status_json,
        templates_dict={
            "sagemaker_autopilot_data_paths": "{{task_instance.xcom_pull(task_ids='train_model_sagemaker_autopilot')}}",  # noqa: E501
        },
    )

    sagemaker_job_status = PythonSensor(
        task_id="sagemaker_job_status",
        python_callable=_sagemaker_job_status,
        templates_dict={
            "automl_job_data": "{{task_instance.xcom_pull(task_ids='create_job_status_json')}}",  # noqa: E501
        },
    )
    # train_model_sagemaker_autopilot is not included but it initiates the training through boto3
    train_model_sagemaker_autopilot >> create_job_status_json

    create_job_status_json >> sagemaker_job_status