/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.layer.impl;

import de.rub.nds.modifiablevariable.util.DataConverter;
import de.rub.nds.protocol.exception.EndOfStreamException;
import de.rub.nds.protocol.exception.TimeoutException;
import de.rub.nds.protocol.util.SilentByteArrayOutputStream;
import de.rub.nds.tlsattacker.core.constants.ExtensionType;
import de.rub.nds.tlsattacker.core.constants.HandshakeMessageType;
import de.rub.nds.tlsattacker.core.constants.ProtocolMessageType;
import de.rub.nds.tlsattacker.core.constants.ProtocolVersion;
import de.rub.nds.tlsattacker.core.layer.LayerConfiguration;
import de.rub.nds.tlsattacker.core.layer.LayerProcessingResult;
import de.rub.nds.tlsattacker.core.layer.ProtocolLayer;
import de.rub.nds.tlsattacker.core.layer.constant.ImplementedLayers;
import de.rub.nds.tlsattacker.core.layer.context.TlsContext;
import de.rub.nds.tlsattacker.core.layer.data.Handler;
import de.rub.nds.tlsattacker.core.layer.data.Serializer;
import de.rub.nds.tlsattacker.core.layer.hints.LayerProcessingHint;
import de.rub.nds.tlsattacker.core.layer.hints.QuicFrameLayerHint;
import de.rub.nds.tlsattacker.core.layer.hints.RecordLayerHint;
import de.rub.nds.tlsattacker.core.layer.impl.QuicFrameLayer;
import de.rub.nds.tlsattacker.core.layer.stream.HintedInputStream;
import de.rub.nds.tlsattacker.core.layer.stream.HintedLayerInputStream;
import de.rub.nds.tlsattacker.core.protocol.MessageFactory;
import de.rub.nds.tlsattacker.core.protocol.ProtocolMessage;
import de.rub.nds.tlsattacker.core.protocol.ProtocolMessageHandler;
import de.rub.nds.tlsattacker.core.protocol.ProtocolMessageParser;
import de.rub.nds.tlsattacker.core.protocol.ProtocolMessagePreparator;
import de.rub.nds.tlsattacker.core.protocol.message.AckMessage;
import de.rub.nds.tlsattacker.core.protocol.message.AlertMessage;
import de.rub.nds.tlsattacker.core.protocol.message.ApplicationMessage;
import de.rub.nds.tlsattacker.core.protocol.message.ChangeCipherSpecMessage;
import de.rub.nds.tlsattacker.core.protocol.message.CoreClientHelloMessage;
import de.rub.nds.tlsattacker.core.protocol.message.HandshakeMessage;
import de.rub.nds.tlsattacker.core.protocol.message.HeartbeatMessage;
import de.rub.nds.tlsattacker.core.protocol.message.ServerHelloMessage;
import de.rub.nds.tlsattacker.core.protocol.message.UnknownHandshakeMessage;
import de.rub.nds.tlsattacker.core.protocol.message.UnknownMessage;
import de.rub.nds.tlsattacker.core.state.Context;
import de.rub.nds.tlsattacker.transport.ConnectionEndType;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class MessageLayer
extends ProtocolLayer<Context, LayerProcessingHint, ProtocolMessage> {
    private static final Logger LOGGER = LogManager.getLogger();
    private final Context context;
    private final TlsContext tlsContext;

    public MessageLayer(Context context) {
        super(ImplementedLayers.MESSAGE);
        this.context = context;
        this.tlsContext = context.getTlsContext();
    }

    @Override
    public LayerProcessingResult<ProtocolMessage> sendConfiguration() throws IOException {
        LayerConfiguration configuration = this.getLayerConfiguration();
        ProtocolMessageType runningProtocolMessageType = null;
        LinkedList<byte[]> bufferedMessages = new LinkedList<byte[]>();
        if (configuration != null && configuration.getContainerList() != null) {
            for (ProtocolMessage message : this.getUnprocessedConfiguredContainers()) {
                if (this.containerAlreadyUsedByHigherLayer(message) || !this.prepareDataContainer(message, this.context)) continue;
                if (!message.isHandshakeMessage()) {
                    this.flushCollectedMessages(runningProtocolMessageType, bufferedMessages, false);
                }
                runningProtocolMessageType = message.getProtocolMessageType();
                this.processMessage(message, bufferedMessages);
                this.addProducedContainer(message);
            }
        }
        this.flushCollectedMessages(runningProtocolMessageType, bufferedMessages, false);
        return this.getLayerResult();
    }

    private void processMessage(ProtocolMessage message, List<byte[]> bufferedMessages) throws IOException {
        Serializer serializer = message.getSerializer(this.context);
        byte[] serializedMessage = serializer.serialize();
        message.setCompleteResultingMessage(serializedMessage);
        Handler handler = message.getHandler(this.context);
        ((ProtocolMessageHandler)handler).updateDigest(message, true);
        if (message.getAdjustContext()) {
            handler.adjustContext(message);
        }
        bufferedMessages.add((byte[])message.getCompleteResultingMessage().getValue());
        if (this.mustFlushCollectedMessagesImmediately(message)) {
            boolean isFirstMessage = message instanceof CoreClientHelloMessage || message.getClass() == ServerHelloMessage.class;
            this.flushCollectedMessages(message.getProtocolMessageType(), bufferedMessages, isFirstMessage);
        }
        if (message.getAdjustContext()) {
            ((ProtocolMessageHandler)handler).adjustContextAfterSerialize(message);
        }
    }

    private void flushCollectedMessages(ProtocolMessageType runningProtocolMessageType, List<byte[]> bufferedMessages, boolean isFirstMessage) throws IOException {
        if (bufferedMessages.size() > 0) {
            byte[] allBufferedMessageBytes = this.collectBufferedBytes(bufferedMessages);
            LOGGER.debug("Handing {} serialized message(s) ({} bytes) down to lower layer", (Object)bufferedMessages.size(), (Object)allBufferedMessageBytes.length);
            if (this.context.getLayerStack().getLayer(QuicFrameLayer.class) != null) {
                this.getLowerLayer().sendData(new QuicFrameLayerHint(runningProtocolMessageType, isFirstMessage), allBufferedMessageBytes);
            } else {
                this.getLowerLayer().sendData(new RecordLayerHint(runningProtocolMessageType), allBufferedMessageBytes);
            }
            bufferedMessages.clear();
        }
    }

    private byte[] collectBufferedBytes(List<byte[]> bufferedMessages) {
        SilentByteArrayOutputStream byteStream = new SilentByteArrayOutputStream();
        for (byte[] message : bufferedMessages) {
            byteStream.write(message);
        }
        return byteStream.toByteArray();
    }

    private boolean mustFlushCollectedMessagesImmediately(ProtocolMessage message) {
        if (!this.context.getConfig().getSendHandshakeMessagesWithinSingleRecord().booleanValue()) {
            return true;
        }
        if (message.getProtocolMessageType() == ProtocolMessageType.CHANGE_CIPHER_SPEC) {
            return true;
        }
        if (message.isHandshakeMessage() && this.tlsContext.getSelectedProtocolVersion() == ProtocolVersion.TLS13) {
            HandshakeMessage handshakeMessage = (HandshakeMessage)message;
            if (handshakeMessage.getHandshakeMessageType() == HandshakeMessageType.SERVER_HELLO) {
                return ((ServerHelloMessage)message).hasTls13HelloRetryRequestRandom() == false;
            }
            if (handshakeMessage.getHandshakeMessageType() == HandshakeMessageType.FINISHED || handshakeMessage.getHandshakeMessageType() == HandshakeMessageType.KEY_UPDATE || handshakeMessage.getHandshakeMessageType() == HandshakeMessageType.END_OF_EARLY_DATA) {
                return true;
            }
            if (handshakeMessage.getHandshakeMessageType() == HandshakeMessageType.CLIENT_HELLO && this.context.getChooser().getConnectionEndType() == ConnectionEndType.CLIENT && this.tlsContext.isExtensionProposed(ExtensionType.EARLY_DATA)) {
                return true;
            }
        }
        return false;
    }

    @Override
    public LayerProcessingResult<ProtocolMessage> sendData(LayerProcessingHint hint, byte[] additionalData) throws IOException {
        LayerConfiguration<ProtocolMessage> configuration = this.getLayerConfiguration();
        ApplicationMessage applicationMessage = this.getConfiguredApplicationMessage(configuration);
        if (applicationMessage == null) {
            applicationMessage = new ApplicationMessage();
        } else if (applicationMessage.getDataConfig() != null) {
            LOGGER.warn("Found Application message with pre configured content while sending HTTP message. Configured content will be replaced.");
        }
        applicationMessage.setDataConfig(additionalData);
        if (this.context.getLayerStack().getLayer(QuicFrameLayer.class) != null) {
            this.getLowerLayer().sendData(new QuicFrameLayerHint(ProtocolMessageType.APPLICATION_DATA), additionalData);
        } else {
            this.getLowerLayer().sendData(new RecordLayerHint(ProtocolMessageType.APPLICATION_DATA), additionalData);
        }
        this.addProducedContainer(applicationMessage);
        return this.getLayerResult();
    }

    public ApplicationMessage getConfiguredApplicationMessage(LayerConfiguration<ProtocolMessage> configuration) {
        if (configuration != null && configuration.getContainerList() != null) {
            for (ProtocolMessage configuredMessage : this.getUnprocessedConfiguredContainers()) {
                if (configuredMessage.getProtocolMessageType() != ProtocolMessageType.APPLICATION_DATA) continue;
                return (ApplicationMessage)configuredMessage;
            }
        }
        return null;
    }

    @Override
    public LayerProcessingResult<ProtocolMessage> receiveData() {
        try {
            do {
                HintedInputStream dataStream;
                try {
                    dataStream = this.getLowerLayer().getDataStream();
                    if (dataStream.available() == 0) {
                        LOGGER.warn("The lower layer did not produce any data.");
                        return this.getLayerResult();
                    }
                }
                catch (IOException e) {
                    LOGGER.warn("The lower layer did not produce a data stream: ", (Throwable)e);
                    return this.getLayerResult();
                }
                LayerProcessingHint tempHint = dataStream.getHint();
                if (tempHint == null) {
                    LOGGER.warn("The TLS message layer requires a processing hint. E.g. a record type. Parsing as an unknown message");
                    this.readUnknownProtocolData();
                    continue;
                }
                if (!(tempHint instanceof RecordLayerHint)) continue;
                RecordLayerHint hint = (RecordLayerHint)dataStream.getHint();
                this.readMessageForHint(hint);
            } while (this.shouldContinueProcessing());
        }
        catch (TimeoutException ex) {
            LOGGER.debug("Received a timeout");
            LOGGER.trace((Object)ex);
            this.setReachedTimeout(true);
        }
        catch (EndOfStreamException ex) {
            LOGGER.debug("Reached end of stream, cannot parse more messages");
            LOGGER.trace((Object)ex);
        }
        return this.getLayerResult();
    }

    public void readMessageForHint(RecordLayerHint hint) {
        switch (hint.getType()) {
            case ALERT: {
                this.readAlertProtocolData();
                break;
            }
            case APPLICATION_DATA: {
                this.readAppDataProtocolData();
                break;
            }
            case CHANGE_CIPHER_SPEC: {
                this.readCcsProtocolData(hint.getEpoch());
                break;
            }
            case HANDSHAKE: {
                this.readHandshakeProtocolData();
                break;
            }
            case HEARTBEAT: {
                this.readHeartbeatProtocolData();
                break;
            }
            case ACK: {
                this.readAckProtocolData();
                break;
            }
            case UNKNOWN: {
                this.readUnknownProtocolData();
                break;
            }
            default: {
                this.readUnknownProtocolData();
                LOGGER.warn("Undefined record layer type ({})", hint.getType() == null ? "null" : hint.getType());
            }
        }
    }

    private void readAlertProtocolData() {
        AlertMessage message = new AlertMessage();
        this.readDataContainer(message, this.context);
    }

    private ApplicationMessage readAppDataProtocolData() {
        ApplicationMessage message = new ApplicationMessage();
        this.readDataContainer(message, this.context);
        this.getLowerLayer().removeDrainedInputStream();
        return message;
    }

    private void readCcsProtocolData(Integer epoch) {
        ChangeCipherSpecMessage message = new ChangeCipherSpecMessage();
        if (this.tlsContext.getSelectedProtocolVersion() != null && this.tlsContext.getSelectedProtocolVersion().isDTLS()) {
            if (this.tlsContext.getDtlsReceivedChangeCipherSpecEpochs().contains(epoch) && this.tlsContext.getConfig().isIgnoreRetransmittedCcsInDtls().booleanValue()) {
                message.setAdjustContext(false);
            } else {
                this.tlsContext.addDtlsReceivedChangeCipherSpecEpochs(epoch);
            }
        }
        this.readDataContainer(message, this.context);
    }

    private void readHandshakeProtocolData() {
        byte[] payload;
        int length;
        HandshakeMessage handshakeMessage;
        byte type;
        HintedInputStream handshakeStream;
        SilentByteArrayOutputStream readBytesStream = new SilentByteArrayOutputStream();
        try {
            handshakeStream = this.getLowerLayer().getDataStream();
            type = handshakeStream.readByte();
            readBytesStream.write(new byte[]{type});
            handshakeMessage = MessageFactory.generateHandshakeMessage(HandshakeMessageType.getMessageType(type), this.tlsContext);
            handshakeMessage.setType(type);
            byte[] lengthBytes = handshakeStream.readChunk(3);
            length = DataConverter.bytesToInt((byte[])lengthBytes);
            readBytesStream.write(lengthBytes);
            handshakeMessage.setLength(length);
            payload = handshakeStream.readChunk(length);
            readBytesStream.write(payload);
        }
        catch (IOException ex) {
            LOGGER.error("Could not parse message header. Setting bytes as unread: ", (Throwable)ex);
            this.setUnreadBytes(DataConverter.concatenate((byte[][])new byte[][]{this.getUnreadBytes(), readBytesStream.toByteArray()}));
            return;
        }
        ProtocolMessageHandler handler = handshakeMessage.getHandler(this.context);
        handshakeMessage.setMessageContent(payload);
        try {
            handshakeMessage.setCompleteResultingMessage(DataConverter.concatenate((byte[][])new byte[][]{{type}, DataConverter.intToBytes((int)length, (int)3), payload}));
            ProtocolMessageParser parser = handshakeMessage.getParser(this.context, new ByteArrayInputStream(payload));
            parser.parse(handshakeMessage);
            ProtocolMessagePreparator preparator = handshakeMessage.getPreparator(this.context);
            preparator.prepareAfterParse();
            if (this.context.getChooser().getSelectedProtocolVersion().isDTLS()) {
                handshakeMessage.setMessageSequence(((RecordLayerHint)handshakeStream.getHint()).getMessageSequence());
            }
            handler.updateDigest(handshakeMessage, false);
            handler.adjustContext(handshakeMessage);
            this.addProducedContainer(handshakeMessage);
        }
        catch (RuntimeException ex) {
            LOGGER.warn("Failed to parse HandshakeMessage using assumed type {}", (Object)HandshakeMessageType.getMessageType(type));
            LOGGER.trace((Object)ex);
            UnknownHandshakeMessage message = new UnknownHandshakeMessage();
            message.setAssumedType(type);
            message.setData(payload);
            this.addProducedContainer(message);
        }
    }

    private void readHeartbeatProtocolData() {
        HeartbeatMessage message = new HeartbeatMessage();
        this.readDataContainer(message, this.context);
    }

    private void readAckProtocolData() {
        AckMessage message = new AckMessage();
        this.readDataContainer(message, this.context);
    }

    private void readUnknownProtocolData() {
        UnknownMessage message = new UnknownMessage();
        this.readDataContainer(message, this.context);
        this.getLowerLayer().removeDrainedInputStream();
    }

    @Override
    public void receiveMoreDataForHint(LayerProcessingHint hint) {
        boolean continueProcessing;
        do {
            try {
                HintedInputStream dataStream;
                try {
                    dataStream = this.getLowerLayer().getDataStream();
                }
                catch (IOException e) {
                    LOGGER.warn("The lower layer did not produce a data stream: ", (Throwable)e);
                    return;
                }
                LayerProcessingHint inputStreamHint = dataStream.getHint();
                if (inputStreamHint == null) {
                    LOGGER.warn("The TLS message layer requires a processing hint. E.g. a record type. Parsing as an unknown message");
                    this.readUnknownProtocolData();
                    continueProcessing = false;
                    continue;
                }
                if (inputStreamHint instanceof RecordLayerHint) {
                    RecordLayerHint recordLayerHint = (RecordLayerHint)inputStreamHint;
                    if (recordLayerHint.getType() == ProtocolMessageType.APPLICATION_DATA) {
                        ApplicationMessage receivedAppData = this.readAppDataProtocolData();
                        this.passToHigherLayer(receivedAppData, hint);
                        continueProcessing = false;
                        continue;
                    }
                    this.readMessageForHint(recordLayerHint);
                    continueProcessing = true;
                    continue;
                }
                continueProcessing = false;
            }
            catch (TimeoutException ex) {
                LOGGER.debug("Received a timeout");
                LOGGER.trace((Object)ex);
                continueProcessing = false;
            }
            catch (EndOfStreamException ex) {
                LOGGER.debug("Reached end of stream, cannot parse more messages");
                LOGGER.trace((Object)ex);
                continueProcessing = false;
            }
        } while (continueProcessing);
    }

    public void passToHigherLayer(ApplicationMessage receivedAppData, LayerProcessingHint hint) {
        LOGGER.debug("Passing the following Application Data to higher layer: {}", receivedAppData.getData().getValue());
        if (this.currentInputStream == null) {
            this.currentInputStream = new HintedLayerInputStream(hint, this);
        } else {
            this.currentInputStream.setHint(hint);
        }
        this.currentInputStream.extendStream((byte[])receivedAppData.getData().getValue());
    }
}

