# -*- coding: utf-8 -*-
# Copyright: (c) 2021, Jordan Borean (@jborean93) <jborean93@gmail.com>
# MIT License (see LICENSE or https://opensource.org/licenses/MIT)

import base64
import struct
import typing

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.padding import PKCS7

from psrpcore._exceptions import MissingCipherError
from psrpcore.types import PSCryptoProvider


def create_keypair() -> typing.Tuple[rsa.RSAPrivateKey, bytes]:
    """Create RSA keypair.

    Create an RSA keypair for use with the PSRemoting session key exchange.

    Returns:
        Tuple[rsa.RSAPrivateKey, bytes]: The RSA Private key generated and the
            public key that can be send to the remote PSSession.
    """
    private_key = rsa.generate_private_key(
        public_exponent=65537,
        key_size=2048,
        backend=default_backend(),
    )
    public_numbers = private_key.public_key().public_numbers()
    exponent = struct.pack("<I", public_numbers.e)

    b_modulus = bytearray()
    modulus = public_numbers.n
    while modulus:
        val = modulus & 0xFF
        b_modulus.append(val)
        modulus >>= 8

    # The public key bytes follow a set structure defined in MS-PSRP.
    public_key_bytes = b"\x06\x02\x00\x00\x00\xa4\x00\x00\x52\x53\x41\x31\x00\x08\x00\x00" + exponent + bytes(b_modulus)

    return private_key, public_key_bytes


def encrypt_session_key(
    exchange_key: bytes,
    session_key: bytes,
) -> bytes:
    """Encrypt session key.

    Encrypts the PSRemoting session key generated by the server to be send to
    the client.

    Args:
        exchange_key: The public key generated by the client to encrypt the
            session key with.
        session_key: The session key to encrypt.

    Returns:
        bytes: The encrypted session key to send to the remote PSSession.
    """
    # Exchange key contains header information used by MS Crypto but we don't use them here.
    exponent = struct.unpack("<I", exchange_key[16:20])[0]
    b_modulus = exchange_key[20:]
    shift = 0
    modulus = 0
    for b in b_modulus:
        modulus += b << (8 * shift)
        shift += 1

    public_key = rsa.RSAPublicNumbers(exponent, modulus).public_key(default_backend())

    encrypted_key = public_key.encrypt(
        session_key,
        padding.PKCS1v15(),
    )[::-1]
    encrypted_key_bytes = b"\x01\x02\x00\x00\x10\x66\x00\x00\x00\xa4\x00\x00" + encrypted_key

    return encrypted_key_bytes


def decrypt_session_key(
    exchange_key: rsa.RSAPrivateKey,
    encrypted_session_key: bytes,
) -> bytes:
    """Decrypt session key.

    Decrypts the PSRemoting session key generated by the server.

    Args:
        exchange_key: The RSA private key that can decrypt the encrypted
            session key.
        encrypted_session_key: The encrypted session key received from the
            server.

    Returns:
        bytes: The decrypted session key sent from the server.
    """
    # Strip off Win32 Crypto Blob Header and reverse the bytes.
    encrypted_key = encrypted_session_key[12:][::-1]
    decrypted_key = exchange_key.decrypt(encrypted_key, padding.PKCS1v15())

    return decrypted_key


class PSRemotingCrypto(PSCryptoProvider):
    """PSCryptoProvider used by PSRP for serializing SecureStrings."""

    def __init__(self) -> None:
        self._cipher: typing.Optional[Cipher] = None
        self._padding = PKCS7(algorithms.AES.block_size)

    def decrypt(self, value: str) -> str:
        if not self._cipher:
            raise MissingCipherError()

        b_enc = base64.b64decode(value)

        decryptor = self._cipher.decryptor()
        b_padded = decryptor.update(b_enc) + decryptor.finalize()

        unpadder = self._padding.unpadder()
        b_dec = unpadder.update(b_padded) + unpadder.finalize()

        return b_dec.decode("utf-16-le", errors="surrogatepass")

    def encrypt(self, value: str) -> str:
        if not self._cipher:
            raise MissingCipherError()

        b_value = value.encode("utf-16-le", errors="surrogatepass")

        padder = self._padding.padder()
        b_padded = padder.update(b_value) + padder.finalize()

        encryptor = self._cipher.encryptor()
        b_enc = encryptor.update(b_padded) + encryptor.finalize()

        return base64.b64encode(b_enc).decode()

    def register_key(
        self,
        key: bytes,
    ) -> None:
        algorithm = algorithms.AES(key)
        mode = modes.CBC(b"\x00" * 16)  # PSRP doesn't use an IV
        self._cipher = Cipher(algorithm, mode, default_backend())


__all__ = [
    "create_keypair",
    "decrypt_session_key",
    "encrypt_session_key",
    "rsa",
    "PSRemotingCrypto",
]
