/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.record.cipher;

import de.rub.nds.modifiablevariable.util.DataConverter;
import de.rub.nds.protocol.exception.CryptoException;
import de.rub.nds.protocol.util.SilentByteArrayOutputStream;
import de.rub.nds.tlsattacker.core.constants.ProtocolMessageType;
import de.rub.nds.tlsattacker.core.constants.ProtocolVersion;
import de.rub.nds.tlsattacker.core.crypto.cipher.BaseCipher;
import de.rub.nds.tlsattacker.core.crypto.cipher.DecryptionCipher;
import de.rub.nds.tlsattacker.core.crypto.cipher.EncryptionCipher;
import de.rub.nds.tlsattacker.core.layer.context.TlsContext;
import de.rub.nds.tlsattacker.core.layer.data.Parser;
import de.rub.nds.tlsattacker.core.record.Record;
import de.rub.nds.tlsattacker.core.record.cipher.CipherState;
import de.rub.nds.tlsattacker.transport.ConnectionEndType;
import java.io.ByteArrayInputStream;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.Random;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public abstract class RecordCipher {
    private static final Logger LOGGER = LogManager.getLogger();
    private static final byte[] SEQUENCE_NUMBER_PLACEHOLDER = new byte[]{-1, -1, -1, -1, -1, -1, -1, -1};
    protected DecryptionCipher decryptCipher;
    protected EncryptionCipher encryptCipher;
    protected TlsContext tlsContext;
    private CipherState state;

    public RecordCipher(TlsContext tlsContext, CipherState state) {
        this.tlsContext = tlsContext;
        this.state = state;
    }

    public abstract void encrypt(Record var1) throws CryptoException;

    public abstract void decrypt(Record var1) throws CryptoException;

    public void encryptDtls13SequenceNumber(Record record) {
        byte[] mask;
        int length = this.tlsContext.getConfig().getUseDtls13HeaderSeqNumSizeLongEncoding() != false ? 2 : 1;
        byte[] sequenceNumber = ((BigInteger)record.getSequenceNumber().getValue()).toByteArray();
        if (sequenceNumber.length < 2) {
            sequenceNumber = new byte[]{0, sequenceNumber[0]};
        }
        if (sequenceNumber.length < length) {
            sequenceNumber = Arrays.copyOf(sequenceNumber, length);
        }
        if (this.getState().getKeySet() == null) {
            mask = new byte[length];
            LOGGER.warn("No keys available for DTLS 1.3 mask derivation for sequence number encryption. Using null encryption.");
        } else {
            try {
                mask = ((BaseCipher)this.encryptCipher).getDtls13Mask(this.getState().getKeySet().getWriteSnKey(this.getLocalConnectionEndType()), (byte[])record.getProtocolMessageBytes().getValue());
                if (mask.length < length) {
                    mask = Arrays.copyOf(mask, length);
                    LOGGER.warn("DTLS 1.3 mask does not have enough bytes for encrypting the sequence number. Padding it to the required length with zero bytes.");
                }
            }
            catch (CryptoException ex) {
                LOGGER.warn("Failed to generate DTLS\u00a01.3 mask. Generating a zero\u2011byte mask.");
                mask = new byte[length];
            }
        }
        byte[] encryptedSequenceNumber = new byte[length];
        for (int i = 0; i < length; ++i) {
            encryptedSequenceNumber[i] = (byte)(sequenceNumber[i] ^ mask[i]);
        }
        record.setEncryptedSequenceNumber(encryptedSequenceNumber);
        LOGGER.debug("Encrypted Sequence Number: {}", record.getEncryptedSequenceNumber().getValue());
    }

    public void decryptDtls13SequenceNumber(Record record) {
        byte[] mask;
        byte[] encryptedSequenceNumber = (byte[])record.getEncryptedSequenceNumber().getValue();
        try {
            mask = ((BaseCipher)this.decryptCipher).getDtls13Mask(this.getState().getKeySet().getReadSnKey(this.getLocalConnectionEndType()), (byte[])record.getProtocolMessageBytes().getValue());
            if (mask.length < encryptedSequenceNumber.length) {
                LOGGER.warn("DTLS 1.3 mask does not have enough bytes for decrypting the sequence number. Padding it to the required length with zero bytes.");
                mask = Arrays.copyOf(mask, encryptedSequenceNumber.length);
            }
        }
        catch (CryptoException ex) {
            LOGGER.warn("Failed to generate DTLS\u00a01.3 mask. Generating a zero\u2011byte mask.");
            mask = new byte[encryptedSequenceNumber.length];
        }
        byte[] sequenceNumber = new byte[encryptedSequenceNumber.length];
        for (int i = 0; i < sequenceNumber.length; ++i) {
            sequenceNumber[i] = (byte)(encryptedSequenceNumber[i] ^ mask[i]);
        }
        record.setSequenceNumber(new BigInteger(1, sequenceNumber));
        LOGGER.debug("Decrypted Sequence Number: {}", record.getSequenceNumber().getValue());
    }

    protected final byte[] collectAdditionalAuthenticatedData(Record record, ProtocolVersion protocolVersion) {
        SilentByteArrayOutputStream stream = new SilentByteArrayOutputStream();
        if (protocolVersion.isTLS13()) {
            stream.write((int)((Byte)record.getContentType().getValue()).byteValue());
            stream.write((byte[])record.getProtocolVersion().getValue());
            if (record.getLength() != null && record.getLength().getValue() != null) {
                stream.write(DataConverter.intToBytes((int)((Integer)record.getLength().getValue()), (int)2));
            } else {
                stream.write(DataConverter.intToBytes((int)((byte[])record.getCleanProtocolMessageBytes().getValue()).length, (int)2));
            }
            return stream.toByteArray();
        }
        if (protocolVersion.isDTLS13()) {
            byte firstByte = (Byte)record.getUnifiedHeader().getValue();
            stream.write((int)firstByte);
            if (record.isUnifiedHeaderCidPresent()) {
                stream.write((byte[])record.getConnectionId().getValue());
            }
            byte[] sequenceNumberBytes = DataConverter.longToUint48Bytes((long)((BigInteger)record.getSequenceNumber().getValue()).longValue());
            if (record.isUnifiedHeaderSqnLong()) {
                stream.write(sequenceNumberBytes, sequenceNumberBytes.length - 2, 2);
            } else {
                stream.write(sequenceNumberBytes, sequenceNumberBytes.length - 1, 1);
            }
            if (record.isUnifiedHeaderLengthPresent()) {
                if (record.getLength() != null && record.getLength().getValue() != null) {
                    stream.write(DataConverter.intToBytes((int)((Integer)record.getLength().getValue()), (int)2));
                } else {
                    stream.write(DataConverter.intToBytes((int)((byte[])record.getCleanProtocolMessageBytes().getValue()).length, (int)2));
                }
            }
            return stream.toByteArray();
        }
        if (protocolVersion.isDTLS()) {
            if (ProtocolMessageType.getContentType((Byte)record.getContentType().getValue()) == ProtocolMessageType.TLS12_CID) {
                stream.write(SEQUENCE_NUMBER_PLACEHOLDER);
                stream.write((int)ProtocolMessageType.TLS12_CID.getValue());
                stream.write(((byte[])record.getConnectionId().getValue()).length);
            } else {
                stream.write(DataConverter.intToBytes((int)((Integer)record.getEpoch().getValue()).shortValue(), (int)2));
                stream.write(DataConverter.longToUint48Bytes((long)((BigInteger)record.getSequenceNumber().getValue()).longValue()));
            }
        } else {
            stream.write(DataConverter.longToUint64Bytes((long)((BigInteger)record.getSequenceNumber().getValue()).longValue()));
        }
        stream.write((int)((Byte)record.getContentType().getValue()).byteValue());
        byte[] version = !protocolVersion.isSSL() ? (byte[])record.getProtocolVersion().getValue() : new byte[]{};
        stream.write(version);
        if (protocolVersion.isDTLS() && ProtocolMessageType.getContentType((Byte)record.getContentType().getValue()) == ProtocolMessageType.TLS12_CID) {
            stream.write(DataConverter.intToBytes((int)((Integer)record.getEpoch().getValue()).shortValue(), (int)2));
            stream.write(DataConverter.longToUint48Bytes((long)((BigInteger)record.getSequenceNumber().getValue()).longValue()));
            stream.write((byte[])record.getConnectionId().getValue());
        }
        int length = record.getComputations().getAuthenticatedNonMetaData() == null || record.getComputations().getAuthenticatedNonMetaData().getOriginalValue() == null ? ((byte[])record.getComputations().getPlainRecordBytes().getValue()).length : ((byte[])record.getComputations().getAuthenticatedNonMetaData().getValue()).length;
        stream.write(DataConverter.intToBytes((int)length, (int)2));
        return stream.toByteArray();
    }

    private int countTrailingZeroBytes(byte[] plainRecordBytes) {
        int counter = 0;
        for (int i = plainRecordBytes.length - 1; i < plainRecordBytes.length; --i) {
            if (plainRecordBytes[i] == 0) {
                ++counter;
                continue;
            }
            return counter;
        }
        return counter;
    }

    protected byte[] encapsulateRecordBytes(Record record) {
        byte[] padding = record.getComputations().getPadding() != null ? (byte[])record.getComputations().getPadding().getValue() : new byte[]{};
        return DataConverter.concatenate((byte[][])new byte[][]{(byte[])record.getCleanProtocolMessageBytes().getValue(), {(Byte)record.getContentType().getValue()}, padding});
    }

    protected void parseEncapsulatedRecordBytes(byte[] plainRecordBytes, Record record) {
        int numberOfPaddingBytes = this.countTrailingZeroBytes(plainRecordBytes);
        if (numberOfPaddingBytes == plainRecordBytes.length) {
            LOGGER.warn("Record contains ONLY padding and no content type. Setting clean bytes == plainbytes");
            record.setCleanProtocolMessageBytes(plainRecordBytes);
            return;
        }
        PlaintextParser parser = new PlaintextParser(this, plainRecordBytes);
        byte[] cleanBytes = parser.parseByteArrayField(plainRecordBytes.length - numberOfPaddingBytes - 1);
        byte[] contentType = parser.parseByteArrayField(1);
        byte[] padding = parser.parseByteArrayField(numberOfPaddingBytes);
        record.getComputations().setPadding(padding);
        record.setCleanProtocolMessageBytes(cleanBytes);
        record.setContentType(contentType[0]);
        record.setContentMessageType(ProtocolMessageType.getContentType(contentType[0]));
    }

    public CipherState getState() {
        return this.state;
    }

    public void setState(CipherState state) {
        this.state = state;
    }

    public ConnectionEndType getLocalConnectionEndType() {
        return this.tlsContext.getContext().getConnection().getLocalConnectionEndType();
    }

    public ConnectionEndType getConnectionEndType() {
        return this.tlsContext.getChooser().getConnectionEndType();
    }

    public Integer getDefaultAdditionalPadding() {
        return this.tlsContext.getConfig().getDefaultAdditionalPadding();
    }

    public ConnectionEndType getTalkingConnectionEndType() {
        return this.tlsContext.getTalkingConnectionEndType();
    }

    public Random getRandom() {
        return this.tlsContext.getRandom();
    }

    class PlaintextParser
    extends Parser<Object> {
        public PlaintextParser(RecordCipher this$0, byte[] array) {
            super(new ByteArrayInputStream(array));
        }

        @Override
        public void parse(Object t) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override
        public byte[] parseByteArrayField(int length) {
            return super.parseByteArrayField(length);
        }
    }
}

