import io
import logging
from multiprocessing.managers import SyncManager

from agentpluginapi import IAgentBinaryRepository, RetrievalError
from monkeytypes import OperatingSystem

from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIError

logger = logging.getLogger(__name__)


class CachingAgentBinaryRepository(IAgentBinaryRepository):
    """
    CachingAgentBinaryRepository implements the IAgentBinaryRepository interface and downloads the
    requested agent binary from the island on request. The agent binary is cached so that only one
    request is actually sent to the island for each requested binary.
    """

    def __init__(self, island_api_client: IIslandAPIClient, manager: SyncManager):
        self._lock = manager.Lock()
        self._cache = manager.dict()
        self._island_api_client = island_api_client

    def get_agent_binary(self, operating_system: OperatingSystem) -> io.BytesIO:
        # If multiple calls to get_agent_binary() are made simultaneously before the result of
        # _download_binary_from_island() is cached, then multiple requests will be sent to the
        # island. Add a mutex in front of the call to _download_agent_binary_from_island() so
        # that only one request per OS will be sent to the island.
        with self._lock:
            return io.BytesIO(self._download_binary_from_island(operating_system))

    def _download_binary_from_island(self, operating_system: OperatingSystem) -> bytes:
        if operating_system in self._cache:
            return self._cache[operating_system]

        try:
            data = self._island_api_client.get_agent_binary(operating_system)
            self._cache[operating_system] = data
            return data
        except IslandAPIError as err:
            raise RetrievalError(err)
