/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.quic.crypto;

import de.rub.nds.protocol.exception.CryptoException;
import de.rub.nds.tlsattacker.core.quic.packet.HandshakePacket;
import de.rub.nds.tlsattacker.core.quic.packet.InitialPacket;
import de.rub.nds.tlsattacker.core.quic.packet.OneRTTPacket;
import de.rub.nds.tlsattacker.core.quic.packet.QuicPacket;
import de.rub.nds.tlsattacker.core.quic.packet.QuicPacketCryptoComputations;
import de.rub.nds.tlsattacker.core.state.quic.QuicContext;
import de.rub.nds.tlsattacker.transport.ConnectionEndType;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.security.spec.AlgorithmParameterSpec;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class QuicEncryptor {
    private static final Logger LOGGER = LogManager.getLogger();
    private final QuicContext context;

    public QuicEncryptor(QuicContext context) {
        this.context = context;
    }

    public void addHeaderProtectionInitial(InitialPacket packet) throws CryptoException {
        this.addHeaderProtection(packet, QuicPacketCryptoComputations.generateInitialServerHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()), QuicPacketCryptoComputations.generateInitialClientHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()));
    }

    public void addHeaderProtectionHandshake(HandshakePacket packet) throws CryptoException {
        this.addHeaderProtection(packet, QuicPacketCryptoComputations.generateHandshakeServerHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()), QuicPacketCryptoComputations.generateHandshakeClientHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()));
    }

    public void addHeaderProtectionZeroRTT(QuicPacket packet) throws CryptoException {
        this.addHeaderProtection(packet, QuicPacketCryptoComputations.generateZeroRTTServerHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()), QuicPacketCryptoComputations.generateZeroRTTClientHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()));
    }

    public void addHeaderProtectionOneRRT(QuicPacket packet) throws CryptoException {
        this.addHeaderProtection(packet, QuicPacketCryptoComputations.generateOneRTTServerHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()), QuicPacketCryptoComputations.generateOneRRTClientHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()));
    }

    private void addHeaderProtection(QuicPacket packet, byte[] serverHeaderProtectionMask, byte[] clientHeaderProtectionMask) {
        byte[] headerProtectionMask;
        ConnectionEndType connectionEndType = this.context.getTalkingConnectionEndType();
        switch (connectionEndType) {
            case SERVER: {
                headerProtectionMask = serverHeaderProtectionMask;
                break;
            }
            case CLIENT: {
                headerProtectionMask = clientHeaderProtectionMask;
                break;
            }
            default: {
                LOGGER.error("Unknown connectionEndType: {}", (Object)connectionEndType);
                return;
            }
        }
        byte flags = (Byte)packet.getUnprotectedFlags().getValue();
        byte hpMask = headerProtectionMask[0];
        byte encryptedFlags = packet instanceof OneRTTPacket ? (byte)(flags ^ hpMask & 0x1F) : (byte)(flags ^ hpMask & 0xF);
        packet.setProtectedFlags(encryptedFlags);
        byte[] unprotectedPacketNumber = (byte[])packet.getUnprotectedPacketNumber().getValue();
        byte[] result = new byte[((Integer)packet.getPacketNumberLength().getValue()).intValue()];
        for (int i = 0; i < (Integer)packet.getPacketNumberLength().getValue(); ++i) {
            result[i] = (byte)(unprotectedPacketNumber[i] ^ headerProtectionMask[i + 1]);
        }
        packet.setProtectedPacketNumber(result);
    }

    public void encryptInitialPacket(InitialPacket packet) throws CryptoException {
        this.encrypt(packet, this.context.getInitialServerIv(), this.context.getInitialServerKey(), this.context.getInitialClientIv(), this.context.getInitialClientKey(), this.context.getInitialAeadCipher());
    }

    public void encryptHandshakePacket(HandshakePacket packet) throws CryptoException {
        this.encrypt(packet, this.context.getHandshakeServerIv(), this.context.getHandshakeServerKey(), this.context.getHandshakeClientIv(), this.context.getHandshakeClientKey(), this.context.getAeadCipher());
    }

    public void encryptOneRRTPacket(QuicPacket packet) throws CryptoException {
        this.encrypt(packet, this.context.getApplicationServerIv(), this.context.getApplicationServerKey(), this.context.getApplicationClientIv(), this.context.getApplicationClientKey(), this.context.getAeadCipher());
    }

    public void encryptZeroRTTPacket(QuicPacket packet) throws CryptoException {
        this.encrypt(packet, this.context.getZeroRTTServerIv(), this.context.getZeroRTTServerKey(), this.context.getZeroRTTClientIv(), this.context.getZeroRTTClientKey(), this.context.getZeroRTTAeadCipher());
    }

    private void encrypt(QuicPacket packet, byte[] serverIv, byte[] serverKey, byte[] clientIv, byte[] clientKey, Cipher cipher) throws CryptoException {
        int i;
        byte[] encryptionKey;
        byte[] encryptionIv;
        ConnectionEndType connectionEndType = this.context.getTalkingConnectionEndType();
        switch (connectionEndType) {
            case SERVER: {
                encryptionIv = serverIv;
                encryptionKey = serverKey;
                break;
            }
            case CLIENT: {
                encryptionIv = clientIv;
                encryptionKey = clientKey;
                break;
            }
            default: {
                LOGGER.error("Unknown connectionEndType: {}", (Object)connectionEndType);
                return;
            }
        }
        byte[] decryptedPayload = (byte[])packet.getUnprotectedPayload().getValue();
        byte[] nonce = new byte[12];
        byte[] paddedPacketNumber = new byte[12];
        for (i = 0; i < 12 - ((byte[])packet.getUnprotectedPacketNumber().getValue()).length; ++i) {
            paddedPacketNumber[i] = 0;
        }
        i = 12 - ((byte[])packet.getUnprotectedPacketNumber().getValue()).length;
        int x = 0;
        while (i < 12) {
            paddedPacketNumber[i] = ((byte[])packet.getUnprotectedPacketNumber().getValue())[x];
            ++i;
            ++x;
        }
        for (i = 0; i < nonce.length; ++i) {
            nonce[i] = (byte)(encryptionIv[i] ^ paddedPacketNumber[i]);
        }
        byte[] associatedData = (byte[])packet.completeUnprotectedHeader.getValue();
        try {
            byte[] encryptedPayload = this.aeadEncrypt(associatedData, decryptedPayload, nonce, encryptionKey, cipher);
            packet.setProtectedPayload(encryptedPayload);
        }
        catch (IllegalArgumentException | IllegalStateException | InvalidKeyException | BadPaddingException | IllegalBlockSizeException ex) {
            throw new CryptoException("Could not encrypt " + packet.getPacketType().getName(), (Throwable)ex);
        }
        catch (InvalidAlgorithmParameterException ex) {
            LOGGER.info("Ignoring InvalidArgumentException");
        }
        catch (NoSuchPaddingException e) {
            throw new RuntimeException(e);
        }
        catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }

    private byte[] aeadEncrypt(byte[] associatedData, byte[] plaintext, byte[] nonce, byte[] key, Cipher aeadCipher) throws InvalidKeyException, IllegalBlockSizeException, BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException {
        AlgorithmParameterSpec parameterSpec;
        String algo;
        Cipher _cipher = Cipher.getInstance(aeadCipher.getAlgorithm());
        if (aeadCipher.getAlgorithm().equals("ChaCha20-Poly1305")) {
            algo = "ChaCha20";
            parameterSpec = new IvParameterSpec(nonce);
        } else {
            algo = "AES";
            parameterSpec = new GCMParameterSpec(128, nonce);
        }
        _cipher.init(1, (Key)new SecretKeySpec(key, algo), parameterSpec);
        _cipher.updateAAD(associatedData);
        return _cipher.doFinal(plaintext);
    }
}

