from unittest import mock

from boto3 import client
from moto import mock_aws

from tests.providers.aws.utils import (
    AWS_ACCOUNT_NUMBER,
    AWS_REGION_EU_WEST_1,
    AWS_REGION_US_EAST_1,
    AWS_REGION_US_EAST_2,
    set_mocked_aws_provider,
)


class Test_cloudtrail_multi_region_enabled_logging_management_events:
    @mock_aws
    def test_no_trails(self):
        from prowler.providers.aws.services.cloudtrail.cloudtrail_service import (
            Cloudtrail,
        )

        aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])

        with mock.patch(
            "prowler.providers.common.provider.Provider.get_global_provider",
            return_value=aws_provider,
        ):
            with mock.patch(
                "prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_client",
                new=Cloudtrail(aws_provider),
            ):
                # Test Check
                from prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events import (
                    cloudtrail_multi_region_enabled_logging_management_events,
                )

                check = cloudtrail_multi_region_enabled_logging_management_events()
                result = check.execute()
                assert len(result) == 1
                assert result[0].resource_id == AWS_ACCOUNT_NUMBER
                assert (
                    result[0].resource_arn
                    == f"arn:aws:cloudtrail:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:trail"
                )
                assert result[0].region == AWS_REGION_US_EAST_1
                assert result[0].status == "FAIL"
                assert (
                    result[0].status_extended
                    == "No CloudTrail trails enabled and logging management events were found."
                )
                assert result[0].resource_tags == []

    @mock_aws
    def test_compliant_trail_advanced_event_selector(self):
        cloudtrail_client_us_east_1 = client(
            "cloudtrail", region_name=AWS_REGION_US_EAST_1
        )
        s3_client_us_east_1 = client("s3", region_name=AWS_REGION_US_EAST_1)
        trail_name_us = "trail_test_us"
        bucket_name_us = "bucket_test_us"
        s3_client_us_east_1.create_bucket(Bucket=bucket_name_us)
        trail_us = cloudtrail_client_us_east_1.create_trail(
            Name=trail_name_us, S3BucketName=bucket_name_us, IsMultiRegionTrail=True
        )
        _ = cloudtrail_client_us_east_1.start_logging(Name=trail_name_us)
        _ = cloudtrail_client_us_east_1.put_event_selectors(
            TrailName=trail_name_us,
            AdvancedEventSelectors=[
                {
                    "Name": "Management events selector",
                    "FieldSelectors": [
                        {"Field": "eventCategory", "Equals": ["Management"]}
                    ],
                }
            ],
        )

        from prowler.providers.aws.services.cloudtrail.cloudtrail_service import (
            Cloudtrail,
        )

        aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])

        with mock.patch(
            "prowler.providers.common.provider.Provider.get_global_provider",
            return_value=aws_provider,
        ):
            with mock.patch(
                "prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_client",
                new=Cloudtrail(aws_provider),
            ):
                # Test Check
                from prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events import (
                    cloudtrail_multi_region_enabled_logging_management_events,
                )

                check = cloudtrail_multi_region_enabled_logging_management_events()
                result = check.execute()
                assert len(result) == 1
                assert result[0].resource_id == trail_name_us
                assert result[0].resource_arn == trail_us["TrailARN"]
                assert result[0].region == AWS_REGION_US_EAST_1
                assert result[0].status == "PASS"
                assert (
                    result[0].status_extended
                    == f"Trail {trail_name_us} from home region {AWS_REGION_US_EAST_1} is multi-region, is logging and have management events enabled."
                )
                assert result[0].resource_tags == []

    @mock_aws
    def test_non_compliant_trail_advanced_event_selector(self):
        cloudtrail_client_us_east_1 = client(
            "cloudtrail", region_name=AWS_REGION_US_EAST_1
        )
        s3_client_us_east_1 = client("s3", region_name=AWS_REGION_US_EAST_1)
        trail_name_us = "trail_test_us"
        bucket_name_us = "bucket_test_us"
        s3_client_us_east_1.create_bucket(Bucket=bucket_name_us)
        _ = cloudtrail_client_us_east_1.create_trail(
            Name=trail_name_us, S3BucketName=bucket_name_us, IsMultiRegionTrail=True
        )
        _ = cloudtrail_client_us_east_1.start_logging(Name=trail_name_us)
        _ = cloudtrail_client_us_east_1.put_event_selectors(
            TrailName=trail_name_us,
            AdvancedEventSelectors=[
                {
                    "Name": "Management events selector",
                    "FieldSelectors": [
                        {"Field": "eventCategory", "Equals": ["Managment"]},
                        {"Field": "readOnly", "Equals": ["true"]},
                    ],
                }
            ],
        )

        from prowler.providers.aws.services.cloudtrail.cloudtrail_service import (
            Cloudtrail,
        )

        aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])

        with mock.patch(
            "prowler.providers.common.provider.Provider.get_global_provider",
            return_value=aws_provider,
        ):
            with mock.patch(
                "prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_client",
                new=Cloudtrail(aws_provider),
            ):
                # Test Check
                from prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events import (
                    cloudtrail_multi_region_enabled_logging_management_events,
                )

                check = cloudtrail_multi_region_enabled_logging_management_events()
                result = check.execute()
                assert result[0].resource_id == AWS_ACCOUNT_NUMBER
                assert (
                    result[0].resource_arn
                    == f"arn:aws:cloudtrail:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:trail"
                )
                assert result[0].region == AWS_REGION_US_EAST_1
                assert result[0].status == "FAIL"
                assert (
                    result[0].status_extended
                    == "No CloudTrail trails enabled and logging management events were found."
                )
                assert result[0].resource_tags == []

    @mock_aws
    def test_compliant_trail_classic_event_selector(self):
        cloudtrail_client_us_east_1 = client(
            "cloudtrail", region_name=AWS_REGION_US_EAST_1
        )
        s3_client_us_east_1 = client("s3", region_name=AWS_REGION_US_EAST_1)
        trail_name_us = "trail_test_us"
        bucket_name_us = "bucket_test_us"
        s3_client_us_east_1.create_bucket(Bucket=bucket_name_us)
        trail_us = cloudtrail_client_us_east_1.create_trail(
            Name=trail_name_us, S3BucketName=bucket_name_us, IsMultiRegionTrail=True
        )
        _ = cloudtrail_client_us_east_1.start_logging(Name=trail_name_us)
        _ = cloudtrail_client_us_east_1.put_event_selectors(
            TrailName=trail_name_us,
            EventSelectors=[
                {
                    "ReadWriteType": "All",
                    "IncludeManagementEvents": True,
                    "DataResources": [],
                }
            ],
        )

        from prowler.providers.aws.services.cloudtrail.cloudtrail_service import (
            Cloudtrail,
        )

        aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])

        with mock.patch(
            "prowler.providers.common.provider.Provider.get_global_provider",
            return_value=aws_provider,
        ):
            with mock.patch(
                "prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_client",
                new=Cloudtrail(aws_provider),
            ):
                # Test Check
                from prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events import (
                    cloudtrail_multi_region_enabled_logging_management_events,
                )

                check = cloudtrail_multi_region_enabled_logging_management_events()
                result = check.execute()
                assert len(result) == 1
                assert result[0].resource_id == trail_name_us
                assert result[0].resource_arn == trail_us["TrailARN"]
                assert result[0].region == AWS_REGION_US_EAST_1
                assert result[0].status == "PASS"
                assert (
                    result[0].status_extended
                    == f"Trail {trail_name_us} from home region {AWS_REGION_US_EAST_1} is multi-region, is logging and have management events enabled."
                )

    @mock_aws
    def test_non_compliant_trail_classic_event_selector(self):
        cloudtrail_client_us_east_1 = client(
            "cloudtrail", region_name=AWS_REGION_US_EAST_1
        )
        s3_client_us_east_1 = client("s3", region_name=AWS_REGION_US_EAST_1)
        trail_name_us = "trail_test_us"
        bucket_name_us = "bucket_test_us"
        s3_client_us_east_1.create_bucket(Bucket=bucket_name_us)
        _ = cloudtrail_client_us_east_1.create_trail(
            Name=trail_name_us, S3BucketName=bucket_name_us, IsMultiRegionTrail=True
        )
        _ = cloudtrail_client_us_east_1.start_logging(Name=trail_name_us)
        _ = cloudtrail_client_us_east_1.put_event_selectors(
            TrailName=trail_name_us,
            EventSelectors=[
                {
                    "ReadWriteType": "ReadOnly",
                    "IncludeManagementEvents": False,
                    "DataResources": [],
                }
            ],
        )

        from prowler.providers.aws.services.cloudtrail.cloudtrail_service import (
            Cloudtrail,
        )

        aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])

        with mock.patch(
            "prowler.providers.common.provider.Provider.get_global_provider",
            return_value=aws_provider,
        ):
            with mock.patch(
                "prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_client",
                new=Cloudtrail(aws_provider),
            ):
                # Test Check
                from prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events import (
                    cloudtrail_multi_region_enabled_logging_management_events,
                )

                check = cloudtrail_multi_region_enabled_logging_management_events()
                result = check.execute()
                assert result[0].resource_id == AWS_ACCOUNT_NUMBER
                assert (
                    result[0].resource_arn
                    == f"arn:aws:cloudtrail:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:trail"
                )
                assert result[0].region == AWS_REGION_US_EAST_1
                assert result[0].status == "FAIL"
                assert (
                    result[0].status_extended
                    == "No CloudTrail trails enabled and logging management events were found."
                )
                assert result[0].resource_tags == []

    @mock_aws
    def test_access_denied(self):
        from prowler.providers.aws.services.cloudtrail.cloudtrail_service import (
            Cloudtrail,
        )

        aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])

        with mock.patch(
            "prowler.providers.common.provider.Provider.get_global_provider",
            return_value=aws_provider,
        ):
            with mock.patch(
                "prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_client",
                new=Cloudtrail(aws_provider),
            ) as cloudtrail_client:
                # Test Check
                from prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events import (
                    cloudtrail_multi_region_enabled_logging_management_events,
                )

                cloudtrail_client.trails = None
                check = cloudtrail_multi_region_enabled_logging_management_events()
                result = check.execute()
                assert len(result) == 0

    def test_no_trails_two_regions(self):
        from prowler.providers.aws.services.cloudtrail.cloudtrail_service import (
            Cloudtrail,
        )

        aws_provider = set_mocked_aws_provider(
            [AWS_REGION_US_EAST_1, AWS_REGION_US_EAST_2]
        )

        with mock.patch(
            "prowler.providers.common.provider.Provider.get_global_provider",
            return_value=aws_provider,
        ):
            with mock.patch(
                "prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_client",
                new=Cloudtrail(aws_provider),
            ):
                # Test Check
                from prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events import (
                    cloudtrail_multi_region_enabled_logging_management_events,
                )

                check = cloudtrail_multi_region_enabled_logging_management_events()
                result = check.execute()
                assert len(result) == 2
                for report in result:
                    if report.region == AWS_REGION_US_EAST_1:
                        assert (
                            report.resource_arn
                            == f"arn:aws:cloudtrail:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:trail"
                        )
                        assert report.resource_id == AWS_ACCOUNT_NUMBER
                        assert report.status == "FAIL"
                        assert (
                            report.status_extended
                            == "No CloudTrail trails enabled and logging management events were found."
                        )
                        assert report.resource_tags == []
                        assert report.region == AWS_REGION_US_EAST_1
                    elif report.region == AWS_REGION_US_EAST_2:
                        # FIXME: the resource_arn here is the global one and should be using the region audited not the global region
                        assert (
                            report.resource_arn
                            == f"arn:aws:cloudtrail:{AWS_REGION_US_EAST_2}:{AWS_ACCOUNT_NUMBER}:trail"
                        )
                        assert report.resource_id == AWS_ACCOUNT_NUMBER
                        assert report.status == "FAIL"
                        assert (
                            report.status_extended
                            == "No CloudTrail trails enabled and logging management events were found."
                        )
                        assert report.resource_tags == []
                        assert report.region == AWS_REGION_US_EAST_2

    @mock_aws
    def test_trail_multi_region_auditing_other_region(self):
        aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
        cloudtrail_client_us_east_1 = client(
            "cloudtrail", region_name=AWS_REGION_US_EAST_1
        )
        s3_client_us_east_1 = client("s3", region_name=AWS_REGION_US_EAST_1)

        trail_name_us = "trail_test_us"
        bucket_name_us = "bucket_test_us"

        s3_client_us_east_1.create_bucket(Bucket=bucket_name_us)

        _ = cloudtrail_client_us_east_1.create_trail(
            Name=trail_name_us, S3BucketName=bucket_name_us, IsMultiRegionTrail=True
        )
        _ = cloudtrail_client_us_east_1.start_logging(Name=trail_name_us)
        _ = cloudtrail_client_us_east_1.put_event_selectors(
            TrailName=trail_name_us,
            EventSelectors=[
                {
                    "ReadWriteType": "All",
                    "IncludeManagementEvents": True,
                    "DataResources": [],
                }
            ],
        )
        from prowler.providers.aws.services.cloudtrail.cloudtrail_service import (
            Cloudtrail,
        )

        with mock.patch(
            "prowler.providers.common.provider.Provider.get_global_provider",
            return_value=aws_provider,
        ):
            with mock.patch(
                "prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_client",
                new=Cloudtrail(aws_provider),
            ):
                # Test Check
                from prowler.providers.aws.services.cloudtrail.cloudtrail_multi_region_enabled_logging_management_events.cloudtrail_multi_region_enabled_logging_management_events import (
                    cloudtrail_multi_region_enabled_logging_management_events,
                )

                check = cloudtrail_multi_region_enabled_logging_management_events()
                result = check.execute()
                assert len(result) == 1
                # FAIL since it is a multi-region trail in a region that Prowler is not auditing, so it is not possible to check the trail's data events
                assert result[0].resource_id == AWS_ACCOUNT_NUMBER
                assert (
                    result[0].resource_arn
                    == f"arn:aws:cloudtrail:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:trail"
                )
                assert result[0].status == "FAIL"
                assert (
                    result[0].status_extended
                    == "No CloudTrail trails enabled and logging management events were found."
                )
                assert result[0].region == AWS_REGION_EU_WEST_1
                assert result[0].resource_tags == []
