使用尚未实现方法的模拟补丁增强 moto

Augmenting moto with mock patch where method is not yet implemented

我正在编写一个 lambda 函数,它获取 CW 日志组列表并在每个日志组上运行“导出到 s3”任务。

我正在使用 pytest 编写自动化测试,我正在使用 moto.mock_logs(以及其他),但是 create_export_tasks() 尚未实现(NotImplementedError)。

为了继续对所有其他方法使用 moto.mock_logs,我尝试使用 mock.patch 仅修补单个 create_export_task() 方法,但它无法找到要修补的正确对象( ImportError).

我成功地使用 mock.Mock() 为我提供了我需要的功能,但我想知道我是否可以用 mock.patch() 做同样的事情?

工作代码:lambda.py

# lambda.py
"""Export CloudWatch Logs to S3 every 24 hours."""
import logging
import os
from time import time
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class CloudWatchLogsS3Archive:
    botocore_config = Config(retries={"max_attempts": 10, "mode": "adaptive"})

    def __init__(self, s3_bucket, account_id) -> None:
        self.s3_bucket = s3_bucket
        self.account_id = account_id
        self.extra_args = {}
        self.log_groups = []
        self.log_groups_to_export = []
        self.logs = boto3.client("logs", config=self.botocore_config)
        self.ssm = boto3.client("ssm", config=self.botocore_config)

    def check_valid_inputs(self):
        """Check that required inputs are present and valid"""
        if len(self.account_id) != 12:
            logging.error("Account Id must be valid 12-digit AWS account id")
            raise ValueError("Account Id must be valid 12-digit AWS account id")

    def collect_log_groups(self):
        """Capture the names of all of the CloudWatch Log Groups"""
        paginator = self.logs.get_paginator("describe_log_groups")
        page_it = paginator.paginate()
        for p in page_it:
            for lg in p["logGroups"]:
                yield lg["logGroupName"]

    def get_last_export_time(self, Name) -> str:
        """Get time of the last export from SSM Parameter Store"""
        try:
            return self.ssm.get_parameter(Name=Name)["Parameter"]["Value"]
        except (self.ssm.exceptions.ParameterNotFound, ClientError) as exc:
            logger.warning(*exc.args)
            if exc.response["Error"]["Code"] == "ParameterNotFound":  # type: ignore
                return "0"
            else:
                raise

    def set_export_time(self):
        """Set current export time"""
        return round(time() * 1000)

    def put_export_time(self, put_time, Name):
        """Put current export time to SSM Parameter Store"""
        self.ssm.put_parameter(Name=Name, Value=str(put_time), Overwrite=True)

    def create_export_tasks(
        self, log_group_name, fromTime, toTime, s3_bucket, account_id
    ):
        """Create new CloudWatchLogs Export Tasks"""
        try:
            response = self.logs.create_export_task(
                logGroupName=log_group_name,
                fromTime=int(fromTime),
                to=toTime,
                destination=s3_bucket,
                destinationPrefix="{}/{}".format(account_id, log_group_name.strip("/")),
            )
            logger.info("✔   Task created: %s" % response["taskId"])
        except self.logs.exceptions.LimitExceededException:
            """The Boto3 standard retry mode will catch throttling errors and
            exceptions, and will back off and retry them for you."""
            logger.warning(
                "⚠   Too many concurrently running export tasks "
                "(LimitExceededException); backing off and retrying..."
            )
            # return False
        except Exception as e:
            logger.exception(
                "✖   Error exporting %s: %s",
                log_group_name,
                getattr(e, "message", repr(e)),
            )


def lambda_handler(event, context):
    s3_bucket = os.environ["S3_BUCKET"]
    account_id = os.environ["ACCOUNT_ID"]
    c = CloudWatchLogsS3Archive(s3_bucket, account_id)
    c.check_valid_inputs()
    log_groups = c.collect_log_groups()
    for log_group_name in log_groups:
        fromTime = c.get_last_export_time(log_group_name)
        toTime = c.set_export_time()
        c.create_export_tasks(log_group_name, fromTime, toTime, s3_bucket, account_id)
        c.put_export_time(log_group_name, toTime)

测试代码(pytest):test_lambda.py

# test_lambda.py
"""Test Lambda Function"""
import os
from unittest import mock

import boto3
import moto
import pytest


@pytest.fixture
def f_aws_credentials(autouse=True):
    """Mocked AWS Credentials for moto.

    This is a "side effect" function and None is returned because we are
    modifying the environment in which other downstream functions are excuting
    """
    os.environ["AWS_ACCESS_KEY_ID"] = "testing"
    os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
    os.environ["AWS_SECURITY_TOKEN"] = "testing"
    os.environ["AWS_SESSION_TOKEN"] = "testing"
    os.environ["AWS_DEFAULT_REGION"] = "us-east-1"


@moto.mock_logs
@moto.mock_ssm
def test_create_export_tasks():
    from cloudwatch_logs_s3_archive import CloudWatchLogsS3Archive

    c = CloudWatchLogsS3Archive("bucket", "123412341234")
    # ssm = boto3.client("ssm")
    logs = boto3.client("logs")
    logs.create_log_group(logGroupName="/log-exporter-last-export/first")
    logs.create_log_group(logGroupName="/log-exporter-last-export/second")
    logs.create_log_group(logGroupName="/log-exporter-last-export/third")
    log_group_name = "/log-exporter-last-export/first"
    s3_bucket = "s3_bucket"
    account_id = 123412341234
    toTime = c.set_export_time()
    fromTime = c.get_last_export_time(Name="/log-exporter-last-export/first")
    c.logs.create_export_task = mock.Mock(
        return_value={"taskId": "I am mocked via mock.Mock"}
    )
    c.create_export_tasks(
        "/log-exporter-last-export/first", fromTime, toTime, "s3_bucket", 123412341234
    )
    assert c.logs.create_export_task.called
    c.logs.create_export_task.assert_called
    c.logs.create_export_task.assert_called_with(
        logGroupName=log_group_name,
        fromTime=int(fromTime),
        to=toTime,
        destination=s3_bucket,
        destinationPrefix="{}/{}".format(account_id, log_group_name.strip("/")),
    )

I'm wondering if I can do the same with mock.patch()?

当然可以,使用 mock.patch.object():

with mock.patch.object(
    c.logs,
    'create_export_task',
    return_value={"taskId": "I am mocked via mock.Mock"}
):
    c.create_export_tasks(
        "/log-exporter-last-export/first", fromTime, toTime, "s3_bucket", 123412341234
    )
    assert c.logs.create_export_task.called

如果您不喜欢使用上下文管理器,我建议您安装 pytest-mock 插件和 pytest 插件,它提供了一个方便的 mocker 固定装置。你的测试看起来像

def test_create_export_tasks(mocker):
    ...
    mocker.patch.object(
        c.logs,
        'create_export_task',
        return_value={"taskId": "I am mocked via mock.Mock"}
    )
    c.create_export_tasks(
        "/log-exporter-last-export/first", fromTime, toTime, "s3_bucket", 123412341234
    )
    assert c.logs.create_export_task.called

mocker 基本上是 unittest.mock 模块的代理,提供相同的功能和方法,除了它会在测试结束时自动清除所有补丁,所以少了一件需要关心的事情关于。