/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.protocol.preparator;

import de.rub.nds.modifiablevariable.util.DataConverter;
import de.rub.nds.protocol.constants.PointFormat;
import de.rub.nds.protocol.crypto.CyclicGroup;
import de.rub.nds.protocol.crypto.ec.EllipticCurve;
import de.rub.nds.protocol.crypto.ec.EllipticCurveSECP256R1;
import de.rub.nds.protocol.crypto.ec.Point;
import de.rub.nds.protocol.crypto.ec.PointFormatter;
import de.rub.nds.protocol.crypto.ec.RFC7748Curve;
import de.rub.nds.protocol.util.SilentByteArrayOutputStream;
import de.rub.nds.tlsattacker.core.constants.ECPointFormat;
import de.rub.nds.tlsattacker.core.constants.EllipticCurveType;
import de.rub.nds.tlsattacker.core.constants.NamedGroup;
import de.rub.nds.tlsattacker.core.constants.SignatureAndHashAlgorithm;
import de.rub.nds.tlsattacker.core.protocol.message.ECDHEServerKeyExchangeMessage;
import de.rub.nds.tlsattacker.core.protocol.message.ServerKeyExchangeMessage;
import de.rub.nds.tlsattacker.core.protocol.preparator.ServerKeyExchangePreparator;
import de.rub.nds.tlsattacker.core.protocol.preparator.selection.SignatureAndHashAlgorithmSelector;
import de.rub.nds.tlsattacker.core.workflow.chooser.Chooser;
import java.math.BigInteger;
import java.util.HashSet;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class ECDHEServerKeyExchangePreparator<T extends ECDHEServerKeyExchangeMessage>
extends ServerKeyExchangePreparator<T> {
    private static final Logger LOGGER = LogManager.getLogger();
    protected final T msg;

    public ECDHEServerKeyExchangePreparator(Chooser chooser, T msg) {
        super(chooser, msg);
        this.msg = msg;
    }

    @Override
    public void prepareHandshakeMessageContents() {
        ((ECDHEServerKeyExchangeMessage)this.msg).prepareKeyExchangeComputations();
        ((ECDHEServerKeyExchangeMessage)this.msg).getKeyExchangeComputations().setPrivateKey(this.chooser.getConfig().getDefaultServerEphemeralEcPrivateKey());
        this.prepareCurveType(this.msg);
        this.prepareEcDhParams();
        SignatureAndHashAlgorithm signHashAlgo = SignatureAndHashAlgorithmSelector.selectSignatureAndHashAlgorithm(this.chooser, false);
        this.prepareSignatureAndHashAlgorithm(this.msg, signHashAlgo);
        byte[] signature = this.generateSignature(signHashAlgo, this.generateSignatureContents(this.msg));
        this.prepareSignature(this.msg, signature);
        this.prepareSignatureLength(this.msg);
    }

    protected void prepareEcDhParams() {
        byte[] publicKeyBytes;
        EllipticCurve curve;
        CyclicGroup group;
        NamedGroup namedGroup = this.selectNamedGroup(this.msg);
        ((ECDHEServerKeyExchangeMessage)this.msg).getKeyExchangeComputations().setNamedGroup(namedGroup.getValue());
        this.prepareNamedGroup(this.msg);
        namedGroup = NamedGroup.getNamedGroup((byte[])((ECDHEServerKeyExchangeMessage)this.msg).getKeyExchangeComputations().getNamedGroup().getValue());
        if (namedGroup == null) {
            LOGGER.warn("Could not deserialize group from computations. Using default group instead");
            namedGroup = this.chooser.getConfig().getDefaultSelectedNamedGroup();
        }
        ECPointFormat pointFormat = this.selectPointFormat(this.msg);
        ((ECDHEServerKeyExchangeMessage)this.msg).getKeyExchangeComputations().setEcPointFormat(pointFormat.getValue());
        pointFormat = ECPointFormat.getECPointFormat((Byte)((ECDHEServerKeyExchangeMessage)this.msg).getKeyExchangeComputations().getEcPointFormat().getValue());
        if (pointFormat == null) {
            LOGGER.warn("Could not deserialize group from computations. Using default point format instead");
            pointFormat = this.chooser.getConfig().getDefaultSelectedPointFormat();
        }
        if ((group = namedGroup.getGroupParameters().getGroup()) instanceof EllipticCurve) {
            curve = (EllipticCurve)group;
        } else {
            LOGGER.warn("Selected group is not an EllipticCurve. Using SECP256R1");
            curve = new EllipticCurveSECP256R1();
        }
        LOGGER.debug("NamedGroup: {} ", (Object)namedGroup.name());
        if (namedGroup.isMontgomery()) {
            RFC7748Curve rfcCurve = (RFC7748Curve)curve;
            publicKeyBytes = rfcCurve.computePublicKey((BigInteger)((ECDHEServerKeyExchangeMessage)this.msg).getKeyExchangeComputations().getPrivateKey().getValue());
        } else if (namedGroup.isEcGroup()) {
            Point publicKey = curve.mult((BigInteger)((ECDHEServerKeyExchangeMessage)this.msg).getKeyExchangeComputations().getPrivateKey().getValue(), curve.getBasePoint());
            publicKeyBytes = PointFormatter.formatToByteArray(namedGroup.getGroupParameters(), (Point)publicKey, (PointFormat)pointFormat.getFormat());
        } else {
            LOGGER.warn("Could not set public key. The selected curve is probably not a real curve. Using empty public key instead");
            publicKeyBytes = new byte[]{};
        }
        ((ServerKeyExchangeMessage)this.msg).setPublicKey(publicKeyBytes);
        ((ServerKeyExchangeMessage)this.msg).setPublicKeyLength(((byte[])((ServerKeyExchangeMessage)this.msg).getPublicKey().getValue()).length);
        this.prepareClientServerRandom(this.msg);
    }

    protected ECPointFormat selectPointFormat(T msg) {
        ECPointFormat selectedFormat;
        if (this.chooser.getConfig().isEnforceSettings().booleanValue()) {
            selectedFormat = this.chooser.getConfig().getDefaultSelectedPointFormat();
        } else {
            HashSet<ECPointFormat> serverSet = new HashSet<ECPointFormat>(this.chooser.getConfig().getDefaultServerSupportedPointFormats());
            HashSet<ECPointFormat> clientSet = new HashSet<ECPointFormat>(this.chooser.getClientSupportedPointFormats());
            serverSet.retainAll(clientSet);
            if (serverSet.isEmpty()) {
                LOGGER.warn("No common ECPointFormat - falling back to default");
                selectedFormat = this.chooser.getConfig().getDefaultSelectedPointFormat();
            } else {
                selectedFormat = serverSet.contains((Object)this.chooser.getConfig().getDefaultSelectedPointFormat()) ? this.chooser.getConfig().getDefaultSelectedPointFormat() : (ECPointFormat)((Object)serverSet.toArray()[0]);
            }
        }
        return selectedFormat;
    }

    protected NamedGroup selectNamedGroup(T msg) {
        NamedGroup namedGroup;
        if (this.chooser.getConfig().isEnforceSettings().booleanValue()) {
            namedGroup = this.chooser.getConfig().getDefaultSelectedNamedGroup();
        } else {
            HashSet<NamedGroup> serverSet = new HashSet<NamedGroup>(this.chooser.getConfig().getDefaultServerNamedGroups());
            HashSet<NamedGroup> clientSet = new HashSet<NamedGroup>(this.chooser.getClientSupportedNamedGroups());
            serverSet.retainAll(clientSet);
            if (serverSet.isEmpty()) {
                LOGGER.warn("No common NamedGroup - falling back to default");
                namedGroup = this.chooser.getConfig().getDefaultSelectedNamedGroup();
            } else {
                namedGroup = serverSet.contains((Object)this.chooser.getConfig().getDefaultSelectedNamedGroup()) ? this.chooser.getConfig().getDefaultSelectedNamedGroup() : (NamedGroup)((Object)serverSet.toArray()[0]);
            }
        }
        if (!namedGroup.isEcGroup() || namedGroup.isGost()) {
            NamedGroup previousNamedGroup = namedGroup;
            namedGroup = NamedGroup.SECP256R1;
            LOGGER.warn("NamedGroup {} is not suitable for ECDHEServerKeyExchange message. Using {} instead", (Object)previousNamedGroup, (Object)namedGroup);
        }
        return namedGroup;
    }

    protected byte[] generateSignatureContents(T msg) {
        EllipticCurveType curveType = this.chooser.getEcCurveType();
        SilentByteArrayOutputStream ecParams = new SilentByteArrayOutputStream();
        switch (curveType) {
            case EXPLICIT_PRIME: 
            case EXPLICIT_CHAR2: {
                throw new UnsupportedOperationException("Signing of explicit curves not implemented yet.");
            }
            case NAMED_CURVE: {
                ecParams.write((int)curveType.getValue());
                ecParams.write((byte[])((ECDHEServerKeyExchangeMessage)msg).getNamedGroup().getValue());
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported curve type: " + String.valueOf((Object)curveType));
            }
        }
        ecParams.write(((Integer)((ServerKeyExchangeMessage)msg).getPublicKeyLength().getValue()).intValue());
        ecParams.write((byte[])((ServerKeyExchangeMessage)msg).getPublicKey().getValue());
        return DataConverter.concatenate((byte[][])new byte[][]{(byte[])((ECDHEServerKeyExchangeMessage)msg).getKeyExchangeComputations().getClientServerRandom().getValue(), ecParams.toByteArray()});
    }

    protected void prepareSignatureAndHashAlgorithm(T msg, SignatureAndHashAlgorithm signHashAlgo) {
        ((ServerKeyExchangeMessage)msg).setSignatureAndHashAlgorithm(signHashAlgo.getByteValue());
        LOGGER.debug("SignatureAndHashAlgorithm: {}", ((ServerKeyExchangeMessage)msg).getSignatureAndHashAlgorithm().getValue());
    }

    protected void prepareClientServerRandom(T msg) {
        ((ECDHEServerKeyExchangeMessage)msg).getKeyExchangeComputations().setClientServerRandom(DataConverter.concatenate((byte[][])new byte[][]{this.chooser.getClientRandom(), this.chooser.getServerRandom()}));
        LOGGER.debug("ClientServerRandom: {}", ((ECDHEServerKeyExchangeMessage)msg).getKeyExchangeComputations().getClientServerRandom().getValue());
    }

    protected void prepareSignature(T msg, byte[] signature) {
        ((ServerKeyExchangeMessage)msg).setSignature(signature);
        LOGGER.debug("Signature: {}", ((ServerKeyExchangeMessage)msg).getSignature().getValue());
    }

    protected void prepareSignatureLength(T msg) {
        ((ServerKeyExchangeMessage)msg).setSignatureLength(((byte[])((ServerKeyExchangeMessage)msg).getSignature().getValue()).length);
        LOGGER.debug("SignatureLength: {}", ((ServerKeyExchangeMessage)msg).getSignatureLength().getValue());
    }

    protected void prepareSerializedPublicKeyLength(T msg) {
        ((ServerKeyExchangeMessage)msg).setPublicKeyLength(((byte[])((ServerKeyExchangeMessage)msg).getPublicKey().getValue()).length);
        LOGGER.debug("SerializedPublicKeyLength: {}", ((ServerKeyExchangeMessage)msg).getPublicKeyLength().getValue());
    }

    protected void prepareCurveType(T msg) {
        ((ECDHEServerKeyExchangeMessage)msg).setCurveType(EllipticCurveType.NAMED_CURVE.getValue());
    }

    protected void prepareNamedGroup(T msg) {
        NamedGroup group = NamedGroup.getNamedGroup((byte[])((ECDHEServerKeyExchangeMessage)msg).getKeyExchangeComputations().getNamedGroup().getValue());
        if (group == null) {
            LOGGER.warn("Could not deserialize group from computations. Using default group instead");
            group = this.chooser.getConfig().getDefaultSelectedNamedGroup();
        }
        ((ECDHEServerKeyExchangeMessage)msg).setNamedGroup(group.getValue());
    }
}

