/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.layer.impl;

import de.rub.nds.protocol.exception.CryptoException;
import de.rub.nds.protocol.exception.EndOfStreamException;
import de.rub.nds.protocol.exception.TimeoutException;
import de.rub.nds.protocol.util.SilentByteArrayOutputStream;
import de.rub.nds.tlsattacker.core.layer.AcknowledgingProtocolLayer;
import de.rub.nds.tlsattacker.core.layer.LayerConfiguration;
import de.rub.nds.tlsattacker.core.layer.LayerProcessingResult;
import de.rub.nds.tlsattacker.core.layer.constant.ImplementedLayers;
import de.rub.nds.tlsattacker.core.layer.hints.LayerProcessingHint;
import de.rub.nds.tlsattacker.core.layer.hints.QuicPacketLayerHint;
import de.rub.nds.tlsattacker.core.layer.stream.HintedInputStream;
import de.rub.nds.tlsattacker.core.layer.stream.HintedLayerInputStream;
import de.rub.nds.tlsattacker.core.quic.constants.QuicPacketType;
import de.rub.nds.tlsattacker.core.quic.constants.QuicVersion;
import de.rub.nds.tlsattacker.core.quic.crypto.QuicDecryptor;
import de.rub.nds.tlsattacker.core.quic.crypto.QuicEncryptor;
import de.rub.nds.tlsattacker.core.quic.packet.HandshakePacket;
import de.rub.nds.tlsattacker.core.quic.packet.InitialPacket;
import de.rub.nds.tlsattacker.core.quic.packet.OneRTTPacket;
import de.rub.nds.tlsattacker.core.quic.packet.QuicPacket;
import de.rub.nds.tlsattacker.core.quic.packet.RetryPacket;
import de.rub.nds.tlsattacker.core.quic.packet.StatelessResetPseudoPacket;
import de.rub.nds.tlsattacker.core.quic.packet.VersionNegotiationPacket;
import de.rub.nds.tlsattacker.core.quic.packet.ZeroRTTPacket;
import de.rub.nds.tlsattacker.core.state.Context;
import de.rub.nds.tlsattacker.core.state.quic.QuicContext;
import java.io.IOException;
import java.io.InputStream;
import java.net.PortUnreachableException;
import java.net.SocketTimeoutException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class QuicPacketLayer
extends AcknowledgingProtocolLayer<Context, QuicPacketLayerHint, QuicPacket> {
    private static final Logger LOGGER = LogManager.getLogger();
    private final Context context;
    private final QuicContext quicContext;
    private final QuicDecryptor decryptor;
    private final QuicEncryptor encryptor;
    private final Map<QuicPacketType, ArrayList<QuicPacket>> receivedPacketBuffer = new HashMap<QuicPacketType, ArrayList<QuicPacket>>();
    private boolean temporarilyDisabledAcks = false;

    public QuicPacketLayer(Context context) {
        super(ImplementedLayers.QUICPACKET);
        this.context = context;
        this.quicContext = context.getQuicContext();
        this.decryptor = new QuicDecryptor(context.getQuicContext());
        this.encryptor = new QuicEncryptor(context.getQuicContext());
        Arrays.stream(QuicPacketType.values()).forEach(quicPacketType -> this.receivedPacketBuffer.put((QuicPacketType)((Object)quicPacketType), new ArrayList()));
    }

    @Override
    public LayerProcessingResult<QuicPacket> sendConfiguration() throws IOException {
        LayerConfiguration configuration = this.getLayerConfiguration();
        if (configuration != null && configuration.getContainerList() != null) {
            for (QuicPacket packet : this.getUnprocessedConfiguredContainers()) {
                if (packet.getPacketType().isFrameContainer() && this.isEmptyPacket(packet)) continue;
                try {
                    byte[] bytes = this.writePacket(packet);
                    this.addProducedContainer(packet);
                    this.getLowerLayer().sendData(null, bytes);
                }
                catch (CryptoException ex) {
                    LOGGER.error((Object)ex);
                }
            }
        }
        return this.getLayerResult();
    }

    @Override
    public LayerProcessingResult<QuicPacket> sendData(LayerProcessingHint hint, byte[] data) throws IOException {
        QuicPacketType hintedType = QuicPacketType.UNKNOWN;
        if (hint != null && hint instanceof QuicPacketLayerHint) {
            hintedType = ((QuicPacketLayerHint)hint).getQuicPacketType();
        } else {
            LOGGER.warn("Sending packet without a LayerProcessing hint. Using UNKNOWN as the type.");
        }
        if (hintedType == QuicPacketType.HANDSHAKE_PACKET && !this.quicContext.isHandshakeSecretsInitialized()) {
            LOGGER.debug("Processing Hint was Handshake Packet, but Handshake Secrets are not initialized yet. Downgrading to Initial Packet.");
            hintedType = QuicPacketType.INITIAL_PACKET;
        }
        List givenPackets = this.getUnprocessedConfiguredContainers();
        try {
            if (this.getLayerConfiguration().getContainerList() != null && !givenPackets.isEmpty()) {
                QuicPacket packet = (QuicPacket)givenPackets.getFirst();
                byte[] bytes = this.writePacket(data, packet);
                this.addProducedContainer(packet);
                this.getLowerLayer().sendData(null, bytes);
            } else {
                QuicPacket packet = switch (hintedType) {
                    case QuicPacketType.INITIAL_PACKET -> new InitialPacket();
                    case QuicPacketType.HANDSHAKE_PACKET -> new HandshakePacket();
                    case QuicPacketType.ONE_RTT_PACKET -> new OneRTTPacket();
                    case QuicPacketType.ZERO_RTT_PACKET -> new ZeroRTTPacket();
                    case QuicPacketType.RETRY_PACKET -> new RetryPacket();
                    case QuicPacketType.VERSION_NEGOTIATION -> new VersionNegotiationPacket();
                    default -> throw new UnsupportedOperationException("Unknown Packet - Not supported yet.");
                };
                byte[] packetBytes = this.writePacket(data, packet);
                this.addProducedContainer(packet);
                this.getLowerLayer().sendData(null, packetBytes);
            }
        }
        catch (CryptoException ex) {
            LOGGER.error((Object)ex);
        }
        return this.getLayerResult();
    }

    @Override
    public LayerProcessingResult<QuicPacket> receiveData() {
        try {
            do {
                HintedInputStream dataStream = this.getLowerLayer().getDataStream();
                this.readPackets(dataStream);
            } while (this.shouldContinueProcessing());
        }
        catch (TimeoutException | SocketTimeoutException ex) {
            LOGGER.debug("Received a timeout");
            LOGGER.trace((Object)ex);
        }
        catch (PortUnreachableException ex) {
            LOGGER.debug("Destination port undreachable");
            LOGGER.trace((Object)ex);
        }
        catch (EndOfStreamException ex) {
            LOGGER.debug("Reached end of stream, cannot parse more messages");
            LOGGER.trace((Object)ex);
        }
        catch (IOException ex) {
            LOGGER.warn("The lower layer did not produce a data stream: ", (Throwable)ex);
        }
        return this.getLayerResult();
    }

    @Override
    public void receiveMoreDataForHint(LayerProcessingHint hint) throws IOException {
        try {
            HintedInputStream dataStream = this.getLowerLayer().getDataStream();
            this.readPackets(dataStream);
        }
        catch (PortUnreachableException ex) {
            LOGGER.debug("Received a ICMP Port Unreachable");
            LOGGER.trace((Object)ex);
        }
        catch (TimeoutException | SocketTimeoutException ex) {
            LOGGER.debug("Received a timeout");
            LOGGER.trace((Object)ex);
        }
        catch (EndOfStreamException ex) {
            LOGGER.debug("Reached end of stream, cannot parse more messages");
            LOGGER.trace((Object)ex);
        }
    }

    private void readPackets(InputStream dataStream) throws IOException {
        SilentByteArrayOutputStream outputStream = new SilentByteArrayOutputStream();
        if (dataStream.available() == 0) {
            throw new EndOfStreamException();
        }
        int firstByte = dataStream.read();
        if (firstByte == 0) {
            dataStream.readNBytes(dataStream.available());
        } else {
            QuicPacketType packetType;
            byte[] versionBytes = new byte[]{};
            if (QuicPacketType.isLongHeaderPacket(firstByte)) {
                versionBytes = dataStream.readNBytes(4);
                QuicVersion quicVersion = QuicVersion.getFromVersionBytes(versionBytes);
                if (quicVersion == QuicVersion.NULL_VERSION) {
                    packetType = QuicPacketType.VERSION_NEGOTIATION;
                } else {
                    if (quicVersion != this.quicContext.getQuicVersion()) {
                        LOGGER.warn("Received packet with unexpected QUIC version, ignoring it.");
                        return;
                    }
                    packetType = QuicPacketType.getPacketTypeFromFirstByte(quicVersion, firstByte);
                }
            } else {
                packetType = QuicPacketType.getPacketTypeFromFirstByte(this.quicContext.getQuicVersion(), firstByte);
            }
            QuicPacket readPacket = switch (packetType) {
                case QuicPacketType.INITIAL_PACKET -> this.readInitialPacket(firstByte, versionBytes, dataStream);
                case QuicPacketType.HANDSHAKE_PACKET -> this.readHandshakePacket(firstByte, versionBytes, dataStream);
                case QuicPacketType.ONE_RTT_PACKET -> this.readOneRTTPacket(firstByte, dataStream);
                case QuicPacketType.RETRY_PACKET -> this.readRetryPacket(firstByte, dataStream);
                case QuicPacketType.VERSION_NEGOTIATION -> this.readVersionNegotiationPacket(dataStream);
                case QuicPacketType.ZERO_RTT_PACKET, QuicPacketType.UNKNOWN -> throw new UnsupportedOperationException("Unknown Packet - Not supported yet.");
                default -> throw new IllegalStateException("Received a Packet of Unknown Type");
            };
            if (this.isStatelessResetPacket(readPacket)) {
                this.quicContext.setReceivedStatelessResetToken(true);
                this.addProducedContainer(new StatelessResetPseudoPacket());
                this.quicContext.getReceivedPackets().add(QuicPacketType.STATELESS_RESET);
            } else if (this.context.getConfig().isDiscardPacketsWithMismatchedSCID().booleanValue() && !Arrays.equals((byte[])readPacket.getDestinationConnectionId().getValue(), this.context.getQuicContext().getSourceConnectionId())) {
                LOGGER.debug("Discarding QUIC Packet with mismatching SCID.");
            } else {
                this.receivedPacketBuffer.get((Object)packetType).add(readPacket);
            }
        }
        this.decryptInitialPacketsInBuffer();
        this.decryptHandshakePacketsInBuffer();
        this.decryptOneRRTPacketsInBuffer();
        QuicPacketType packetTypeToProcess = this.getPacketTypeToProcessNext();
        if (packetTypeToProcess != null) {
            ArrayList<QuicPacket> packets = this.receivedPacketBuffer.get((Object)packetTypeToProcess);
            QuicPacket packet = packets.remove(0);
            LOGGER.debug("Processing {} Packet: {}", (Object)packetTypeToProcess, (Object)packet.getPlainPacketNumber());
            this.receivedPacketBuffer.put(packetTypeToProcess, packets);
            outputStream.write((byte[])packet.getUnprotectedPayload().getValue());
            this.quicContext.getReceivedPackets().add(packet.getPacketType());
        }
        if (this.currentInputStream == null) {
            this.currentInputStream = new HintedLayerInputStream(null, this);
            this.currentInputStream.extendStream(outputStream.toByteArray());
        } else {
            this.currentInputStream.extendStream(outputStream.toByteArray());
        }
        outputStream.flush();
    }

    private byte[] writePacket(byte[] data, QuicPacket packet) throws CryptoException {
        packet.setUnprotectedPayload(data);
        return this.writePacket(packet);
    }

    private byte[] writePacket(QuicPacket packet) throws CryptoException {
        return switch (packet.getPacketType()) {
            case QuicPacketType.INITIAL_PACKET -> this.writeInitialPacket((InitialPacket)packet);
            case QuicPacketType.HANDSHAKE_PACKET -> this.writeHandshakePacket((HandshakePacket)packet);
            case QuicPacketType.ONE_RTT_PACKET -> this.writeOneRTTPacket((OneRTTPacket)packet);
            case QuicPacketType.ZERO_RTT_PACKET -> this.writeZeroRTTPacket((ZeroRTTPacket)packet);
            case QuicPacketType.RETRY_PACKET -> this.writeRetryPacket((RetryPacket)packet);
            case QuicPacketType.VERSION_NEGOTIATION -> this.writeVersionNegotiationPacket((VersionNegotiationPacket)packet);
            default -> throw new UnsupportedOperationException("Unknown Packet - Not supported yet.");
        };
    }

    private byte[] writeInitialPacket(InitialPacket packet) throws CryptoException {
        packet.getPreparator(this.context).prepare();
        this.encryptor.encryptInitialPacket(packet);
        packet.updateFlagsWithEncodedPacketNumber();
        this.encryptor.addHeaderProtectionInitial(packet);
        return packet.getSerializer(this.context).serialize();
    }

    private byte[] writeHandshakePacket(HandshakePacket packet) throws CryptoException {
        packet.getPreparator(this.context).prepare();
        this.encryptor.encryptHandshakePacket(packet);
        packet.updateFlagsWithEncodedPacketNumber();
        this.encryptor.addHeaderProtectionHandshake(packet);
        return packet.getSerializer(this.context).serialize();
    }

    private byte[] writeOneRTTPacket(OneRTTPacket packet) throws CryptoException {
        packet.getPreparator(this.context).prepare();
        this.encryptor.encryptOneRRTPacket(packet);
        packet.updateFlagsWithEncodedPacketNumber();
        this.encryptor.addHeaderProtectionOneRRT(packet);
        return packet.getSerializer(this.context).serialize();
    }

    private byte[] writeZeroRTTPacket(ZeroRTTPacket packet) throws CryptoException {
        packet.getPreparator(this.context).prepare();
        this.encryptor.encryptZeroRTTPacket(packet);
        packet.updateFlagsWithEncodedPacketNumber();
        this.encryptor.addHeaderProtectionZeroRTT(packet);
        return packet.getSerializer(this.context).serialize();
    }

    private byte[] writeRetryPacket(RetryPacket packet) {
        packet.getPreparator(this.context).prepare();
        return packet.getSerializer(this.context).serialize();
    }

    private byte[] writeVersionNegotiationPacket(VersionNegotiationPacket packet) {
        packet.getPreparator(this.context).prepare();
        return packet.getSerializer(this.context).serialize();
    }

    private InitialPacket readInitialPacket(int flags, byte[] versionBytes, InputStream dataStream) {
        InitialPacket packet = new InitialPacket((byte)flags, versionBytes);
        packet.getParser(this.context, dataStream).parse(packet);
        return packet;
    }

    private InitialPacket decryptIntitialPacket(InitialPacket packet) throws CryptoException {
        this.decryptor.removeHeaderProtectionInitial(packet);
        packet.convertCompleteProtectedHeader();
        this.decryptor.decryptInitialPacket(packet);
        this.quicContext.addReceivedInitialPacketNumber(packet.getPlainPacketNumber());
        packet.getHandler(this.context).adjustContext(packet);
        this.addProducedContainer(packet);
        return packet;
    }

    private HandshakePacket readHandshakePacket(int flags, byte[] versionBytes, InputStream dataStream) {
        HandshakePacket packet = new HandshakePacket((byte)flags, versionBytes);
        packet.getParser(this.context, dataStream).parse(packet);
        return packet;
    }

    private HandshakePacket decryptHandshakePacket(HandshakePacket packet) throws CryptoException {
        this.decryptor.removeHeaderProtectionHandshake(packet);
        packet.convertCompleteProtectedHeader();
        this.decryptor.decryptHandshakePacket(packet);
        this.quicContext.addReceivedHandshakePacketNumber(packet.getPlainPacketNumber());
        packet.getHandler(this.context).adjustContext(packet);
        this.addProducedContainer(packet);
        return packet;
    }

    private OneRTTPacket readOneRTTPacket(int flags, InputStream dataStream) {
        OneRTTPacket packet = new OneRTTPacket((byte)flags);
        packet.getParser(this.context, dataStream).parse(packet);
        return packet;
    }

    private OneRTTPacket decryptOneRTTPacket(OneRTTPacket packet) throws CryptoException {
        this.decryptor.removeHeaderProtectionOneRTT(packet);
        packet.convertCompleteProtectedHeader();
        this.decryptor.decryptOneRTTPacket(packet);
        this.quicContext.addReceivedOneRTTPacketNumber(packet.getPlainPacketNumber());
        packet.getHandler(this.context).adjustContext(packet);
        this.addProducedContainer(packet);
        return packet;
    }

    private RetryPacket readRetryPacket(int flags, InputStream dataStream) {
        RetryPacket packet = new RetryPacket((byte)flags);
        packet.getParser(this.context, dataStream).parse(packet);
        packet.getHandler(this.context).adjustContext(packet);
        this.addProducedContainer(packet);
        return packet;
    }

    private VersionNegotiationPacket readVersionNegotiationPacket(InputStream dataStream) {
        VersionNegotiationPacket packet = new VersionNegotiationPacket();
        packet.getParser(this.context, dataStream).parse(packet);
        packet.getHandler(this.context).adjustContext(packet);
        this.addProducedContainer(packet);
        return packet;
    }

    private void decryptInitialPacketsInBuffer() {
        if (!this.receivedPacketBuffer.get((Object)QuicPacketType.INITIAL_PACKET).isEmpty() && this.quicContext.isInitialSecretsInitialized()) {
            this.receivedPacketBuffer.computeIfPresent(QuicPacketType.INITIAL_PACKET, (packetType, packets) -> (ArrayList)packets.stream().map(packet -> {
                try {
                    return packet.getUnprotectedPayload() == null ? this.decryptIntitialPacket((InitialPacket)packet) : packet;
                }
                catch (CryptoException ex) {
                    throw new CryptoException("Could not decrypt packet", (Throwable)ex);
                }
            }).sorted(Comparator.comparingInt(QuicPacket::getPlainPacketNumber)).collect(Collectors.toList()));
        }
    }

    private void decryptHandshakePacketsInBuffer() {
        if (!this.receivedPacketBuffer.get((Object)QuicPacketType.HANDSHAKE_PACKET).isEmpty() && this.quicContext.isHandshakeSecretsInitialized()) {
            this.receivedPacketBuffer.computeIfPresent(QuicPacketType.HANDSHAKE_PACKET, (packetType, packets) -> (ArrayList)packets.stream().map(packet -> {
                try {
                    return packet.getUnprotectedPayload() == null ? this.decryptHandshakePacket((HandshakePacket)packet) : packet;
                }
                catch (CryptoException ex) {
                    throw new CryptoException("Could not decrypt packet", (Throwable)ex);
                }
            }).sorted(Comparator.comparingInt(QuicPacket::getPlainPacketNumber)).collect(Collectors.toList()));
        }
    }

    private void decryptOneRRTPacketsInBuffer() {
        if (!this.receivedPacketBuffer.get((Object)QuicPacketType.ONE_RTT_PACKET).isEmpty() && this.quicContext.isApplicationSecretsInitialized()) {
            this.receivedPacketBuffer.computeIfPresent(QuicPacketType.ONE_RTT_PACKET, (packetType, packets) -> (ArrayList)packets.stream().map(packet -> {
                try {
                    return packet.getUnprotectedPayload() == null ? this.decryptOneRTTPacket((OneRTTPacket)packet) : packet;
                }
                catch (CryptoException ex) {
                    throw new CryptoException("Could not decrypt packet", (Throwable)ex);
                }
            }).sorted(Comparator.comparingInt(QuicPacket::getPlainPacketNumber)).collect(Collectors.toList()));
        }
    }

    private QuicPacketType getPacketTypeToProcessNext() {
        if (!this.receivedPacketBuffer.get((Object)QuicPacketType.INITIAL_PACKET).isEmpty() && this.quicContext.isInitialSecretsInitialized() && !this.quicContext.isHandshakeSecretsInitialized()) {
            return QuicPacketType.INITIAL_PACKET;
        }
        if (!this.receivedPacketBuffer.get((Object)QuicPacketType.HANDSHAKE_PACKET).isEmpty() && this.quicContext.isHandshakeSecretsInitialized() && !this.quicContext.isApplicationSecretsInitialized()) {
            return QuicPacketType.HANDSHAKE_PACKET;
        }
        if (!this.receivedPacketBuffer.get((Object)QuicPacketType.ONE_RTT_PACKET).isEmpty() && this.quicContext.isApplicationSecretsInitialized()) {
            return QuicPacketType.ONE_RTT_PACKET;
        }
        return null;
    }

    private boolean isEmptyPacket(QuicPacket packet) {
        return this.context.getConfig().isUseAllProvidedQuicPackets() == false && packet.getUnprotectedPayload() != null && ((byte[])packet.getUnprotectedPayload().getValue()).length == 0;
    }

    @Override
    public void sendAck(byte[] data) {
        if (this.temporarilyDisabledAcks) {
            return;
        }
        this.context.setTalkingConnectionEndType(this.context.getConnection().getLocalConnectionEndType());
        try {
            if (this.quicContext.getReceivedPackets().getLast() == QuicPacketType.INITIAL_PACKET) {
                this.getLowerLayer().sendData(null, this.writePacket(data, new InitialPacket()));
            } else if (this.quicContext.getReceivedPackets().getLast() == QuicPacketType.HANDSHAKE_PACKET) {
                this.getLowerLayer().sendData(null, this.writePacket(data, new HandshakePacket()));
            } else if (this.quicContext.getReceivedPackets().getLast() == QuicPacketType.ONE_RTT_PACKET) {
                this.getLowerLayer().sendData(null, this.writePacket(data, new OneRTTPacket()));
            }
        }
        catch (CryptoException | IOException e) {
            LOGGER.error("Could not send ACK", e);
        }
        this.context.setTalkingConnectionEndType(this.context.getConnection().getLocalConnectionEndType().getPeer());
    }

    public void clearReceivedPacketBuffer() {
        this.receivedPacketBuffer.values().forEach(ArrayList::clear);
    }

    private boolean isStatelessResetPacket(QuicPacket packet) {
        if (packet.getPacketType() != QuicPacketType.RETRY_PACKET && packet.getPacketType() != QuicPacketType.VERSION_NEGOTIATION) {
            byte[] protectedPacketNumberAndPayload = (byte[])packet.getProtectedPacketNumberAndPayload().getValue();
            if (protectedPacketNumberAndPayload.length < 16) {
                return false;
            }
            byte[] lastSixteenBytes = Arrays.copyOfRange(protectedPacketNumberAndPayload, protectedPacketNumberAndPayload.length - 16, protectedPacketNumberAndPayload.length);
            if (this.quicContext.isStatelessResetToken(lastSixteenBytes)) {
                LOGGER.debug("Received a Stateless Reset Packet with Token {}", (Object)lastSixteenBytes);
                return true;
            }
        }
        return false;
    }

    public void setTemporarilyDisabledAcks(boolean temporarilyDisabledAcks) {
        this.temporarilyDisabledAcks = temporarilyDisabledAcks;
    }
}

