"""t2st.py

Created by Takeru Hayasaka on 2023-01-21.
Copyright (c) 2023 BBSakura Networks Inc. All rights reserved.
"""

from __future__ import annotations

from struct import pack
from typing import TYPE_CHECKING, Any, ClassVar

if TYPE_CHECKING:
    from exabgp.bgp.message.open.capability.negotiated import Negotiated

from exabgp.bgp.message import Action
from exabgp.bgp.message.update.nlri.mup.nlri import MUP
from exabgp.bgp.message.update.nlri.nlri import NLRI
from exabgp.bgp.message.update.nlri.qualifier import RouteDistinguisher
from exabgp.protocol.family import AFI, SAFI
from exabgp.protocol.ip import IP
from exabgp.util.types import Buffer

# https://datatracker.ietf.org/doc/html/draft-mpmz-bess-mup-safi-03

# +-----------------------------------+
# |           RD  (8 octets)          |
# +-----------------------------------+
# |      Endpoint Length (1 octet)    |
# +-----------------------------------+
# |      Endpoint Address (variable)  |
# +-----------------------------------+
# | Architecture specific Endpoint    |
# |         Identifier (variable)     |
# +-----------------------------------+

# 3gpp-5g Specific BGP Type 2 ST Route
# +-----------------------------------+
# |          TEID (0-4 octets)        |
# +-----------------------------------+

# MUP Type 2 Session Transformed Route constants
MUP_T2ST_IPV4_SIZE_BITS: int = 32  # IPv4 address size in bits
MUP_T2ST_IPV6_SIZE_BITS: int = 128  # IPv6 address size in bits
MUP_T2ST_TEID_MAX_SIZE: int = 32  # Maximum TEID size in bits
MUP_T2ST_IPV4_MAX_ENDPOINT: int = 64  # Max endpoint length for IPv4 (32 IP + 32 TEID)
MUP_T2ST_IPV6_MAX_ENDPOINT: int = 160  # Max endpoint length for IPv6 (128 IP + 32 TEID)


