/*
 * Decompiled with CFR 0.152.
 */
package dr.evolution.coalescent;

import dr.evolution.coalescent.DemographicFunction;
import dr.evolution.coalescent.IntervalList;
import dr.evolution.coalescent.IntervalType;
import dr.evolution.tree.Tree;
import dr.evomodel.bigfasttree.BigFastTreeIntervals;
import dr.evomodel.coalescent.CoalescentLikelihood;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightProxyParameter;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.GradientProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.Binomial;
import dr.xml.Reportable;
import java.util.Arrays;

public class CoalescentGradient
implements GradientWrtParameterProvider,
Reportable,
Loggable {
    private final CoalescentLikelihood likelihood;
    private final Parameter parameter;
    private final Tree tree;
    private final GradientProvider provider;
    private final double tolerance;

    public CoalescentGradient(CoalescentLikelihood coalescentLikelihood, TreeModel treeModel, Parameter parameter, Wrt wrt, double d) {
        this.likelihood = coalescentLikelihood;
        this.tree = treeModel;
        if (wrt == Wrt.NODE_HEIGHTS) {
            this.parameter = new NodeHeightProxyParameter("NodeHeights", treeModel, true);
            this.provider = new GradientProvider(){

                @Override
                public int getDimension() {
                    return CoalescentGradient.this.parameter.getDimension();
                }

                @Override
                public double[] getGradientLogDensity(Object object) {
                    return CoalescentGradient.this.getGradientLogDensityWrtNodeHeights();
                }
            };
        } else {
            this.parameter = parameter;
            this.provider = null;
        }
        this.tolerance = d;
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.parameter.getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        if (this.likelihood.getPopulationSizeModel() != null) {
            throw new RuntimeException("Not yet implemented!");
        }
        return this.provider.getGradientLogDensity(null);
    }

    private double[] getGradientLogDensityWrtNodeHeights() {
        int n;
        double d = this.likelihood.getLogLikelihood();
        double[] dArray = new double[this.tree.getInternalNodeCount()];
        if (d == Double.NEGATIVE_INFINITY) {
            Arrays.fill(dArray, Double.NaN);
            return dArray;
        }
        IntervalList intervalList = this.likelihood.getIntervalList();
        BigFastTreeIntervals bigFastTreeIntervals = (BigFastTreeIntervals)intervalList;
        DemographicFunction demographicFunction = this.likelihood.getDemoModel().getDemographicFunction();
        int n2 = 1;
        double d2 = 0.0;
        for (n = 0; n < bigFastTreeIntervals.getIntervalCount(); ++n) {
            if (bigFastTreeIntervals.getIntervalType(n) != IntervalType.COALESCENT) continue;
            double d3 = bigFastTreeIntervals.getIntervalTime(n + 1);
            int n3 = bigFastTreeIntervals.getLineageCount(n);
            double d4 = Binomial.choose2(n3);
            double d5 = demographicFunction.getIntensityGradient(d3);
            d2 += demographicFunction.getLogDemographicGradient(d3);
            if (bigFastTreeIntervals.getInterval(n) != 0.0) {
                d2 -= d4 * d5;
            } else {
                ++n2;
            }
            if (n >= bigFastTreeIntervals.getIntervalCount() - 1 || bigFastTreeIntervals.getInterval(n + 1) == 0.0) continue;
            int n4 = bigFastTreeIntervals.getLineageCount(n + 1);
            d2 += Binomial.choose2(n4) * d5;
            for (int i = 0; i < n2; ++i) {
                int n5 = bigFastTreeIntervals.getNodeNumbersForInterval(n - i)[1];
                dArray[n5 - this.tree.getExternalNodeCount()] = d2 / (double)n2;
            }
            d2 = 0.0;
            n2 = 1;
        }
        n = n2;
        int n6 = bigFastTreeIntervals.getIntervalCount() - 1;
        while (n > 0) {
            if (bigFastTreeIntervals.getIntervalType(n6) == IntervalType.COALESCENT) {
                dArray[bigFastTreeIntervals.getNodeNumbersForInterval((int)n6)[1] - this.tree.getExternalNodeCount()] = d2 / (double)n2;
                --n;
            }
            --n6;
        }
        return dArray;
    }

    @Override
    public String getReport() {
        return GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, this.tolerance);
    }

    @Override
    public LogColumn[] getColumns() {
        return Loggable.getColumnsFromReport(this, "CoalescentGradient check");
    }

    public static enum Wrt {
        NODE_HEIGHTS,
        PARAMETER;

    }
}

