/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.coalescent.basta;

import beagle.BeagleFlag;
import beagle.basta.BastaFactory;
import beagle.basta.BeagleBasta;
import dr.evolution.tree.Tree;
import dr.evomodel.coalescent.basta.BastaLikelihood;
import dr.evomodel.coalescent.basta.BastaLikelihoodDelegate;
import dr.evomodel.coalescent.basta.ProcessOnCoalescentIntervalDelegate;
import dr.evomodel.substmodel.EigenDecomposition;
import dr.evomodel.treedatalikelihood.BufferIndexHelper;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.Likelihood;
import java.util.Arrays;
import java.util.List;

public class BeagleBastaLikelihoodDelegate
extends BastaLikelihoodDelegate.AbstractBastaLikelihoodDelegate {
    private static final int COALESCENT_PROBABILITY_INDEX = 0;
    private final BeagleBasta beagle;
    private final BufferIndexHelper eigenBufferHelper;
    private final OffsetBufferIndexHelper populationSizesBufferHelper;
    int tipCount = -1;
    int[] map;
    int used;
    private static final boolean CACHE_FRIENDLY = true;
    private boolean releaseSingleton = true;

    public BeagleBastaLikelihoodDelegate(String string, Tree tree, int n, boolean bl) {
        super(string, tree, n, bl);
        int n2 = this.maxNumCoalescentIntervals * (tree.getNodeCount() + 1);
        int n3 = this.maxNumCoalescentIntervals;
        int n4 = 5;
        long l = 0L;
        this.beagle = BastaFactory.loadBastaInstance((int)0, (int)n4, (int)this.maxNumCoalescentIntervals, (int)n2, (int)0, (int)n, (int)1, (int)2, (int)n3, (int)1, (int)1, null, (long)0L, (long)(l |= BeagleFlag.EIGEN_COMPLEX.getMask()));
        this.eigenBufferHelper = new BufferIndexHelper(1, 0);
        this.populationSizesBufferHelper = new OffsetBufferIndexHelper(1, 0, 0);
        this.beagle.setCategoryRates(new double[]{1.0});
    }

    @Override
    protected void computeBranchIntervalOperations(List<Integer> list, List<ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation> list2) {
        int[] nArray = new int[list2.size() * 8];
        int[] nArray2 = new int[list.size()];
        double[] dArray = new double[list.size() - 1];
        this.vectorizeBranchIntervalOperations(list, list2, nArray, nArray2, dArray);
        int n = this.populationSizesBufferHelper.getOffsetIndex(0);
        this.beagle.updateBastaPartials(nArray, list2.size(), nArray2, list.size(), n, 0);
    }

    @Override
    String getStamp() {
        return "beagle";
    }

    @Override
    protected void computeTransitionProbabilityOperations(List<ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation> list) {
        int[] nArray = new int[list.size()];
        double[] dArray = new double[list.size()];
        this.vectorizeTransitionMatrixOperations(list, nArray, dArray);
        int n = this.eigenBufferHelper.getOffsetIndex(0);
        this.beagle.updateTransitionMatrices(n, nArray, null, null, dArray, list.size());
    }

    @Override
    protected double computeCoalescentIntervalReduction(List<Integer> list, List<ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation> list2) {
        int[] nArray = new int[list2.size() * 8];
        int[] nArray2 = new int[list.size()];
        double[] dArray = new double[list.size() - 1];
        this.vectorizeBranchIntervalOperations(list, list2, nArray, nArray2, dArray);
        int n = this.populationSizesBufferHelper.getOffsetIndex(0);
        double[] dArray2 = new double[1];
        this.beagle.accumulateBastaPartials(nArray, list2.size(), nArray2, list.size(), dArray, n, 0, dArray2);
        return dArray2[0];
    }

    @Override
    protected void computeBranchIntervalOperationsGrad(List<Integer> list, List<ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation> list2, List<ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation> list3) {
        int[] nArray = new int[list3.size() * 8];
        int[] nArray2 = new int[list.size()];
        double[] dArray = new double[list.size() - 1];
        this.vectorizeBranchIntervalOperations(list, list3, nArray, nArray2, dArray);
        int n = this.populationSizesBufferHelper.getOffsetIndex(0);
        this.beagle.updateBastaPartialsGrad(nArray, list3.size(), nArray2, list.size(), n, 0);
    }

    @Override
    protected void computeTransitionProbabilityOperationsGrad(List<ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation> list) {
        int[] nArray = new int[list.size()];
        double[] dArray = new double[list.size()];
        this.vectorizeTransitionMatrixOperations(list, nArray, dArray);
        this.beagle.updateTransitionMatricesGrad(nArray, dArray, list.size());
    }

    @Override
    protected double[][] computeCoalescentIntervalReductionGrad(List<Integer> list, List<ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation> list2) {
        int[] nArray = new int[list2.size() * 8];
        int[] nArray2 = new int[list.size()];
        double[] dArray = new double[list.size() - 1];
        this.vectorizeBranchIntervalOperations(list, list2, nArray, nArray2, dArray);
        int n = this.populationSizesBufferHelper.getOffsetIndex(0);
        double[] dArray2 = new double[this.stateCount * this.stateCount];
        double[][] dArray3 = new double[this.stateCount][this.stateCount];
        this.beagle.accumulateBastaPartialsGrad(nArray, list2.size(), nArray2, list.size(), dArray, n, 0, dArray2);
        for (int i = 0; i < this.stateCount; ++i) {
            for (int j = 0; j < this.stateCount; ++j) {
                dArray3[i][j] = dArray2[i * this.stateCount + j];
            }
        }
        return dArray3;
    }

    @Override
    protected double[] computeCoalescentIntervalReductionGradPopSize(List<Integer> list, List<ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation> list2) {
        return new double[0];
    }

    @Override
    public void setPartials(int n, double[] dArray) {
        this.beagle.setPartials(n, dArray);
    }

    @Override
    public void getPartials(int n, double[] dArray) {
        assert (n >= 0);
        assert (dArray != null);
        assert (dArray.length >= this.stateCount);
        this.beagle.getPartials(n, -1, dArray);
    }

    public void getMatrix(int n, double[] dArray) {
        assert (n >= 0);
        assert (dArray != null);
        assert (dArray.length >= this.stateCount * this.stateCount);
        this.beagle.getTransitionMatrix(n, dArray);
    }

    @Override
    public void updateEigenDecomposition(int n, EigenDecomposition eigenDecomposition, boolean bl) {
        if (bl) {
            this.eigenBufferHelper.flipOffset(0);
        }
        if (this.transpose) {
            eigenDecomposition = eigenDecomposition.transpose();
        }
        this.beagle.setEigenDecomposition(this.eigenBufferHelper.getOffsetIndex(0), eigenDecomposition.getEigenVectors(), eigenDecomposition.getInverseEigenVectors(), eigenDecomposition.getEigenValues());
    }

    @Override
    public void updatePopulationSizes(int n, double[] dArray, boolean bl) {
        if (bl) {
            this.populationSizesBufferHelper.flipOffset(0);
        }
        this.beagle.setStateFrequencies(this.populationSizesBufferHelper.getOffsetIndex(0), dArray);
    }

    private void vectorizeTransitionMatrixOperations(List<ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation> list, int[] nArray, double[] dArray) {
        int n = 0;
        for (ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation transitionMatrixOperation : list) {
            nArray[n] = transitionMatrixOperation.outputBuffer;
            dArray[n] = transitionMatrixOperation.time;
            ++n;
        }
    }

    int map(int n) {
        if (n < this.tipCount) {
            return n;
        }
        if (this.map[n] == -1) {
            this.map[n] = this.used++;
        }
        return this.map[n];
    }

    private void vectorizeBranchIntervalOperations(List<Integer> list, List<ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation> list2, int[] nArray, int[] nArray2, double[] dArray) {
        int n;
        this.tipCount = this.tree.getExternalNodeCount();
        if (this.map == null) {
            this.map = new int[this.maxNumCoalescentIntervals * (this.tree.getNodeCount() + 1)];
        }
        Arrays.fill(this.map, -1);
        this.used = this.tipCount;
        int n2 = 0;
        for (ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation branchIntervalOperation : list2) {
            nArray[n2] = this.map(branchIntervalOperation.outputBuffer);
            nArray[n2 + 1] = this.map(branchIntervalOperation.inputBuffer1);
            nArray[n2 + 2] = branchIntervalOperation.inputMatrix1;
            nArray[n2 + 3] = this.map(branchIntervalOperation.inputBuffer2);
            nArray[n2 + 4] = branchIntervalOperation.inputMatrix2;
            nArray[n2 + 5] = this.map(branchIntervalOperation.accBuffer1);
            nArray[n2 + 6] = this.map(branchIntervalOperation.accBuffer2);
            nArray[n2 + 7] = branchIntervalOperation.intervalNumber;
            n2 += 8;
        }
        int n3 = list.size() - 1;
        for (n = 0; n < n3; ++n) {
            int n4;
            nArray2[n] = n4 = list.get(n).intValue();
            dArray[n] = list2.get((int)n4).intervalLength;
        }
        nArray2[n] = list.get(n);
    }

    private void releaseBeagle() throws Throwable {
        if (this.beagle != null && this.releaseSingleton) {
            this.beagle.finalize();
            this.releaseSingleton = false;
        }
    }

    public static void releaseBeagleBastaLikelihoodDelegate(BastaLikelihood bastaLikelihood) throws Throwable {
        BastaLikelihoodDelegate bastaLikelihoodDelegate = bastaLikelihood.getLikelihoodDelegate();
        if (bastaLikelihoodDelegate instanceof BeagleBastaLikelihoodDelegate) {
            BeagleBastaLikelihoodDelegate beagleBastaLikelihoodDelegate = (BeagleBastaLikelihoodDelegate)bastaLikelihoodDelegate;
            beagleBastaLikelihoodDelegate.releaseBeagle();
        }
    }

    public static void releaseAllBeagleBastaInstances() throws Throwable {
        for (Likelihood likelihood : Likelihood.FULL_LIKELIHOOD_SET) {
            if (likelihood instanceof BastaLikelihood) {
                BeagleBastaLikelihoodDelegate.releaseBeagleBastaLikelihoodDelegate((BastaLikelihood)likelihood);
                continue;
            }
            if (!(likelihood instanceof CompoundLikelihood)) continue;
            for (Likelihood likelihood2 : ((CompoundLikelihood)likelihood).getLikelihoods()) {
                if (!(likelihood2 instanceof BastaLikelihood)) continue;
                BeagleBastaLikelihoodDelegate.releaseBeagleBastaLikelihoodDelegate((BastaLikelihood)likelihood2);
            }
        }
    }

    static class OffsetBufferIndexHelper
    extends BufferIndexHelper {
        public OffsetBufferIndexHelper(int n, int n2, int n3) {
            super(n, n2, n3);
        }

        @Override
        protected int computeOffset(int n) {
            return n;
        }
    }
}

