/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.quic.packet;

import de.rub.nds.modifiablevariable.ModifiableVariableHolder;
import de.rub.nds.protocol.exception.CryptoException;
import de.rub.nds.tlsattacker.core.constants.AlgorithmResolver;
import de.rub.nds.tlsattacker.core.constants.HKDFAlgorithm;
import de.rub.nds.tlsattacker.core.constants.ProtocolVersion;
import de.rub.nds.tlsattacker.core.crypto.HKDFunction;
import de.rub.nds.tlsattacker.core.quic.constants.QuicRetryConstants;
import de.rub.nds.tlsattacker.core.quic.constants.QuicVersion;
import de.rub.nds.tlsattacker.core.quic.packet.RetryPacket;
import de.rub.nds.tlsattacker.core.state.Context;
import de.rub.nds.tlsattacker.core.state.quic.QuicContext;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.Mac;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.spec.ChaCha20ParameterSpec;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class QuicPacketCryptoComputations
extends ModifiableVariableHolder {
    private static final Logger LOGGER = LogManager.getLogger();
    private static final int INITIAL_KEY_LENGTH = 16;
    private static final int IV_LENGTH = 12;

    public static byte[] generateHeaderProtectionMask(Cipher cipher, byte[] headerProtectionKey, byte[] sample) throws CryptoException {
        try {
            byte[] mask;
            if (cipher.getAlgorithm().equals("ChaCha20")) {
                ByteBuffer wrapped = ByteBuffer.wrap(Arrays.copyOfRange(sample, 0, 4));
                wrapped.order(ByteOrder.LITTLE_ENDIAN);
                int counter = wrapped.getInt();
                byte[] nonce = Arrays.copyOfRange(sample, 4, 16);
                ChaCha20ParameterSpec param = new ChaCha20ParameterSpec(nonce, counter);
                SecretKeySpec keySpec = new SecretKeySpec(headerProtectionKey, cipher.getAlgorithm());
                cipher.init(1, (Key)keySpec, param);
                mask = cipher.doFinal(new byte[]{0, 0, 0, 0, 0});
            } else {
                SecretKeySpec keySpec = new SecretKeySpec(headerProtectionKey, "AES");
                cipher.init(1, keySpec);
                mask = cipher.doFinal(sample);
            }
            return mask;
        }
        catch (InvalidAlgorithmParameterException | InvalidKeyException | BadPaddingException | IllegalBlockSizeException e) {
            throw new CryptoException("Could not generate header protection mask", (Throwable)e);
        }
    }

    public static byte[] generateInitialClientHeaderProtectionMask(QuicContext context, byte[] sample) throws CryptoException {
        return QuicPacketCryptoComputations.generateHeaderProtectionMask(context.getInitalHeaderProtectionCipher(), context.getInitialClientHeaderProtectionKey(), sample);
    }

    public static byte[] generateInitialServerHeaderProtectionMask(QuicContext context, byte[] sample) throws CryptoException {
        return QuicPacketCryptoComputations.generateHeaderProtectionMask(context.getInitalHeaderProtectionCipher(), context.getInitialServerHeaderProtectionKey(), sample);
    }

    public static byte[] generateHandshakeClientHeaderProtectionMask(QuicContext context, byte[] sample) throws CryptoException {
        return QuicPacketCryptoComputations.generateHeaderProtectionMask(context.getHeaderProtectionCipher(), context.getHandshakeClientHeaderProtectionKey(), sample);
    }

    public static byte[] generateHandshakeServerHeaderProtectionMask(QuicContext context, byte[] sample) throws CryptoException {
        return QuicPacketCryptoComputations.generateHeaderProtectionMask(context.getHeaderProtectionCipher(), context.getHandshakeServerHeaderProtectionKey(), sample);
    }

    public static byte[] generateOneRRTClientHeaderProtectionMask(QuicContext context, byte[] sample) throws CryptoException {
        return QuicPacketCryptoComputations.generateHeaderProtectionMask(context.getHeaderProtectionCipher(), context.getApplicationClientHeaderProtectionKey(), sample);
    }

    public static byte[] generateOneRTTServerHeaderProtectionMask(QuicContext context, byte[] sample) throws CryptoException {
        return QuicPacketCryptoComputations.generateHeaderProtectionMask(context.getHeaderProtectionCipher(), context.getApplicationServerHeaderProtectionKey(), sample);
    }

    public static byte[] generateZeroRTTClientHeaderProtectionMask(QuicContext context, byte[] sample) throws CryptoException {
        return QuicPacketCryptoComputations.generateHeaderProtectionMask(context.getZeroRTTHeaderProtectionCipher(), context.getZeroRTTClientHeaderProtectionKey(), sample);
    }

    public static byte[] generateZeroRTTServerHeaderProtectionMask(QuicContext context, byte[] sample) throws CryptoException {
        return QuicPacketCryptoComputations.generateHeaderProtectionMask(context.getZeroRTTHeaderProtectionCipher(), context.getZeroRTTServerHeaderProtectionKey(), sample);
    }

    public static void calculateInitialSecrets(QuicContext context) throws CryptoException, NoSuchAlgorithmException {
        LOGGER.debug("Initialize Quic Initial Secrets");
        QuicVersion version = context.getQuicVersion();
        if (version == QuicVersion.NEGOTIATION_VERSION) {
            LOGGER.debug("Version Negotiation Packets do not have initial secrets. They are not encrypted.");
            context.setInitialSecretsInitialized(false);
            return;
        }
        HKDFAlgorithm hkdfAlgorithm = context.getInitialHKDFAlgorithm();
        Mac mac = Mac.getInstance(hkdfAlgorithm.getMacAlgorithm().getJavaName());
        context.setInitialSecret(HKDFunction.extract(hkdfAlgorithm, context.getInitialSalt(), context.getFirstDestinationConnectionId()));
        context.setInitialClientSecret(HKDFunction.expandLabel(hkdfAlgorithm, context.getInitialSecret(), "client in", new byte[0], mac.getMacLength(), ProtocolVersion.TLS13));
        context.setInitialClientKey(QuicPacketCryptoComputations.deriveKeyFromSecret(version, hkdfAlgorithm, context.getInitialClientSecret(), 16));
        context.setInitialClientIv(QuicPacketCryptoComputations.deriveIvFromSecret(version, hkdfAlgorithm, context.getInitialClientSecret()));
        context.setInitialClientHeaderProtectionKey(QuicPacketCryptoComputations.deriveHeaderProtectionKeyFromSecret(version, hkdfAlgorithm, context.getInitialClientSecret(), 16));
        context.setInitialServerSecret(HKDFunction.expandLabel(hkdfAlgorithm, context.getInitialSecret(), "server in", new byte[0], mac.getMacLength(), ProtocolVersion.TLS13));
        context.setInitialServerKey(QuicPacketCryptoComputations.deriveKeyFromSecret(version, hkdfAlgorithm, context.getInitialServerSecret(), 16));
        context.setInitialServerIv(QuicPacketCryptoComputations.deriveIvFromSecret(version, hkdfAlgorithm, context.getInitialServerSecret()));
        context.setInitialServerHeaderProtectionKey(QuicPacketCryptoComputations.deriveHeaderProtectionKeyFromSecret(version, hkdfAlgorithm, context.getInitialServerSecret(), 16));
        context.setInitialSecretsInitialized(true);
    }

    public static void calculateHandshakeSecrets(Context context) throws NoSuchPaddingException, NoSuchAlgorithmException, CryptoException {
        LOGGER.debug("Initialize Quic Handshake Secrets");
        QuicContext quicContext = context.getQuicContext();
        quicContext.setAeadCipher(Cipher.getInstance(AlgorithmResolver.getCipher(context.getTlsContext().getSelectedCipherSuite()).getJavaName()));
        int keyLength = 16;
        switch (context.getTlsContext().getSelectedCipherSuite()) {
            case TLS_AES_128_CCM_SHA256: 
            case TLS_AES_128_GCM_SHA256: {
                keyLength = 16;
                quicContext.setHeaderProtectionCipher(Cipher.getInstance("AES/ECB/NoPadding"));
                break;
            }
            case TLS_AES_256_GCM_SHA384: {
                keyLength = 32;
                quicContext.setHeaderProtectionCipher(Cipher.getInstance("AES/ECB/NoPadding"));
                break;
            }
            case TLS_CHACHA20_POLY1305_SHA256: {
                keyLength = 32;
                quicContext.setHeaderProtectionCipher(Cipher.getInstance("ChaCha20"));
                break;
            }
            default: {
                LOGGER.warn("Unsupported Cipher Suite: {}", (Object)context.getTlsContext().getSelectedCipherSuite());
            }
        }
        quicContext.setHkdfAlgorithm(AlgorithmResolver.getHKDFAlgorithm(context.getTlsContext().getSelectedCipherSuite()));
        HKDFAlgorithm hkdfAlgorithm = context.getQuicContext().getHkdfAlgorithm();
        QuicVersion version = quicContext.getQuicVersion();
        quicContext.setHandshakeClientSecret(context.getTlsContext().getClientHandshakeTrafficSecret());
        quicContext.setHandshakeClientKey(QuicPacketCryptoComputations.deriveKeyFromSecret(version, hkdfAlgorithm, quicContext.getHandshakeClientSecret(), keyLength));
        quicContext.setHandshakeClientIv(QuicPacketCryptoComputations.deriveIvFromSecret(version, hkdfAlgorithm, quicContext.getHandshakeClientSecret()));
        quicContext.setHandshakeClientHeaderProtectionKey(QuicPacketCryptoComputations.deriveHeaderProtectionKeyFromSecret(version, hkdfAlgorithm, quicContext.getHandshakeClientSecret(), keyLength));
        quicContext.setHandshakeServerSecret(context.getTlsContext().getServerHandshakeTrafficSecret());
        quicContext.setHandshakeServerKey(QuicPacketCryptoComputations.deriveKeyFromSecret(version, hkdfAlgorithm, quicContext.getHandshakeServerSecret(), keyLength));
        quicContext.setHandshakeServerIv(QuicPacketCryptoComputations.deriveIvFromSecret(version, hkdfAlgorithm, quicContext.getHandshakeServerSecret()));
        quicContext.setHandshakeServerHeaderProtectionKey(QuicPacketCryptoComputations.deriveHeaderProtectionKeyFromSecret(version, hkdfAlgorithm, quicContext.getHandshakeServerSecret(), keyLength));
        quicContext.setHandshakeSecretsInitialized(true);
    }

    public static void calculateApplicationSecrets(Context context) throws NoSuchPaddingException, NoSuchAlgorithmException, CryptoException {
        LOGGER.debug("Initialize Quic Application Secrets");
        QuicContext quicContext = context.getQuicContext();
        int keyLength = 16;
        switch (context.getTlsContext().getSelectedCipherSuite()) {
            case TLS_AES_128_CCM_SHA256: 
            case TLS_AES_128_GCM_SHA256: {
                keyLength = 16;
                break;
            }
            case TLS_AES_256_GCM_SHA384: 
            case TLS_CHACHA20_POLY1305_SHA256: {
                keyLength = 32;
                break;
            }
            default: {
                LOGGER.warn("Unsupported Cipher Suite: {}", (Object)context.getTlsContext().getSelectedCipherSuite());
            }
        }
        HKDFAlgorithm hkdfAlgorithm = quicContext.getHkdfAlgorithm();
        QuicVersion version = quicContext.getQuicVersion();
        quicContext.setApplicationClientSecret(context.getTlsContext().getClientApplicationTrafficSecret());
        quicContext.setApplicationClientKey(QuicPacketCryptoComputations.deriveKeyFromSecret(version, hkdfAlgorithm, quicContext.getApplicationClientSecret(), keyLength));
        quicContext.setApplicationClientIv(QuicPacketCryptoComputations.deriveIvFromSecret(version, hkdfAlgorithm, quicContext.getApplicationClientSecret()));
        quicContext.setApplicationClientHeaderProtectionKey(QuicPacketCryptoComputations.deriveHeaderProtectionKeyFromSecret(version, hkdfAlgorithm, quicContext.getApplicationClientSecret(), keyLength));
        quicContext.setApplicationServerSecret(context.getTlsContext().getServerApplicationTrafficSecret());
        quicContext.setApplicationServerKey(QuicPacketCryptoComputations.deriveKeyFromSecret(version, hkdfAlgorithm, quicContext.getApplicationServerSecret(), keyLength));
        quicContext.setApplicationServerIv(QuicPacketCryptoComputations.deriveIvFromSecret(version, hkdfAlgorithm, quicContext.getApplicationServerSecret()));
        quicContext.setApplicationServerHeaderProtectionKey(QuicPacketCryptoComputations.deriveHeaderProtectionKeyFromSecret(version, hkdfAlgorithm, quicContext.getApplicationServerSecret(), keyLength));
        quicContext.setApplicationSecretsInitialized(true);
    }

    public static void calculateZeroRTTSecrets(Context context) throws CryptoException, NoSuchPaddingException, NoSuchAlgorithmException {
        LOGGER.debug("Initialize Quic 0-RTT Secrets");
        QuicContext quicContext = context.getQuicContext();
        quicContext.setZeroRTTCipherSuite(context.getTlsContext().getEarlyDataCipherSuite());
        quicContext.setZeroRTTAeadCipher(Cipher.getInstance(context.getTlsContext().getEarlyDataCipherSuite().getCipherAlgorithm().getJavaName()));
        int keyLength = 16;
        switch (quicContext.getZeroRTTCipherSuite()) {
            case TLS_AES_128_CCM_SHA256: 
            case TLS_AES_128_GCM_SHA256: {
                keyLength = 16;
                quicContext.setZeroRTTHeaderProtectionCipher(Cipher.getInstance("AES/ECB/NoPadding"));
                break;
            }
            case TLS_AES_256_GCM_SHA384: {
                keyLength = 32;
                quicContext.setZeroRTTHeaderProtectionCipher(Cipher.getInstance("AES/ECB/NoPadding"));
                break;
            }
            case TLS_CHACHA20_POLY1305_SHA256: {
                keyLength = 32;
                quicContext.setZeroRTTHeaderProtectionCipher(Cipher.getInstance("ChaCha20"));
                break;
            }
            default: {
                LOGGER.warn("Unsupported Cipher Suite: {}", (Object)quicContext.getZeroRTTCipherSuite());
            }
        }
        quicContext.setZeroRTTHKDFAlgorithm(AlgorithmResolver.getHKDFAlgorithm(quicContext.getZeroRTTCipherSuite()));
        HKDFAlgorithm hkdfAlgorithm = quicContext.getZeroRTTHKDFAlgorithm();
        QuicVersion version = quicContext.getQuicVersion();
        quicContext.setZeroRTTClientSecret(context.getTlsContext().getClientEarlyTrafficSecret());
        quicContext.setZeroRTTClientKey(QuicPacketCryptoComputations.deriveKeyFromSecret(version, hkdfAlgorithm, quicContext.getZeroRTTClientSecret(), keyLength));
        quicContext.setZeroRTTClientIv(QuicPacketCryptoComputations.deriveIvFromSecret(version, hkdfAlgorithm, quicContext.getZeroRTTClientSecret()));
        quicContext.setZeroRTTClientHeaderProtectionKey(QuicPacketCryptoComputations.deriveHeaderProtectionKeyFromSecret(version, hkdfAlgorithm, quicContext.getZeroRTTClientSecret(), keyLength));
        quicContext.setZeroRTTServerSecret(context.getTlsContext().getClientEarlyTrafficSecret());
        quicContext.setZeroRTTServerKey(QuicPacketCryptoComputations.deriveKeyFromSecret(version, hkdfAlgorithm, quicContext.getZeroRTTServerSecret(), keyLength));
        quicContext.setZeroRTTServerIv(QuicPacketCryptoComputations.deriveIvFromSecret(version, hkdfAlgorithm, quicContext.getZeroRTTServerSecret()));
        quicContext.setZeroRTTServerHeaderProtectionKey(QuicPacketCryptoComputations.deriveHeaderProtectionKeyFromSecret(version, hkdfAlgorithm, quicContext.getZeroRTTServerSecret(), keyLength));
        quicContext.setZeroRTTSecretsInitialized(true);
    }

    private static byte[] deriveKeyFromSecret(QuicVersion version, HKDFAlgorithm hkdfAlgorithm, byte[] secret, int keyLength) throws CryptoException {
        return HKDFunction.expandLabel(hkdfAlgorithm, secret, version.getKeyLabel(), new byte[0], keyLength, ProtocolVersion.TLS13);
    }

    private static byte[] deriveIvFromSecret(QuicVersion version, HKDFAlgorithm hkdfAlgorithm, byte[] secret) throws CryptoException {
        return HKDFunction.expandLabel(hkdfAlgorithm, secret, version.getIvLabel(), new byte[0], 12, ProtocolVersion.TLS13);
    }

    private static byte[] deriveHeaderProtectionKeyFromSecret(QuicVersion version, HKDFAlgorithm hkdfAlgorithm, byte[] secret, int keyLength) throws CryptoException {
        return HKDFunction.expandLabel(hkdfAlgorithm, secret, version.getHeaderProtectionLabel(), new byte[0], keyLength, ProtocolVersion.TLS13);
    }

    public static byte[] calculateRetryIntegrityTag(QuicContext context, RetryPacket packet) {
        byte[] computedTag;
        byte[] pseudoPacket = ByteBuffer.allocate(1 + context.getFirstDestinationConnectionId().length + 1 + 4 + 1 + (Byte)packet.getDestinationConnectionIdLength().getValue() + 1 + (Byte)packet.getSourceConnectionIdLength().getValue() + ((byte[])packet.retryToken.getValue()).length).put((byte)(context.getFirstDestinationConnectionId().length & 0xFF)).put(context.getFirstDestinationConnectionId()).put((Byte)packet.getUnprotectedFlags().getValue()).put(context.getQuicVersion().getByteValue()).put((Byte)packet.getDestinationConnectionIdLength().getValue()).put((byte[])packet.getDestinationConnectionId().getValue()).put((Byte)packet.getSourceConnectionIdLength().getValue()).put((byte[])packet.getSourceConnectionId().getValue()).put((byte[])packet.retryToken.getValue()).array();
        LOGGER.trace("Build Integrity Check Pseudo Packet {}", (Object)pseudoPacket);
        try {
            SecretKeySpec secretKey = new SecretKeySpec(QuicRetryConstants.getRetryIntegrityTagKey(context.getQuicVersion()), "AES");
            Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
            GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(128, QuicRetryConstants.getRetryIntegrityTagIv(context.getQuicVersion()));
            cipher.init(1, (Key)secretKey, gcmParameterSpec);
            cipher.updateAAD(pseudoPacket);
            computedTag = cipher.doFinal();
        }
        catch (InvalidAlgorithmParameterException | InvalidKeyException | NoSuchAlgorithmException | BadPaddingException | IllegalBlockSizeException | NoSuchPaddingException e) {
            throw new CryptoException("Error while computing Retry Integrity Tag", (Throwable)e);
        }
        if (computedTag.length == 0) {
            throw new CryptoException("Attempted to compute Retry Integrity Tag for verification but result is empty!");
        }
        return computedTag;
    }
}

