/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.quic.crypto;

import de.rub.nds.modifiablevariable.util.DataConverter;
import de.rub.nds.protocol.exception.CryptoException;
import de.rub.nds.tlsattacker.core.quic.constants.QuicPacketType;
import de.rub.nds.tlsattacker.core.quic.packet.HandshakePacket;
import de.rub.nds.tlsattacker.core.quic.packet.InitialPacket;
import de.rub.nds.tlsattacker.core.quic.packet.QuicPacket;
import de.rub.nds.tlsattacker.core.quic.packet.QuicPacketCryptoComputations;
import de.rub.nds.tlsattacker.core.state.quic.QuicContext;
import de.rub.nds.tlsattacker.transport.ConnectionEndType;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.spec.AlgorithmParameterSpec;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class QuicDecryptor {
    private static final Logger LOGGER = LogManager.getLogger();
    private final QuicContext context;

    public QuicDecryptor(QuicContext context) {
        this.context = context;
    }

    public void removeHeaderProtectionInitial(InitialPacket packet) throws CryptoException {
        this.removeHeaderProtection(packet, QuicPacketCryptoComputations.generateInitialServerHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()), QuicPacketCryptoComputations.generateInitialClientHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()));
    }

    public void removeHeaderProtectionHandshake(HandshakePacket packet) throws CryptoException {
        this.removeHeaderProtection(packet, QuicPacketCryptoComputations.generateHandshakeServerHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()), QuicPacketCryptoComputations.generateHandshakeClientHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()));
    }

    public void removeHeaderProtectionZeroRTT(QuicPacket packet) throws CryptoException {
        this.removeHeaderProtection(packet, QuicPacketCryptoComputations.generateZeroRTTServerHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()), QuicPacketCryptoComputations.generateZeroRTTClientHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()));
    }

    public void removeHeaderProtectionOneRTT(QuicPacket packet) throws CryptoException {
        this.removeHeaderProtection(packet, QuicPacketCryptoComputations.generateOneRTTServerHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()), QuicPacketCryptoComputations.generateOneRRTClientHeaderProtectionMask(this.context, packet.getHeaderProtectionSample()));
    }

    public void removeHeaderProtection(QuicPacket packet, byte[] serverHeaderProtectionMask, byte[] clientHeaderProtectionMask) {
        byte[] headerProtectionMask;
        ConnectionEndType connectionEndType = this.context.getTalkingConnectionEndType();
        if (this.context.getConfig().isEchoQuic().booleanValue()) {
            connectionEndType = connectionEndType.getPeer();
        }
        switch (connectionEndType) {
            case SERVER: {
                headerProtectionMask = serverHeaderProtectionMask;
                break;
            }
            case CLIENT: {
                headerProtectionMask = clientHeaderProtectionMask;
                break;
            }
            default: {
                LOGGER.error("Unknown connectionEndType: {}", (Object)connectionEndType);
                return;
            }
        }
        byte flags = (Byte)packet.getProtectedFlags().getValue();
        byte hpMask = headerProtectionMask[0];
        byte unprotectedFlags = QuicPacketType.isShortHeaderPacket(flags) ? (byte)(flags ^ hpMask & 0x1F) : (byte)(flags ^ hpMask & 0xF);
        packet.setUnprotectedFlags(unprotectedFlags);
        int length = (unprotectedFlags & 3) + 1;
        packet.setPacketNumberLength(length);
        byte[] protectedPacketNumber = new byte[length];
        System.arraycopy(packet.getProtectedPacketNumberAndPayload().getValue(), 0, protectedPacketNumber, 0, length);
        packet.setProtectedPacketNumber(protectedPacketNumber);
        byte[] result = new byte[((Integer)packet.getPacketNumberLength().getValue()).intValue()];
        for (int i = 0; i < (Integer)packet.getPacketNumberLength().getValue(); ++i) {
            result[i] = (byte)(headerProtectionMask[i + 1] ^ protectedPacketNumber[i]);
        }
        packet.protectedHeaderHelper.write(result);
        packet.setUnprotectedPacketNumber(result);
        this.restorePacketNumber(packet);
    }

    private void restorePacketNumber(QuicPacket packet) {
        int largest_Pn = 0;
        switch (packet.getPacketType()) {
            case INITIAL_PACKET: {
                if (this.context.getReceivedInitialPacketNumbers().isEmpty()) break;
                largest_Pn = this.context.getReceivedInitialPacketNumbers().getLast();
                break;
            }
            case HANDSHAKE_PACKET: {
                if (this.context.getReceivedHandshakePacketNumbers().isEmpty()) break;
                largest_Pn = this.context.getReceivedHandshakePacketNumbers().getLast();
                break;
            }
            case ONE_RTT_PACKET: {
                if (this.context.getReceivedOneRTTPacketNumbers().isEmpty()) break;
                largest_Pn = this.context.getReceivedOneRTTPacketNumbers().getLast();
                break;
            }
        }
        int truncated_Pn = DataConverter.bytesToInt((byte[])((byte[])packet.getUnprotectedPacketNumber().getValue()));
        int pn_nBits = (Integer)packet.getPacketNumberLength().getValue() * 8;
        long decodedPn = packet.decodePacketNumber(truncated_Pn, largest_Pn, pn_nBits);
        LOGGER.debug("Decoded pktNumber: {}, raw pktNumber: {}", (Object)decodedPn, (Object)DataConverter.bytesToInt((byte[])((byte[])packet.getUnprotectedPacketNumber().getValue())));
        packet.setRestoredPacketNumber((int)decodedPn);
        packet.setPlainPacketNumber((int)decodedPn);
        if (((byte[])packet.getUnprotectedPacketNumber().getValue()).length >= ((byte[])packet.getRestoredPacketNumber().getValue()).length) {
            packet.setRestoredPacketNumber((byte[])packet.getUnprotectedPacketNumber().getValue());
            packet.setPlainPacketNumber(DataConverter.bytesToInt((byte[])((byte[])packet.getUnprotectedPacketNumber().getValue())));
        }
    }

    public void decryptInitialPacket(InitialPacket packet) throws CryptoException {
        this.decrypt(packet, this.context.getInitialServerIv(), this.context.getInitialServerKey(), this.context.getInitialClientIv(), this.context.getInitialClientKey(), this.context.getInitialAeadCipher());
    }

    public void decryptHandshakePacket(HandshakePacket packet) throws CryptoException {
        this.decrypt(packet, this.context.getHandshakeServerIv(), this.context.getHandshakeServerKey(), this.context.getHandshakeClientIv(), this.context.getHandshakeClientKey(), this.context.getAeadCipher());
    }

    public void decryptOneRTTPacket(QuicPacket packet) throws CryptoException {
        this.decrypt(packet, this.context.getApplicationServerIv(), this.context.getApplicationServerKey(), this.context.getApplicationClientIv(), this.context.getApplicationClientKey(), this.context.getAeadCipher());
    }

    private void decrypt(QuicPacket packet, byte[] serverIv, byte[] serverKey, byte[] clientIv, byte[] clientKey, Cipher cipher) throws CryptoException {
        int i;
        byte[] decryptionKey;
        byte[] decryptionIv;
        ConnectionEndType connectionEndType = this.context.getTalkingConnectionEndType();
        if (this.context.getConfig().isEchoQuic().booleanValue()) {
            connectionEndType = connectionEndType.getPeer();
        }
        switch (connectionEndType) {
            case SERVER: {
                decryptionIv = serverIv;
                decryptionKey = serverKey;
                break;
            }
            case CLIENT: {
                decryptionIv = clientIv;
                decryptionKey = clientKey;
                break;
            }
            default: {
                LOGGER.error("Unknown connectionEndType: {}", (Object)connectionEndType);
                return;
            }
        }
        byte[] encryptedPayload = new byte[(Integer)packet.getPacketLength().getValue() - (Integer)packet.getPacketNumberLength().getValue()];
        System.arraycopy(packet.getProtectedPacketNumberAndPayload().getValue(), (Integer)packet.getPacketNumberLength().getValue(), encryptedPayload, 0, (Integer)packet.getPacketLength().getValue() - (Integer)packet.getPacketNumberLength().getValue());
        byte[] nonce = new byte[12];
        byte[] paddedPacketNumber = new byte[12];
        for (i = 0; i < 12 - ((byte[])packet.getRestoredPacketNumber().getValue()).length; ++i) {
            paddedPacketNumber[i] = 0;
        }
        i = 12 - ((byte[])packet.getRestoredPacketNumber().getValue()).length;
        int x = 0;
        while (i < 12) {
            paddedPacketNumber[i] = ((byte[])packet.getRestoredPacketNumber().getValue())[x];
            ++i;
            ++x;
        }
        for (i = 0; i < nonce.length; ++i) {
            nonce[i] = (byte)(decryptionIv[i] ^ paddedPacketNumber[i]);
        }
        byte[] associatedData = new byte[packet.offsetToPacketNumber + ((byte[])packet.getUnprotectedPacketNumber().getValue()).length];
        System.arraycopy(packet.completeUnprotectedHeader.getValue(), 0, associatedData, 0, packet.offsetToPacketNumber + ((byte[])packet.getUnprotectedPacketNumber().getValue()).length);
        try {
            byte[] decryptedPayload = this.aeadDecrypt(associatedData, encryptedPayload, nonce, decryptionKey, cipher);
            packet.setUnprotectedPayload(decryptedPayload);
        }
        catch (IllegalArgumentException | IllegalStateException | InvalidAlgorithmParameterException | InvalidKeyException | BadPaddingException | IllegalBlockSizeException ex) {
            throw new CryptoException("Could not decrypt " + packet.getPacketType().getName(), (Throwable)ex);
        }
    }

    public byte[] aeadDecrypt(byte[] associatedData, byte[] ciphertext, byte[] nonce, byte[] key, Cipher aeadCipher) throws InvalidAlgorithmParameterException, InvalidKeyException, IllegalBlockSizeException, BadPaddingException {
        AlgorithmParameterSpec parameterSpec;
        String algo;
        if (aeadCipher.getAlgorithm().equals("ChaCha20-Poly1305")) {
            algo = "ChaCha20";
            parameterSpec = new IvParameterSpec(nonce);
        } else {
            algo = "AES";
            parameterSpec = new GCMParameterSpec(128, nonce);
        }
        aeadCipher.init(2, (Key)new SecretKeySpec(key, algo), parameterSpec);
        aeadCipher.updateAAD(associatedData);
        return aeadCipher.doFinal(ciphertext);
    }
}

