/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.crypto.hpke;

import de.rub.nds.modifiablevariable.util.DataConverter;
import de.rub.nds.protocol.exception.CryptoException;
import de.rub.nds.tlsattacker.core.constants.HpkeLabel;
import de.rub.nds.tlsattacker.core.constants.hpke.HpkeAeadFunction;
import de.rub.nds.tlsattacker.core.constants.hpke.HpkeKeyDerivationFunction;
import de.rub.nds.tlsattacker.core.constants.hpke.HpkeKeyEncapsulationMechanism;
import de.rub.nds.tlsattacker.core.constants.hpke.HpkeMode;
import de.rub.nds.tlsattacker.core.crypto.HKDFunction;
import de.rub.nds.tlsattacker.core.crypto.KeyShareCalculator;
import de.rub.nds.tlsattacker.core.crypto.hpke.HpkeReceiverContext;
import de.rub.nds.tlsattacker.core.crypto.hpke.HpkeSenderContext;
import de.rub.nds.tlsattacker.core.protocol.message.extension.EchConfig;
import de.rub.nds.tlsattacker.core.protocol.message.extension.keyshare.KeyShareEntry;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class HpkeUtil {
    private final HpkeAeadFunction hpkeAeadFunction;
    private final HpkeKeyDerivationFunction hpkeKeyDerivationFunction;
    private final HpkeKeyEncapsulationMechanism hpkeKeyEncapsulationMechanism;
    private byte[] publicKeyReceiver;
    private byte[] publicKeySender;
    private byte[] sharedSecret;
    private byte[] kemContext;
    private byte[] baseNonce;
    private byte[] key;
    private byte[] exporterSecret;
    private byte[] secret;
    private byte[] keyScheduleContext;
    private static final String DEFAULT_PSK = "";
    private static final String DEFAULT_PSK_ID = "";

    public HpkeUtil(HpkeAeadFunction hpkeAeadFunction, HpkeKeyDerivationFunction hpkeKeyDerivationFunction, HpkeKeyEncapsulationMechanism hpkeKeyEncapsulationMechanism) {
        this.hpkeAeadFunction = hpkeAeadFunction;
        this.hpkeKeyDerivationFunction = hpkeKeyDerivationFunction;
        this.hpkeKeyEncapsulationMechanism = hpkeKeyEncapsulationMechanism;
    }

    public HpkeUtil(EchConfig echConfig) {
        this.hpkeAeadFunction = echConfig.getHpkeAeadFunction();
        this.hpkeKeyDerivationFunction = echConfig.getHpkeKeyDerivationFunction();
        this.hpkeKeyEncapsulationMechanism = echConfig.getKem();
    }

    public HpkeSenderContext setupBaseSender(byte[] publicKeyReceiver, byte[] info, KeyShareEntry keysSender) throws CryptoException {
        this.encap(publicKeyReceiver, keysSender);
        HpkeSenderContext hpkeSenderContext = this.generateKeyScheduleSender(HpkeMode.MODE_BASE, this.sharedSecret, info, "", "");
        return hpkeSenderContext;
    }

    public HpkeReceiverContext setupBaseReceiver(byte[] enc, byte[] info, KeyShareEntry keysReceiver) throws CryptoException {
        this.decap(enc, keysReceiver);
        HpkeReceiverContext hpkeReceiverContext = this.generateKeyScheduleReceiver(HpkeMode.MODE_BASE, this.sharedSecret, info, "", "");
        return hpkeReceiverContext;
    }

    private void encap(byte[] echServerPublicKey, KeyShareEntry keyShareEntry) throws CryptoException {
        this.publicKeyReceiver = echServerPublicKey;
        this.publicKeySender = (byte[])keyShareEntry.getPublicKey().getValue();
        byte[] dh = KeyShareCalculator.computeSharedSecret(this.hpkeKeyEncapsulationMechanism.getNamedGroup(), keyShareEntry.getPrivateKey(), echServerPublicKey);
        this.kemContext = DataConverter.concatenate((byte[][])new byte[][]{(byte[])keyShareEntry.getPublicKey().getValue(), echServerPublicKey});
        this.sharedSecret = this.extractAndExpand(dh, this.kemContext, true);
    }

    private void decap(byte[] enc, KeyShareEntry keysReceiver) throws CryptoException {
        this.publicKeySender = enc;
        this.publicKeyReceiver = (byte[])keysReceiver.getPublicKey().getValue();
        byte[] dh = KeyShareCalculator.computeSharedSecret(this.hpkeKeyEncapsulationMechanism.getNamedGroup(), keysReceiver.getPrivateKey(), enc);
        this.kemContext = DataConverter.concatenate((byte[][])new byte[][]{enc, this.publicKeyReceiver});
        this.sharedSecret = this.extractAndExpand(dh, this.kemContext, true);
    }

    private void verifyPskInputs(HpkeMode mode, String psk, String pskId) throws CryptoException {
        boolean gotPskId;
        boolean gotPsk = !Objects.equals(psk, "");
        boolean bl = gotPskId = !Objects.equals(pskId, "");
        if (gotPsk != gotPskId) {
            throw new CryptoException("Inconsistent PSK inputs");
        }
        if (gotPskId && (mode == HpkeMode.MODE_BASE || mode == HpkeMode.MODE_AUTH)) {
            throw new CryptoException("PSK input provided when not needed");
        }
        if (!(gotPskId || mode != HpkeMode.MODE_PSK && mode != HpkeMode.MODE_AUTH_PSK)) {
            throw new CryptoException("Missing required PSK input");
        }
    }

    private HpkeSenderContext generateKeyScheduleSender(HpkeMode mode, byte[] sharedSecret, byte[] info, String psk, String pskId) throws CryptoException {
        this.verifyPskInputs(mode, psk, pskId);
        byte[] pskIdHash = this.labeledExtract(HpkeLabel.EMPTY.getBytes(), HpkeLabel.PSK_ID_HASH.getBytes(), pskId.getBytes(StandardCharsets.US_ASCII), false);
        byte[] infoHash = this.labeledExtract(HpkeLabel.EMPTY.getBytes(), HpkeLabel.INFO_HASH.getBytes(), info, false);
        this.keyScheduleContext = DataConverter.concatenate((byte[][])new byte[][]{mode.getByteValue(), pskIdHash, infoHash});
        this.secret = this.labeledExtract(sharedSecret, HpkeLabel.SECRET.getBytes(), psk.getBytes(StandardCharsets.US_ASCII), false);
        this.key = this.labeledExpand(this.secret, HpkeLabel.KEY.getBytes(), this.keyScheduleContext, this.hpkeAeadFunction.getKeyLength(), false);
        this.baseNonce = this.labeledExpand(this.secret, HpkeLabel.BASE_NONCE.getBytes(), this.keyScheduleContext, this.hpkeAeadFunction.getNonceLength(), false);
        this.exporterSecret = this.labeledExpand(this.secret, HpkeLabel.EXPAND.getBytes(), this.keyScheduleContext, this.hpkeKeyDerivationFunction.getHashLength(), false);
        return new HpkeSenderContext(this.key, this.baseNonce, 0, this.exporterSecret, this.hpkeAeadFunction);
    }

    private HpkeReceiverContext generateKeyScheduleReceiver(HpkeMode mode, byte[] sharedSecret, byte[] info, String psk, String pskId) throws CryptoException {
        this.verifyPskInputs(mode, psk, pskId);
        byte[] pskIdHash = this.labeledExtract(HpkeLabel.EMPTY.getBytes(), HpkeLabel.PSK_ID_HASH.getBytes(), pskId.getBytes(StandardCharsets.US_ASCII), false);
        byte[] infoHash = this.labeledExtract(HpkeLabel.EMPTY.getBytes(), HpkeLabel.INFO_HASH.getBytes(), info, false);
        this.keyScheduleContext = DataConverter.concatenate((byte[][])new byte[][]{mode.getByteValue(), pskIdHash, infoHash});
        this.secret = this.labeledExtract(sharedSecret, HpkeLabel.SECRET.getBytes(), psk.getBytes(StandardCharsets.US_ASCII), false);
        this.key = this.labeledExpand(this.secret, HpkeLabel.KEY.getBytes(), this.keyScheduleContext, this.hpkeAeadFunction.getKeyLength(), false);
        this.baseNonce = this.labeledExpand(this.secret, HpkeLabel.BASE_NONCE.getBytes(), this.keyScheduleContext, this.hpkeAeadFunction.getNonceLength(), false);
        this.exporterSecret = this.labeledExpand(this.secret, HpkeLabel.EXPAND.getBytes(), this.keyScheduleContext, this.hpkeKeyDerivationFunction.getHashLength(), false);
        return new HpkeReceiverContext(this.key, this.baseNonce, 0, this.exporterSecret, this.hpkeAeadFunction);
    }

    private byte[] getSuiteId(boolean fromKEM) {
        if (fromKEM) {
            byte[] kemId = this.hpkeKeyEncapsulationMechanism.getByteValue();
            return DataConverter.concatenate((byte[][])new byte[][]{HpkeLabel.KEM.getBytes(), kemId});
        }
        byte[] version = HpkeLabel.HPKE.getBytes();
        byte[] kemId = this.hpkeKeyEncapsulationMechanism.getByteValue();
        byte[] aeadId = this.hpkeAeadFunction.getByteValue();
        byte[] hkdfID = this.hpkeKeyDerivationFunction.getByteValue();
        return DataConverter.concatenate((byte[][])new byte[][]{version, kemId, hkdfID, aeadId});
    }

    private byte[] labeledExtract(byte[] salt, byte[] label, byte[] ikm, boolean fromKem) throws CryptoException {
        byte[] labeledIkm = DataConverter.concatenate((byte[][])new byte[][]{HpkeLabel.HPKE_VERSION_1.getBytes(), this.getSuiteId(fromKem), label, ikm});
        return HKDFunction.extract(this.hpkeKeyDerivationFunction.getHkdfAlgorithm(), salt, labeledIkm);
    }

    private byte[] labeledExpand(byte[] prk, byte[] label, byte[] info, int l, boolean fromKem) throws CryptoException {
        byte[] labeledInfo = DataConverter.concatenate((byte[][])new byte[][]{DataConverter.longToBytes((long)l, (int)2), HpkeLabel.HPKE_VERSION_1.getBytes(), this.getSuiteId(fromKem), label, info});
        return HKDFunction.expand(this.hpkeKeyDerivationFunction.getHkdfAlgorithm(), prk, labeledInfo, l);
    }

    private byte[] extractAndExpand(byte[] dh, byte[] kemContext, boolean fromKem) throws CryptoException {
        byte[] eaePrk = this.labeledExtract(HpkeLabel.EMPTY.getBytes(), HpkeLabel.EXTRACT_AND_EXPAND.getBytes(), dh, fromKem);
        return this.labeledExpand(eaePrk, HpkeLabel.SHARED_SECRET.getBytes(), kemContext, this.hpkeKeyEncapsulationMechanism.getSecretLength(), fromKem);
    }

    @Deprecated
    public static int indexOf(byte[] outerArray, byte[] smallerArray) {
        for (int i = 0; i < outerArray.length - smallerArray.length + 1; ++i) {
            boolean found = true;
            for (int j = 0; j < smallerArray.length; ++j) {
                if (outerArray[i + j] == smallerArray[j]) continue;
                found = false;
                break;
            }
            if (!found) continue;
            return i;
        }
        return -1;
    }

    public byte[] getSharedSecret() {
        return this.sharedSecret;
    }

    public byte[] getPublicKeySender() {
        return this.publicKeySender;
    }

    public byte[] getKemContext() {
        return this.kemContext;
    }

    public byte[] getBaseNonce() {
        return this.baseNonce;
    }

    public byte[] getExporterSecret() {
        return this.exporterSecret;
    }

    public byte[] getSecret() {
        return this.secret;
    }

    public byte[] getKeyScheduleContext() {
        return this.keyScheduleContext;
    }

    public byte[] getPublicKeyReceiver() {
        return this.publicKeyReceiver;
    }

    public byte[] getKey() {
        return this.key;
    }
}

