from unittest import mock

from prowler.providers.aws.services.dms.dms_service import ReplicationTasks
from tests.providers.aws.utils import (
    AWS_ACCOUNT_NUMBER,
    AWS_REGION_US_EAST_1,
    set_mocked_aws_provider,
)

DMS_ENDPOINT_NAME = "dms-endpoint"
DMS_ENDPOINT_ARN = f"arn:aws:dms:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:endpoint:{DMS_ENDPOINT_NAME}"
DMS_INSTANCE_NAME = "rep-instance"
DMS_INSTANCE_ARN = (
    f"arn:aws:dms:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:rep:{DMS_INSTANCE_NAME}"
)
DMS_REPLICATION_TASK_ARN = (
    f"arn:aws:dms:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:task:rep-task"
)


class Test_dms_replication_task_source_logging_enabled:
    def test_no_dms_replication_tasks(self):
        dms_client = mock.MagicMock()
        dms_client.replication_tasks = {}

        with (
            mock.patch(
                "prowler.providers.common.provider.Provider.get_global_provider",
                return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]),
            ),
            mock.patch(
                "prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled.dms_client",
                new=dms_client,
            ),
        ):
            # Test Check
            from prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled import (
                dms_replication_task_source_logging_enabled,
            )

            check = dms_replication_task_source_logging_enabled()
            result = check.execute()

            assert len(result) == 0

    def test_dms_replication_task_logging_not_enabled(self):
        dms_client = mock.MagicMock()
        dms_client.replication_tasks = {
            DMS_REPLICATION_TASK_ARN: ReplicationTasks(
                arn=DMS_REPLICATION_TASK_ARN,
                id="rep-task",
                region=AWS_REGION_US_EAST_1,
                source_endpoint_arn=DMS_ENDPOINT_ARN,
                target_endpoint_arn=DMS_ENDPOINT_ARN,
                logging_enabled=False,
                log_components=[
                    {"Id": "SOURCE_CAPTURE", "Severity": "LOGGER_SEVERITY_DEFAULT"}
                ],
                tags=[],
            )
        }

        with (
            mock.patch(
                "prowler.providers.common.provider.Provider.get_global_provider",
                return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]),
            ),
            mock.patch(
                "prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled.dms_client",
                new=dms_client,
            ),
        ):
            # Test Check
            from prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled import (
                dms_replication_task_source_logging_enabled,
            )

            check = dms_replication_task_source_logging_enabled()
            result = check.execute()

            assert len(result) == 1
            assert result[0].status == "FAIL"
            assert result[0].status_extended == (
                "DMS Replication Task rep-task does not have logging enabled for source events."
            )
            assert result[0].resource_id == "rep-task"
            assert result[0].resource_arn == DMS_REPLICATION_TASK_ARN
            assert result[0].resource_tags == []
            assert result[0].region == "us-east-1"

    def test_dms_replication_task_logging_enabled_source_capture_only(self):
        dms_client = mock.MagicMock()
        dms_client.replication_tasks = {
            DMS_REPLICATION_TASK_ARN: ReplicationTasks(
                arn=DMS_REPLICATION_TASK_ARN,
                id="rep-task",
                region=AWS_REGION_US_EAST_1,
                source_endpoint_arn=DMS_ENDPOINT_ARN,
                target_endpoint_arn=DMS_ENDPOINT_ARN,
                logging_enabled=True,
                log_components=[
                    {"Id": "SOURCE_CAPTURE", "Severity": "LOGGER_SEVERITY_DEFAULT"}
                ],
                tags=[],
            )
        }

        with (
            mock.patch(
                "prowler.providers.common.provider.Provider.get_global_provider",
                return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]),
            ),
            mock.patch(
                "prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled.dms_client",
                new=dms_client,
            ),
        ):
            # Test Check
            from prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled import (
                dms_replication_task_source_logging_enabled,
            )

            check = dms_replication_task_source_logging_enabled()
            result = check.execute()

            assert len(result) == 1
            assert result[0].status == "FAIL"
            assert result[0].status_extended == (
                "DMS Replication Task rep-task does not meet the minimum severity level of logging in Source Unload events."
            )
            assert result[0].resource_id == "rep-task"
            assert result[0].resource_arn == DMS_REPLICATION_TASK_ARN
            assert result[0].resource_tags == []
            assert result[0].region == "us-east-1"

    def test_dms_replication_task_logging_enabled_source_unload_only(self):
        dms_client = mock.MagicMock()
        dms_client.replication_tasks = {
            DMS_REPLICATION_TASK_ARN: ReplicationTasks(
                arn=DMS_REPLICATION_TASK_ARN,
                id="rep-task",
                region=AWS_REGION_US_EAST_1,
                source_endpoint_arn=DMS_ENDPOINT_ARN,
                target_endpoint_arn=DMS_ENDPOINT_ARN,
                logging_enabled=True,
                log_components=[
                    {"Id": "SOURCE_UNLOAD", "Severity": "LOGGER_SEVERITY_DEFAULT"}
                ],
                tags=[],
            )
        }

        with (
            mock.patch(
                "prowler.providers.common.provider.Provider.get_global_provider",
                return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]),
            ),
            mock.patch(
                "prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled.dms_client",
                new=dms_client,
            ),
        ):
            # Test Check
            from prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled import (
                dms_replication_task_source_logging_enabled,
            )

            check = dms_replication_task_source_logging_enabled()
            result = check.execute()

            assert len(result) == 1
            assert result[0].status == "FAIL"
            assert result[0].status_extended == (
                "DMS Replication Task rep-task does not meet the minimum severity level of logging in Source Capture events."
            )
            assert result[0].resource_id == "rep-task"
            assert result[0].resource_arn == DMS_REPLICATION_TASK_ARN
            assert result[0].resource_tags == []
            assert result[0].region == "us-east-1"

    def test_dms_replication_task_logging_enabled_source_unload_capture_with_not_enough_severity_on_capture(
        self,
    ):
        dms_client = mock.MagicMock()
        dms_client.replication_tasks = {
            DMS_REPLICATION_TASK_ARN: ReplicationTasks(
                arn=DMS_REPLICATION_TASK_ARN,
                id="rep-task",
                region=AWS_REGION_US_EAST_1,
                source_endpoint_arn=DMS_ENDPOINT_ARN,
                target_endpoint_arn=DMS_ENDPOINT_ARN,
                logging_enabled=True,
                log_components=[
                    {"Id": "SOURCE_CAPTURE", "Severity": "LOGGER_SEVERITY_INFO"},
                    {"Id": "SOURCE_UNLOAD", "Severity": "LOGGER_SEVERITY_DEFAULT"},
                ],
                tags=[],
            )
        }

        with (
            mock.patch(
                "prowler.providers.common.provider.Provider.get_global_provider",
                return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]),
            ),
            mock.patch(
                "prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled.dms_client",
                new=dms_client,
            ),
        ):
            # Test Check
            from prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled import (
                dms_replication_task_source_logging_enabled,
            )

            check = dms_replication_task_source_logging_enabled()
            result = check.execute()

            assert len(result) == 1
            assert result[0].status == "FAIL"
            assert result[0].status_extended == (
                "DMS Replication Task rep-task does not meet the minimum severity level of logging in Source Capture events."
            )
            assert result[0].resource_id == "rep-task"
            assert result[0].resource_arn == DMS_REPLICATION_TASK_ARN
            assert result[0].resource_tags == []
            assert result[0].region == "us-east-1"

    def test_dms_replication_task_logging_enabled_source_unload_capture_with_not_enough_severity_on_unload(
        self,
    ):
        dms_client = mock.MagicMock()
        dms_client.replication_tasks = {
            DMS_REPLICATION_TASK_ARN: ReplicationTasks(
                arn=DMS_REPLICATION_TASK_ARN,
                id="rep-task",
                region=AWS_REGION_US_EAST_1,
                source_endpoint_arn=DMS_ENDPOINT_ARN,
                target_endpoint_arn=DMS_ENDPOINT_ARN,
                logging_enabled=True,
                log_components=[
                    {"Id": "SOURCE_CAPTURE", "Severity": "LOGGER_SEVERITY_DEFAULT"},
                    {"Id": "SOURCE_UNLOAD", "Severity": "LOGGER_SEVERITY_INFO"},
                ],
                tags=[],
            )
        }

        with (
            mock.patch(
                "prowler.providers.common.provider.Provider.get_global_provider",
                return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]),
            ),
            mock.patch(
                "prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled.dms_client",
                new=dms_client,
            ),
        ):
            # Test Check
            from prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled import (
                dms_replication_task_source_logging_enabled,
            )

            check = dms_replication_task_source_logging_enabled()
            result = check.execute()

            assert len(result) == 1
            assert result[0].status == "FAIL"
            assert result[0].status_extended == (
                "DMS Replication Task rep-task does not meet the minimum severity level of logging in Source Unload events."
            )
            assert result[0].resource_id == "rep-task"
            assert result[0].resource_arn == DMS_REPLICATION_TASK_ARN
            assert result[0].resource_tags == []
            assert result[0].region == "us-east-1"

    def test_dms_replication_task_logging_enabled_source_unload_capture_with_not_enough_severity_on_both(
        self,
    ):
        dms_client = mock.MagicMock()
        dms_client.replication_tasks = {
            DMS_REPLICATION_TASK_ARN: ReplicationTasks(
                arn=DMS_REPLICATION_TASK_ARN,
                id="rep-task",
                region=AWS_REGION_US_EAST_1,
                source_endpoint_arn=DMS_ENDPOINT_ARN,
                target_endpoint_arn=DMS_ENDPOINT_ARN,
                logging_enabled=True,
                log_components=[
                    {"Id": "SOURCE_CAPTURE", "Severity": "LOGGER_SEVERITY_INFO"},
                    {"Id": "SOURCE_UNLOAD", "Severity": "LOGGER_SEVERITY_INFO"},
                ],
                tags=[],
            )
        }

        with (
            mock.patch(
                "prowler.providers.common.provider.Provider.get_global_provider",
                return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]),
            ),
            mock.patch(
                "prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled.dms_client",
                new=dms_client,
            ),
        ):
            # Test Check
            from prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled import (
                dms_replication_task_source_logging_enabled,
            )

            check = dms_replication_task_source_logging_enabled()
            result = check.execute()

            assert len(result) == 1
            assert result[0].status == "FAIL"
            assert result[0].status_extended == (
                "DMS Replication Task rep-task does not meet the minimum severity level of logging in Source Capture and Source Unload events."
            )
            assert result[0].resource_id == "rep-task"
            assert result[0].resource_arn == DMS_REPLICATION_TASK_ARN
            assert result[0].resource_tags == []
            assert result[0].region == "us-east-1"

    def test_dms_replication_task_logging_enabled_source_unload_capture_with_enough_severity_on_both(
        self,
    ):
        dms_client = mock.MagicMock()
        dms_client.replication_tasks = {
            DMS_REPLICATION_TASK_ARN: ReplicationTasks(
                arn=DMS_REPLICATION_TASK_ARN,
                id="rep-task",
                region=AWS_REGION_US_EAST_1,
                source_endpoint_arn=DMS_ENDPOINT_ARN,
                target_endpoint_arn=DMS_ENDPOINT_ARN,
                logging_enabled=True,
                log_components=[
                    {"Id": "SOURCE_CAPTURE", "Severity": "LOGGER_SEVERITY_DEFAULT"},
                    {"Id": "SOURCE_UNLOAD", "Severity": "LOGGER_SEVERITY_DEFAULT"},
                ],
                tags=[],
            )
        }

        with (
            mock.patch(
                "prowler.providers.common.provider.Provider.get_global_provider",
                return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]),
            ),
            mock.patch(
                "prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled.dms_client",
                new=dms_client,
            ),
        ):
            # Test Check
            from prowler.providers.aws.services.dms.dms_replication_task_source_logging_enabled.dms_replication_task_source_logging_enabled import (
                dms_replication_task_source_logging_enabled,
            )

            check = dms_replication_task_source_logging_enabled()
            result = check.execute()

            assert len(result) == 1
            assert result[0].status == "PASS"
            assert result[0].status_extended == (
                "DMS Replication Task rep-task has logging enabled with the minimum severity level in source events."
            )
            assert result[0].resource_id == "rep-task"
            assert result[0].resource_arn == DMS_REPLICATION_TASK_ARN
            assert result[0].resource_tags == []
            assert result[0].region == "us-east-1"
