import logging
import threading
from http.server import HTTPServer
from ipaddress import IPv4Address
from typing import Callable, Optional, Type

from agentpluginapi import (
    AgentBinaryDownloadReservation,
    AgentBinaryDownloadTicket,
    ITCPPortSelector,
    LocalMachineInfo,
    ReservationID,
)
from monkeytoolbox import (
    create_daemon_thread,
    insecure_generate_random_string,
    secure_generate_random_string,
)
from monkeytypes import Event, Lock, NetworkPort, OperatingSystem

from .http_agent_binary_request_handler import AgentBinaryHTTPRequestHandler

logger = logging.getLogger(__name__)

AgentBinaryHTTPHandlerFactory = Callable[[], Type[AgentBinaryHTTPRequestHandler]]


class HTTPAgentBinaryServer:
    """
    Serves Agent binaries over HTTP

    Allows clients to register for an Agent binary to be served. The server will serve the
    requested binary until it is deregistered or the server is stopped.

    :param tcp_port_selector: The TCP port selector to use
    :param get_handler_class: A function that returns the HTTP handler class to use
    :param create_event: A function that the server will use to create events
    :param lock: A lock to use
    :param poll_interval: The interval to poll for server shutdown, in seconds
    """

    def __init__(
        self,
        local_machine_info: LocalMachineInfo,
        tcp_port_selector: ITCPPortSelector,
        get_handler_class: AgentBinaryHTTPHandlerFactory,
        create_event: Callable[[], Event],
        lock: Lock,
        poll_interval: float = 0.5,
    ):
        self._local_machine_info = local_machine_info
        self._tcp_port_selector = tcp_port_selector
        self._handler_class = get_handler_class()
        self._create_event = create_event
        self._lock = lock
        self._poll_interval = poll_interval
        self._port: Optional[NetworkPort] = None
        self._server: Optional[HTTPServer] = None
        self._server_thread: Optional[threading.Thread] = None

    def register(
        self,
        operating_system: OperatingSystem,
        requestor_ip: IPv4Address,
        agent_binary_wrapper_template: bytes | None = None,
    ) -> AgentBinaryDownloadTicket:
        """
        Register to download an Agent binary

        If the server is not running, it will be started.

        :param operating_system: The operating system for the Agent binary to serve
        :param requestor_ip: The IP address of the client that will download the Agent binary
        :param agent_binary_wrapper_template: A template that transforms the Agent binary
            before serving.
            This may be used to, e.g., convert the binary into a self-extracting shell script.
        :raises RuntimeError: If the binary could not be served
        :raises Exception: If the server failed to start
        :returns: A ticket to download the Agent binary
        """
        with self._lock:
            if not self.server_is_running():
                self._start_server()

            reservation_id = secure_generate_random_string(n=5)
            url = self._build_request_url(reservation_id, operating_system, requestor_ip)
            reservation = AgentBinaryDownloadReservation(
                reservation_id,
                operating_system,
                agent_binary_wrapper_template,
                url,
                self._create_event(),
            )
            self._handler_class.reserve_download(reservation)

            return AgentBinaryDownloadTicket(reservation_id, url, reservation.download_completed)

    def _build_request_url(
        self,
        reservation_id: ReservationID,
        operating_system: OperatingSystem,
        requestor_ip: IPv4Address,
    ) -> str:
        interface_to_target = self._local_machine_info.get_interface_to_target(requestor_ip)

        if interface_to_target is None:
            raise RuntimeError(
                f"Could not find an interface to the target {requestor_ip} to serve Agent binaries"
            )

        server_ip = interface_to_target.ip
        return f"http://{server_ip}:{self._port}/{operating_system.value}/{reservation_id}"

    def server_is_running(self) -> bool:
        return self._server_thread is not None and self._server_thread.is_alive()

    def _start_server(self):
        if self._server is None:
            self._server = self._create_server()
        if self._server_thread is None:
            self._server_thread = self._create_server_thread(self._server)
            self._server_thread.start()

    def _create_server(self) -> HTTPServer:
        self._port = self._tcp_port_selector.get_free_tcp_port(
            # Allow 443, 80 in the future?
            preferred_ports=list(map(NetworkPort, [8080, 8008, 8000, 8443]))
        )
        if self._port is None:
            raise RuntimeError("Could not find a free TCP port to serve Agent binaries")

        return HTTPServer(("0.0.0.0", int(self._port)), self._handler_class)

    def _create_server_thread(self, server: HTTPServer) -> threading.Thread:
        thread_name = f"HTTPAgentBinaryServer-{insecure_generate_random_string(n=8)}"
        return create_daemon_thread(
            target=server.serve_forever,
            name=thread_name,
            args=(self._poll_interval,),
        )

    def deregister(self, reservation_id: ReservationID) -> None:
        """
        Deregister an Agent binary from being served

        :param reservation_id: The ID of the reservation to deregister
        :raises KeyError: If the reservation ID is not registered
        """
        with self._lock:
            self._handler_class.clear_reservation(reservation_id)

    def start(self):
        """
        Start the server

        :raises Exception: If the server failed to start
        """
        if not self.server_is_running():
            logger.debug("Starting the HTTP server")
            self._start_server()

    def stop(self, timeout: Optional[float] = None):
        """
        Stop the server

        :param timeout: The maximum amount of time to wait for the server to stop, in seconds. If
            not provided or set to None, it will block until the server shuts down
        """
        if self._server is None or self._server_thread is None:
            return

        if self._server_thread.is_alive():
            logger.debug("Stopping the HTTP server")
            self._server.shutdown()
            self._server_thread.join(timeout)

        if self._server_thread.is_alive():
            logger.warning("Timed out waiting for HTTP server to stop")
        else:
            logger.debug("The HTTP server has stopped")
