/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.workflow;

import de.rub.nds.tlsattacker.core.state.State;
import de.rub.nds.tlsattacker.core.workflow.task.ITask;
import de.rub.nds.tlsattacker.core.workflow.task.StateExecutionTask;
import de.rub.nds.tlsattacker.core.workflow.task.TlsTask;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class ParallelExecutor {
    private static final Logger LOGGER = LogManager.getLogger();
    private final ThreadPoolExecutor executorService;
    private Callable<Integer> timeoutAction;
    private final int size;
    private boolean shouldShutdown = false;
    private final int reexecutions;
    private Function<State, Integer> defaultBeforeTransportPreInitCallback = null;
    private Function<State, Integer> defaultBeforeTransportInitCallback = null;
    private Function<State, Integer> defaultAfterTransportInitCallback = null;
    private Function<State, Integer> defaultAfterExecutionCallback = null;

    public static ParallelExecutor create(int size, int reexecutions, ThreadPoolExecutor executorService) {
        if (reexecutions < 0) {
            throw new IllegalArgumentException("Reexecutions is below zero");
        }
        return new ParallelExecutor(size, reexecutions, executorService);
    }

    public static ParallelExecutor create(ThreadPoolExecutor executorService, int reexecutions) {
        return ParallelExecutor.create(-1, reexecutions, executorService);
    }

    public static ParallelExecutor create(int size, int reexecutions) {
        return ParallelExecutor.create(size, reexecutions, new ThreadPoolExecutor(size, size, 10L, TimeUnit.DAYS, new LinkedBlockingDeque<Runnable>()));
    }

    public static ParallelExecutor create(int size, int reexecutions, ThreadFactory factory) {
        return ParallelExecutor.create(size, reexecutions, new ThreadPoolExecutor(size, size, 5L, TimeUnit.MINUTES, new LinkedBlockingDeque<Runnable>(), factory));
    }

    protected ParallelExecutor(int size, int reexecutions, ThreadPoolExecutor executorService) {
        this.executorService = executorService;
        this.reexecutions = reexecutions;
        this.size = size;
    }

    protected Future<ITask> addTask(TlsTask task) {
        if (this.executorService.isShutdown()) {
            throw new RuntimeException("Cannot add Tasks to already shutdown executor");
        }
        if (this.defaultBeforeTransportPreInitCallback != null && task.getBeforeTransportPreInitCallback() == null) {
            task.setBeforeTransportPreInitCallback(this.defaultBeforeTransportPreInitCallback);
        }
        if (this.defaultBeforeTransportInitCallback != null && task.getBeforeTransportInitCallback() == null) {
            task.setBeforeTransportInitCallback(this.defaultBeforeTransportInitCallback);
        }
        if (this.defaultAfterTransportInitCallback != null && task.getAfterTransportInitCallback() == null) {
            task.setAfterTransportInitCallback(this.defaultAfterTransportInitCallback);
        }
        if (this.defaultAfterExecutionCallback != null && task.getAfterExecutionCallback() == null) {
            task.setAfterExecutionCallback(this.defaultAfterExecutionCallback);
        }
        return this.executorService.submit(task);
    }

    protected Future<ITask> addStateTask(State state) {
        return this.addTask(new StateExecutionTask(state, this.reexecutions));
    }

    public void bulkExecuteStateTasks(Iterable<State> stateList) {
        LinkedList<Future<ITask>> futureList = new LinkedList<Future<ITask>>();
        for (State state : stateList) {
            futureList.add(this.addStateTask(state));
        }
        for (Future future : futureList) {
            try {
                future.get();
            }
            catch (InterruptedException | ExecutionException ex) {
                throw new RuntimeException("Failed to execute tasks!", ex);
            }
        }
    }

    public void bulkExecuteStateTasks(State ... states) {
        this.bulkExecuteStateTasks(new ArrayList<State>(Arrays.asList(states)));
    }

    public List<ITask> bulkExecuteTasks(Iterable<TlsTask> taskList) {
        LinkedList<Future<ITask>> futureList = new LinkedList<Future<ITask>>();
        ArrayList<ITask> resultList = new ArrayList<ITask>(futureList.size());
        for (TlsTask tlsTask : taskList) {
            futureList.add(this.addTask(tlsTask));
        }
        for (Future future : futureList) {
            try {
                resultList.add((ITask)future.get());
            }
            catch (InterruptedException | ExecutionException ex) {
                throw new RuntimeException("Failed to execute tasks!", ex);
            }
        }
        return resultList;
    }

    public List<ITask> bulkExecuteTasks(TlsTask ... tasks) {
        return this.bulkExecuteTasks(new ArrayList<TlsTask>(Arrays.asList(tasks)));
    }

    public int getSize() {
        return this.size;
    }

    public void shutdown() {
        this.shouldShutdown = true;
        this.executorService.shutdown();
    }

    public void armTimeoutAction(int timeout) {
        if (this.timeoutAction == null) {
            LOGGER.warn("No TimeoutAction set, this won't do anything");
            return;
        }
        new Thread(() -> this.monitorExecution(timeout)).start();
    }

    private void monitorExecution(int timeout) {
        long timeoutTime = System.currentTimeMillis() + (long)timeout;
        long lastCompletedCount = 0L;
        while (!this.shouldShutdown) {
            long completedCount = this.executorService.getCompletedTaskCount();
            if (this.executorService.getActiveCount() == 0 || completedCount != lastCompletedCount) {
                timeoutTime = System.currentTimeMillis() + (long)timeout;
                lastCompletedCount = completedCount;
                continue;
            }
            if (System.currentTimeMillis() <= timeoutTime) continue;
            LOGGER.debug("Timeout");
            try {
                int exitCode = this.timeoutAction.call();
                if (exitCode != 0) {
                    throw new RuntimeException("TimeoutAction did terminate with code " + exitCode);
                }
                timeoutTime = System.currentTimeMillis() + (long)timeout;
            }
            catch (Exception e) {
                LOGGER.warn("TimeoutAction did not succeed", (Throwable)e);
            }
        }
    }

    public int getReexecutions() {
        return this.reexecutions;
    }

    public Callable<Integer> getTimeoutAction() {
        return this.timeoutAction;
    }

    public void setTimeoutAction(Callable<Integer> timeoutAction) {
        this.timeoutAction = timeoutAction;
    }

    public Function<State, Integer> getDefaultBeforeTransportPreInitCallback() {
        return this.defaultBeforeTransportPreInitCallback;
    }

    public void setDefaultBeforeTransportPreInitCallback(Function<State, Integer> defaultBeforeTransportPreInitCallback) {
        this.defaultBeforeTransportPreInitCallback = defaultBeforeTransportPreInitCallback;
    }

    public Function<State, Integer> getDefaultBeforeTransportInitCallback() {
        return this.defaultBeforeTransportInitCallback;
    }

    public void setDefaultBeforeTransportInitCallback(Function<State, Integer> defaultBeforeTransportInitCallback) {
        this.defaultBeforeTransportInitCallback = defaultBeforeTransportInitCallback;
    }

    public Function<State, Integer> getDefaultAfterTransportInitCallback() {
        return this.defaultAfterTransportInitCallback;
    }

    public void setDefaultAfterTransportInitCallback(Function<State, Integer> defaultAfterTransportInitCallback) {
        this.defaultAfterTransportInitCallback = defaultAfterTransportInitCallback;
    }

    public Function<State, Integer> getDefaultAfterExecutionCallback() {
        return this.defaultAfterExecutionCallback;
    }

    public void setDefaultAfterExecutionCallback(Function<State, Integer> defaultAfterExecutionCallback) {
        this.defaultAfterExecutionCallback = defaultAfterExecutionCallback;
    }
}

