/*
 * Decompiled with CFR 0.152.
 */
package dr.util;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.function.BinaryOperator;

public class TaskPool {
    private ExecutorService pool = null;
    private final List<TaskIndices> indices;
    private final int taskCount;
    private final int threadCount;

    public TaskPool(int n, int n2) {
        this.indices = this.setupTasks(n, Math.abs(n2));
        this.taskCount = n;
        this.threadCount = n2;
    }

    public ExecutorService getPool() {
        return this.pool;
    }

    public List<TaskIndices> getIndices() {
        return this.indices;
    }

    public int getNumThreads() {
        return this.indices.size();
    }

    public int getNumTaxon() {
        return this.taskCount;
    }

    private List<TaskIndices> setupTasks(int n, int n2) {
        ArrayList<TaskIndices> arrayList = new ArrayList<TaskIndices>(n2);
        int n3 = n / n2;
        if (n % n2 != 0) {
            ++n3;
        }
        int n4 = 0;
        for (int i = 0; i < n2 && n4 < n; n4 += n3, ++i) {
            arrayList.add(new TaskIndices(n4, Math.min(n4 + n3, n), i));
        }
        return arrayList;
    }

    private ExecutorService setupParallelServices(int n) {
        ExecutorService executorService = n > 1 ? Executors.newFixedThreadPool(n) : (n < 0 ? Executors.newCachedThreadPool() : null);
        return executorService;
    }

    public <E> E mapReduce(RangeCallable<E> rangeCallable, BinaryOperator<E> binaryOperator) {
        Object t = null;
        if (this.indices.size() == 1) {
            TaskIndices taskIndices = this.indices.get(0);
            t = rangeCallable.map(taskIndices.start, taskIndices.stop, 0);
        } else {
            if (this.pool == null) {
                this.pool = this.setupParallelServices(this.threadCount);
            }
            ArrayList<Callable<Object>> arrayList = new ArrayList<Callable<Object>>();
            for (TaskIndices taskIndices : this.indices) {
                arrayList.add(() -> rangeCallable.map(taskIndices.start, taskIndices.stop, taskIndices.task));
            }
            try {
                List list = this.pool.invokeAll(arrayList);
                t = ((Future)list.get(0)).get();
                for (int i = 1; i < list.size(); ++i) {
                    t = binaryOperator.apply(t, ((Future)list.get(i)).get());
                }
            }
            catch (InterruptedException | ExecutionException exception) {
                exception.printStackTrace();
            }
        }
        return (E)t;
    }

    public void fork(TaskCallable taskCallable) {
        if (this.indices.size() == 1) {
            TaskIndices taskIndices = this.indices.get(0);
            for (int i = taskIndices.start; i < taskIndices.stop; ++i) {
                taskCallable.execute(i, 0);
            }
        } else {
            if (this.pool == null) {
                this.pool = this.setupParallelServices(this.threadCount);
            }
            ArrayList<Callable<Object>> arrayList = new ArrayList<Callable<Object>>();
            for (TaskIndices taskIndices : this.indices) {
                arrayList.add(Executors.callable(() -> {
                    for (int i = taskIndices.start; i < taskIndices.stop; ++i) {
                        taskCallable.execute(i, taskIndices.task);
                    }
                }));
            }
            try {
                this.pool.invokeAll(arrayList);
            }
            catch (InterruptedException interruptedException) {
                interruptedException.printStackTrace();
            }
        }
    }

    public static interface RangeCallable<E> {
        public E map(int var1, int var2, int var3);
    }

    public static interface TaskCallable {
        public void execute(int var1, int var2);
    }

    class TaskIndices {
        final int start;
        final int stop;
        final int task;

        TaskIndices(int n, int n2, int n3) {
            this.start = n;
            this.stop = n2;
            this.task = n3;
        }

        public String toString() {
            return this.start + " " + this.stop;
        }
    }
}

