/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.protocol.handler;

import de.rub.nds.protocol.exception.AdjustmentException;
import de.rub.nds.protocol.exception.CryptoException;
import de.rub.nds.tlsattacker.core.constants.AlgorithmResolver;
import de.rub.nds.tlsattacker.core.constants.HKDFAlgorithm;
import de.rub.nds.tlsattacker.core.constants.Tls13KeySetType;
import de.rub.nds.tlsattacker.core.crypto.HKDFunction;
import de.rub.nds.tlsattacker.core.layer.context.TlsContext;
import de.rub.nds.tlsattacker.core.protocol.handler.HandshakeMessageHandler;
import de.rub.nds.tlsattacker.core.protocol.message.KeyUpdateMessage;
import de.rub.nds.tlsattacker.core.record.cipher.RecordCipher;
import de.rub.nds.tlsattacker.core.record.cipher.RecordCipherFactory;
import de.rub.nds.tlsattacker.core.record.cipher.cryptohelper.KeyDerivator;
import de.rub.nds.tlsattacker.core.record.cipher.cryptohelper.KeySet;
import de.rub.nds.tlsattacker.transport.ConnectionEndType;
import java.security.NoSuchAlgorithmException;
import javax.crypto.Mac;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class KeyUpdateHandler
extends HandshakeMessageHandler<KeyUpdateMessage> {
    private static final Logger LOGGER = LogManager.getLogger();

    public KeyUpdateHandler(TlsContext tlsContext) {
        super(tlsContext);
    }

    @Override
    public void adjustContext(KeyUpdateMessage message) {
        if (this.tlsContext.getChooser().getTalkingConnectionEnd() != this.tlsContext.getChooser().getConnectionEndType()) {
            this.adjustApplicationTrafficSecrets();
            this.setRecordCipher(Tls13KeySetType.APPLICATION_TRAFFIC_SECRETS);
        }
    }

    @Override
    public void adjustContextAfterSerialize(KeyUpdateMessage message) {
        this.adjustApplicationTrafficSecrets();
        this.setRecordCipher(Tls13KeySetType.APPLICATION_TRAFFIC_SECRETS);
    }

    private void adjustApplicationTrafficSecrets() {
        HKDFAlgorithm hkdfAlgortihm = AlgorithmResolver.getHKDFAlgorithm(this.tlsContext.getChooser().getSelectedCipherSuite());
        try {
            Mac mac = Mac.getInstance(hkdfAlgortihm.getMacAlgorithm().getJavaName());
            if (this.tlsContext.getChooser().getTalkingConnectionEnd() == ConnectionEndType.CLIENT) {
                byte[] clientApplicationTrafficSecret = HKDFunction.expandLabel(hkdfAlgortihm, this.tlsContext.getChooser().getClientApplicationTrafficSecret(), "traffic upd", new byte[0], mac.getMacLength(), this.tlsContext.getChooser().getSelectedProtocolVersion());
                this.tlsContext.setClientApplicationTrafficSecret(clientApplicationTrafficSecret);
                LOGGER.debug("Set clientApplicationTrafficSecret in Context to {}", (Object)clientApplicationTrafficSecret);
            } else {
                byte[] serverApplicationTrafficSecret = HKDFunction.expandLabel(hkdfAlgortihm, this.tlsContext.getChooser().getServerApplicationTrafficSecret(), "traffic upd", new byte[0], mac.getMacLength(), this.tlsContext.getChooser().getSelectedProtocolVersion());
                this.tlsContext.setServerApplicationTrafficSecret(serverApplicationTrafficSecret);
                LOGGER.debug("Set serverApplicationTrafficSecret in Context to {}", (Object)serverApplicationTrafficSecret);
            }
        }
        catch (CryptoException | NoSuchAlgorithmException ex) {
            throw new AdjustmentException(ex);
        }
    }

    private KeySet getKeySet(TlsContext tlsContext, Tls13KeySetType keySetType) {
        try {
            LOGGER.debug("Generating new KeySet");
            KeySet keySet = KeyDerivator.generateKeySet(tlsContext, tlsContext.getChooser().getSelectedProtocolVersion(), keySetType);
            return keySet;
        }
        catch (CryptoException | NoSuchAlgorithmException ex) {
            throw new UnsupportedOperationException("The specified Algorithm is not supported", ex);
        }
    }

    private void setRecordCipher(Tls13KeySetType keySetType) {
        try {
            KeySet keySet;
            int AEAD_IV_LENGTH = 12;
            HKDFAlgorithm hkdfAlgortihm = AlgorithmResolver.getHKDFAlgorithm(this.tlsContext.getChooser().getSelectedCipherSuite());
            if (this.tlsContext.getChooser().getTalkingConnectionEnd() == ConnectionEndType.CLIENT) {
                this.tlsContext.setActiveClientKeySetType(keySetType);
                LOGGER.debug("Setting cipher for client to use {}", (Object)keySetType);
                keySet = this.getKeySet(this.tlsContext, this.tlsContext.getActiveClientKeySetType());
            } else {
                this.tlsContext.setActiveServerKeySetType(keySetType);
                LOGGER.debug("Setting cipher for server to use {}", (Object)keySetType);
                keySet = this.getKeySet(this.tlsContext, this.tlsContext.getActiveServerKeySetType());
            }
            if (this.tlsContext.getChooser().getTalkingConnectionEnd() == this.tlsContext.getChooser().getConnectionEndType()) {
                if (this.tlsContext.getChooser().getConnectionEndType() == ConnectionEndType.CLIENT) {
                    keySet.setClientWriteIv(HKDFunction.expandLabel(hkdfAlgortihm, this.tlsContext.getClientApplicationTrafficSecret(), "iv", new byte[0], AEAD_IV_LENGTH, this.tlsContext.getChooser().getSelectedProtocolVersion()));
                    keySet.setClientWriteKey(HKDFunction.expandLabel(hkdfAlgortihm, this.tlsContext.getClientApplicationTrafficSecret(), "key", new byte[0], this.tlsContext.getChooser().getSelectedCipherSuite().getCipherAlgorithm().getKeySize(), this.tlsContext.getChooser().getSelectedProtocolVersion()));
                } else {
                    keySet.setServerWriteIv(HKDFunction.expandLabel(hkdfAlgortihm, this.tlsContext.getServerApplicationTrafficSecret(), "iv", new byte[0], AEAD_IV_LENGTH, this.tlsContext.getChooser().getSelectedProtocolVersion()));
                    keySet.setServerWriteKey(HKDFunction.expandLabel(hkdfAlgortihm, this.tlsContext.getServerApplicationTrafficSecret(), "key", new byte[0], this.tlsContext.getChooser().getSelectedCipherSuite().getCipherAlgorithm().getKeySize(), this.tlsContext.getChooser().getSelectedProtocolVersion()));
                }
                RecordCipher recordCipherClient = RecordCipherFactory.getRecordCipher(this.tlsContext, keySet, true);
                this.tlsContext.getRecordLayer().updateEncryptionCipher(recordCipherClient);
            } else if (this.tlsContext.getChooser().getTalkingConnectionEnd() != this.tlsContext.getChooser().getConnectionEndType()) {
                if (this.tlsContext.getChooser().getTalkingConnectionEnd() == ConnectionEndType.SERVER) {
                    keySet.setServerWriteIv(HKDFunction.expandLabel(hkdfAlgortihm, this.tlsContext.getServerApplicationTrafficSecret(), "iv", new byte[0], AEAD_IV_LENGTH, this.tlsContext.getChooser().getSelectedProtocolVersion()));
                    keySet.setServerWriteKey(HKDFunction.expandLabel(hkdfAlgortihm, this.tlsContext.getServerApplicationTrafficSecret(), "key", new byte[0], this.tlsContext.getChooser().getSelectedCipherSuite().getCipherAlgorithm().getKeySize(), this.tlsContext.getChooser().getSelectedProtocolVersion()));
                } else {
                    keySet.setClientWriteIv(HKDFunction.expandLabel(hkdfAlgortihm, this.tlsContext.getClientApplicationTrafficSecret(), "iv", new byte[0], AEAD_IV_LENGTH, this.tlsContext.getChooser().getSelectedProtocolVersion()));
                    keySet.setClientWriteKey(HKDFunction.expandLabel(hkdfAlgortihm, this.tlsContext.getClientApplicationTrafficSecret(), "key", new byte[0], this.tlsContext.getChooser().getSelectedCipherSuite().getCipherAlgorithm().getKeySize(), this.tlsContext.getChooser().getSelectedProtocolVersion()));
                }
                RecordCipher recordCipherClient = RecordCipherFactory.getRecordCipher(this.tlsContext, keySet, false);
                this.tlsContext.getRecordLayer().updateDecryptionCipher(recordCipherClient);
            }
        }
        catch (CryptoException ex) {
            throw new AdjustmentException((Throwable)ex);
        }
    }
}