@MUP.register_mup_route(archtype=1, code=4)
class Type2SessionTransformedRoute(MUP):
    NAME: ClassVar[str] = 'Type2SessionTransformedRoute'
    SHORT_NAME: ClassVar[str] = 'T2ST'

    # Wire format offsets (after 4-byte header: arch(1) + code(2) + length(1))
    HEADER_SIZE: ClassVar[int] = 4
    RD_OFFSET: ClassVar[int] = 4  # Bytes 4-11: RD (8 bytes)
    ENDPOINT_LEN_OFFSET: ClassVar[int] = 12  # Byte 12: endpoint length
    ENDPOINT_IP_OFFSET: ClassVar[int] = 13  # Bytes 13+: endpoint IP

    def __init__(self, packed: Buffer, afi: AFI) -> None:
        """Create T2ST with complete wire format.

        Args:
            packed: Complete wire format including 4-byte header
        """
        MUP.__init__(self, afi)
        self._packed: Buffer = packed

    @classmethod
    def make_t2st(
        cls,
        rd: RouteDistinguisher,
        endpoint_len: int,
        endpoint_ip: IP,
        teid: int,
        afi: AFI,
    ) -> 'Type2SessionTransformedRoute':
        """Factory method to create T2ST from semantic parameters."""
        payload = bytes(rd.pack_rd()) + pack('!B', endpoint_len) + endpoint_ip.pack_ip()

        endpoint_size = MUP_T2ST_IPV4_SIZE_BITS if endpoint_ip.afi == AFI.ipv4 else MUP_T2ST_IPV6_SIZE_BITS
        teid_size = endpoint_len - endpoint_size

        if teid_size < 0 or teid_size > MUP_T2ST_TEID_MAX_SIZE:
            raise Exception('teid is too large %d (range 0~32)' % teid_size)

        teid_packed = pack('!I', teid)

        offset = teid_size // 8
        remainder = teid_size % 8
        if remainder != 0:
            offset += 1

        if teid_size > 0:
            payload += teid_packed[-offset:]

        # Include 4-byte header: arch(1) + code(2) + length(1) + payload
        packed = pack('!BHB', cls.ARCHTYPE, cls.CODE, len(payload)) + payload
        return cls(packed, afi)

    @property
    def rd(self) -> RouteDistinguisher:
        # Offset by 4-byte header: RD at bytes 4-11
        return RouteDistinguisher.unpack_routedistinguisher(self._packed[4:12])

    @property
    def endpoint_len(self) -> int:
        # Offset by 4-byte header: endpoint_len at byte 12
        return self._packed[12]

    @property
    def endpoint_ip(self) -> IP:
        afi_bytes_size = 4 if self.afi == AFI.ipv4 else 16
        # Offset by 4-byte header: endpoint_ip at bytes 13+
        return IP.create_ip(self._packed[13 : 13 + afi_bytes_size])

    @property
    def teid(self) -> int:
        afi_bit_size = MUP_T2ST_IPV4_SIZE_BITS if self.afi == AFI.ipv4 else MUP_T2ST_IPV6_SIZE_BITS
        afi_bytes_size = 4 if self.afi == AFI.ipv4 else 16
        # Offset by 4-byte header: teid after endpoint_ip
        end = 13 + afi_bytes_size
        if self.endpoint_len > afi_bit_size:
            return int.from_bytes(self._packed[end:], 'big')
        return 0

    def index(self) -> bytes:
        return MUP.index(self)

    def __eq__(self, other: Any) -> bool:
        return (
            isinstance(other, Type2SessionTransformedRoute)
            and self.rd == other.rd
            and self.teid == other.teid
            and self.endpoint_len == self.endpoint_len
            and self.endpoint_ip == other.endpoint_ip
        )

    def __ne__(self, other: Any) -> bool:
        return not self.__eq__(other)

    def __str__(self) -> str:
        return '{}:{}:{}:{}:{}:'.format(
            self._prefix(),
            self.rd._str(),
            self.endpoint_len,
            self.endpoint_ip,
            self.teid,
        )

    def __hash__(self) -> int:
        # Direct _packed hash - include afi since MUP supports both IPv4 and IPv6
        return hash((self.afi, self._packed))

    @classmethod
    def unpack_nlri(
        cls, afi: AFI, safi: SAFI, data: Buffer, action: Action, addpath: Any, negotiated: Negotiated
    ) -> tuple[NLRI, Buffer]:
        # Parent provides complete wire format including 4-byte header
        # Offsets: header(0-3), RD(4-11), endpoint_len(12)
        afi_bit_size = MUP_T2ST_IPV4_SIZE_BITS if afi == AFI.ipv4 else MUP_T2ST_IPV6_SIZE_BITS
        endpoint_len = data[12]

        if endpoint_len > afi_bit_size:
            teid_len = endpoint_len - afi_bit_size
            if afi == AFI.ipv4 and teid_len > MUP_T2ST_TEID_MAX_SIZE:
                raise Exception(
                    'endpoint length is too large %d (max %d for Ipv4)' % (endpoint_len, MUP_T2ST_IPV4_MAX_ENDPOINT)
                )
            if afi == AFI.ipv6 and teid_len > MUP_T2ST_TEID_MAX_SIZE:
                raise Exception(
                    'endpoint length is too large %d (max %d for Ipv6)' % (endpoint_len, MUP_T2ST_IPV6_MAX_ENDPOINT)
                )

        instance = cls(data, afi)
        return instance, b''

    def json(self, announced: bool = True, compact: bool | None = None) -> str:
        content = '"name": "{}", '.format(self.NAME)
        content += ' "arch": %d, ' % self.ARCHTYPE
        content += '"code": %d, ' % self.CODE
        content += '"endpoint_len": %d, ' % self.endpoint_len
        content += '"endpoint_ip": "{}", '.format(str(self.endpoint_ip))
        content += self.rd.json() + ', '
        content += '"teid": "{}", '.format(str(self.teid))
        content += '"raw": "{}"'.format(self._raw())
        return '{{ {} }}'.format(content)
