/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.record.cipher.cryptohelper;

import de.rub.nds.modifiablevariable.util.DataConverter;
import de.rub.nds.protocol.constants.MacAlgorithm;
import de.rub.nds.protocol.exception.CryptoException;
import de.rub.nds.tlsattacker.core.constants.AlgorithmResolver;
import de.rub.nds.tlsattacker.core.constants.CipherAlgorithm;
import de.rub.nds.tlsattacker.core.constants.CipherSuite;
import de.rub.nds.tlsattacker.core.constants.HKDFAlgorithm;
import de.rub.nds.tlsattacker.core.constants.PRFAlgorithm;
import de.rub.nds.tlsattacker.core.constants.ProtocolVersion;
import de.rub.nds.tlsattacker.core.constants.Tls13KeySetType;
import de.rub.nds.tlsattacker.core.crypto.HKDFunction;
import de.rub.nds.tlsattacker.core.crypto.MD5Utils;
import de.rub.nds.tlsattacker.core.crypto.PseudoRandomFunction;
import de.rub.nds.tlsattacker.core.crypto.SSLUtils;
import de.rub.nds.tlsattacker.core.layer.context.TlsContext;
import de.rub.nds.tlsattacker.core.record.cipher.cryptohelper.KeyBlockParser;
import de.rub.nds.tlsattacker.core.record.cipher.cryptohelper.KeySet;
import de.rub.nds.tlsattacker.core.workflow.chooser.Chooser;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class KeyDerivator {
    private static final Logger LOGGER = LogManager.getLogger();
    private static final int AEAD_IV_LENGTH = 12;

    public static byte[] calculateMasterSecret(TlsContext tlsContext, byte[] clientServerRandom) throws CryptoException {
        Chooser chooser = tlsContext.getChooser();
        if (chooser.getSelectedProtocolVersion() == ProtocolVersion.SSL3) {
            LOGGER.debug("Calculate SSL MasterSecret with Client and Server Nonces, which are: {}", (Object)clientServerRandom);
            return SSLUtils.calculateMasterSecretSSL3(chooser.getPreMasterSecret(), clientServerRandom);
        }
        PRFAlgorithm prfAlgorithm = AlgorithmResolver.getPRFAlgorithm(chooser.getSelectedProtocolVersion(), chooser.getSelectedCipherSuite());
        if (chooser.isUseExtendedMasterSecret()) {
            LOGGER.debug("Calculating ExtendedMasterSecret");
            byte[] sessionHash = tlsContext.getDigest().digest(chooser.getSelectedProtocolVersion(), chooser.getSelectedCipherSuite());
            LOGGER.debug("Premastersecret: {}", (Object)chooser.getPreMasterSecret());
            LOGGER.debug("SessionHash: {}", (Object)sessionHash);
            byte[] extendedMasterSecret = PseudoRandomFunction.compute(prfAlgorithm, chooser.getPreMasterSecret(), "extended master secret", sessionHash, 48);
            return extendedMasterSecret;
        }
        LOGGER.debug("Calculating MasterSecret");
        byte[] masterSecret = PseudoRandomFunction.compute(prfAlgorithm, chooser.getPreMasterSecret(), "master secret", clientServerRandom, 48);
        return masterSecret;
    }

    public static KeySet generateKeySet(TlsContext tlsContext, ProtocolVersion protocolVersion, Tls13KeySetType keySetType) throws NoSuchAlgorithmException, CryptoException {
        if (protocolVersion.is13()) {
            return KeyDerivator.getTls13KeySet(tlsContext, keySetType);
        }
        return KeyDerivator.getTlsKeySet(tlsContext);
    }

    public static KeySet generateKeySet(TlsContext tlsContext) throws NoSuchAlgorithmException, CryptoException {
        return KeyDerivator.generateKeySet(tlsContext, tlsContext.getChooser().getSelectedProtocolVersion(), tlsContext.getActiveKeySetTypeWrite());
    }

    private static KeySet getTls13KeySet(TlsContext tlsContext, Tls13KeySetType keySetType) throws CryptoException {
        byte[] serverSecret;
        byte[] clientSecret;
        CipherSuite cipherSuite = tlsContext.getChooser().getSelectedCipherSuite();
        if (null == keySetType) {
            throw new CryptoException("Unknown KeySetType: null");
        }
        switch (keySetType) {
            case HANDSHAKE_TRAFFIC_SECRETS: {
                clientSecret = tlsContext.getChooser().getClientHandshakeTrafficSecret();
                serverSecret = tlsContext.getChooser().getServerHandshakeTrafficSecret();
                break;
            }
            case APPLICATION_TRAFFIC_SECRETS: {
                clientSecret = tlsContext.getChooser().getClientApplicationTrafficSecret();
                serverSecret = tlsContext.getChooser().getServerApplicationTrafficSecret();
                break;
            }
            case EARLY_TRAFFIC_SECRETS: {
                cipherSuite = tlsContext.getChooser().getEarlyDataCipherSuite();
                clientSecret = tlsContext.getChooser().getClientEarlyTrafficSecret();
                serverSecret = tlsContext.getChooser().getClientEarlyTrafficSecret();
                break;
            }
            case NONE: {
                LOGGER.warn("KeySet is NONE! , returning empty KeySet");
                return new KeySet(keySetType);
            }
            default: {
                throw new CryptoException("Unknown KeySetType:" + keySetType.name());
            }
        }
        LOGGER.debug("ActiveKeySetType is {}", (Object)keySetType);
        CipherAlgorithm cipherAlg = cipherSuite.getCipherAlgorithm();
        if (cipherAlg == null) {
            LOGGER.debug("No cipher algorithm found for cipher suite: {}, falling back to TLS_AES_128_GCM_SHA256", (Object)cipherSuite);
            cipherAlg = CipherSuite.TLS_AES_128_GCM_SHA256.getCipherAlgorithm();
        }
        KeySet keySet = new KeySet(keySetType);
        HKDFAlgorithm hkdfAlgorithm = AlgorithmResolver.getHKDFAlgorithm(cipherSuite);
        keySet.setClientWriteKey(HKDFunction.expandLabel(hkdfAlgorithm, clientSecret, "key", new byte[0], cipherAlg.getKeySize(), tlsContext.getChooser().getSelectedProtocolVersion()));
        LOGGER.debug("Client write key: {}", (Object)keySet.getClientWriteKey());
        keySet.setServerWriteKey(HKDFunction.expandLabel(hkdfAlgorithm, serverSecret, "key", new byte[0], cipherAlg.getKeySize(), tlsContext.getChooser().getSelectedProtocolVersion()));
        LOGGER.debug("Server write key: {}", (Object)keySet.getServerWriteKey());
        keySet.setClientWriteIv(HKDFunction.expandLabel(hkdfAlgorithm, clientSecret, "iv", new byte[0], 12, tlsContext.getChooser().getSelectedProtocolVersion()));
        LOGGER.debug("Client write IV: {}", (Object)keySet.getClientWriteIv());
        keySet.setServerWriteIv(HKDFunction.expandLabel(hkdfAlgorithm, serverSecret, "iv", new byte[0], 12, tlsContext.getChooser().getSelectedProtocolVersion()));
        LOGGER.debug("Server write IV: {}", (Object)keySet.getServerWriteIv());
        keySet.setClientSnKey(HKDFunction.expandLabel(hkdfAlgorithm, clientSecret, "sn", new byte[0], cipherAlg.getKeySize(), tlsContext.getChooser().getSelectedProtocolVersion()));
        LOGGER.debug("Client sn key: {}", (Object)keySet.getClientSnKey());
        keySet.setServerSnKey(HKDFunction.expandLabel(hkdfAlgorithm, serverSecret, "sn", new byte[0], cipherAlg.getKeySize(), tlsContext.getChooser().getSelectedProtocolVersion()));
        LOGGER.debug("Server sn key: {}", (Object)keySet.getServerSnKey());
        keySet.setServerWriteMacSecret(new byte[0]);
        keySet.setClientWriteMacSecret(new byte[0]);
        return keySet;
    }

    private static KeySet getTlsKeySet(TlsContext tlsContext) throws CryptoException {
        byte[] keyBlock;
        ProtocolVersion protocolVersion = tlsContext.getChooser().getSelectedProtocolVersion();
        CipherSuite cipherSuite = tlsContext.getChooser().getSelectedCipherSuite();
        byte[] masterSecret = tlsContext.getChooser().getMasterSecret();
        byte[] seed = DataConverter.concatenate((byte[][])new byte[][]{tlsContext.getChooser().getServerRandom(), tlsContext.getChooser().getClientRandom()});
        if (protocolVersion.isSSL()) {
            keyBlock = SSLUtils.calculateKeyBlockSSL3(masterSecret, seed, KeyDerivator.getSecretSetSize(protocolVersion, cipherSuite));
        } else {
            PRFAlgorithm prfAlgorithm = AlgorithmResolver.getPRFAlgorithm(protocolVersion, cipherSuite);
            keyBlock = PseudoRandomFunction.compute(prfAlgorithm, masterSecret, "key expansion", seed, KeyDerivator.getSecretSetSize(protocolVersion, cipherSuite));
        }
        LOGGER.debug("A new key block was generated: {}", (Object)keyBlock);
        KeyBlockParser parser = new KeyBlockParser(keyBlock, cipherSuite, protocolVersion);
        KeySet keySet = new KeySet();
        parser.parse(keySet);
        if (cipherSuite.isExportSymmetricCipher()) {
            KeyDerivator.deriveExportKeys(keySet, tlsContext);
        }
        return keySet;
    }

    private static void deriveExportKeys(KeySet keySet, TlsContext tlsContext) throws CryptoException {
        ProtocolVersion protocolVersion = tlsContext.getChooser().getSelectedProtocolVersion();
        CipherSuite cipherSuite = tlsContext.getChooser().getSelectedCipherSuite();
        byte[] clientRandom = tlsContext.getChooser().getClientRandom();
        byte[] serverRandom = tlsContext.getChooser().getServerRandom();
        if (protocolVersion == ProtocolVersion.SSL3) {
            KeyDerivator.deriveSSL3ExportKeys(cipherSuite, keySet, clientRandom, serverRandom);
            return;
        }
        byte[] clientAndServerRandom = DataConverter.concatenate((byte[][])new byte[][]{clientRandom, serverRandom});
        PRFAlgorithm prfAlgorithm = AlgorithmResolver.getPRFAlgorithm(protocolVersion, cipherSuite);
        int keySize = cipherSuite.getCipherAlgorithm().getExportFinalKeySize();
        keySet.setClientWriteKey(PseudoRandomFunction.compute(prfAlgorithm, keySet.getClientWriteKey(), "client write key", clientAndServerRandom, keySize));
        keySet.setServerWriteKey(PseudoRandomFunction.compute(prfAlgorithm, keySet.getServerWriteKey(), "server write key", clientAndServerRandom, keySize));
        int blockSize = cipherSuite.getCipherAlgorithm().getBlocksize();
        byte[] emptySecret = new byte[]{};
        byte[] ivBlock = PseudoRandomFunction.compute(prfAlgorithm, emptySecret, "IV block", clientAndServerRandom, 2 * blockSize);
        keySet.setClientWriteIv(Arrays.copyOfRange(ivBlock, 0, blockSize));
        keySet.setServerWriteIv(Arrays.copyOfRange(ivBlock, blockSize, 2 * blockSize));
    }

    private static byte[] md5firstNBytes(int numOfBytes, byte[] ... byteArrays) {
        byte[] md5 = MD5Utils.md5(byteArrays);
        return Arrays.copyOfRange(md5, 0, numOfBytes);
    }

    private static void deriveSSL3ExportKeys(CipherSuite cipherSuite, KeySet keySet, byte[] clientRandom, byte[] serverRandom) {
        int keySize = cipherSuite.getCipherAlgorithm().getExportFinalKeySize();
        keySet.setClientWriteKey(KeyDerivator.md5firstNBytes(keySize, keySet.getClientWriteKey(), clientRandom, serverRandom));
        keySet.setServerWriteKey(KeyDerivator.md5firstNBytes(keySize, keySet.getServerWriteKey(), serverRandom, clientRandom));
        int blockSize = cipherSuite.getCipherAlgorithm().getBlocksize();
        keySet.setClientWriteIv(KeyDerivator.md5firstNBytes(blockSize, clientRandom, serverRandom));
        keySet.setServerWriteIv(KeyDerivator.md5firstNBytes(blockSize, serverRandom, clientRandom));
    }

    private static int getSecretSetSize(ProtocolVersion protocolVersion, CipherSuite cipherSuite) throws CryptoException {
        switch (cipherSuite.getCipherType()) {
            case AEAD: {
                return KeyDerivator.getAeadSecretSetSize(protocolVersion, cipherSuite);
            }
            case BLOCK: {
                return KeyDerivator.getBlockSecretSetSize(protocolVersion, cipherSuite);
            }
            case STREAM: {
                return KeyDerivator.getStreamSecretSetSize(protocolVersion, cipherSuite);
            }
        }
        throw new CryptoException("Unknown CipherType");
    }

    private static int getBlockSecretSetSize(ProtocolVersion protocolVersion, CipherSuite cipherSuite) {
        CipherAlgorithm cipherAlg = cipherSuite.getCipherAlgorithm();
        int keySize = cipherAlg.getKeySize();
        MacAlgorithm macAlg = AlgorithmResolver.getMacAlgorithm(protocolVersion, cipherSuite);
        int secretSetSize = 2 * keySize + 2 * macAlg.getKeySize();
        if (!protocolVersion.usesExplicitIv()) {
            secretSetSize += 2 * cipherAlg.getNonceBytesFromHandshake();
        }
        return secretSetSize;
    }

    private static int getAeadSecretSetSize(ProtocolVersion protocolVersion, CipherSuite cipherSuite) {
        CipherAlgorithm cipherAlg = cipherSuite.getCipherAlgorithm();
        int keySize = cipherAlg.getKeySize();
        int saltSize = 12 - cipherAlg.getNonceBytesFromRecord();
        int secretSetSize = 2 * keySize + 2 * saltSize;
        return secretSetSize;
    }

    private static int getStreamSecretSetSize(ProtocolVersion protocolVersion, CipherSuite cipherSuite) {
        CipherAlgorithm cipherAlg = cipherSuite.getCipherAlgorithm();
        MacAlgorithm macAlg = AlgorithmResolver.getMacAlgorithm(protocolVersion, cipherSuite);
        int secretSetSize = 2 * cipherAlg.getKeySize() + 2 * macAlg.getKeySize();
        if (cipherSuite.isStreamCipherWithIV()) {
            secretSetSize += 2 * cipherAlg.getNonceBytesFromHandshake();
        }
        return secretSetSize;
    }

    private KeyDerivator() {
    }
}

