/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.tree.TreeChangedEvent;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.LikelihoodTreeTraversal;
import dr.evomodel.treedatalikelihood.ProcessOnTreeDelegate;
import dr.evomodel.treedatalikelihood.SimulationTreeTraversal;
import dr.evomodel.treedatalikelihood.TreeTraversal;
import dr.evomodel.treedatalikelihood.discrete.AbstractNodeHeightTransformDelegate;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.matrixAlgebra.Vector;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class NodeHeightToRatiosTransformDelegate
extends AbstractNodeHeightTransformDelegate {
    protected Parameter ratios;
    private final LikelihoodTreeTraversal postOrderTraversal;
    protected final SimulationTreeTraversal preOrderTraversal;
    protected Map<Integer, Epoch> nodeEpochMap = new HashMap<Integer, Epoch>();
    private List<Epoch> epochs = new ArrayList<Epoch>();
    private boolean ratiosKnown = false;
    protected boolean epochKnown = false;
    private boolean DEBUG = false;

    public NodeHeightToRatiosTransformDelegate(TreeModel treeModel, Parameter parameter, Parameter parameter2, BranchRateModel branchRateModel) {
        super(treeModel, parameter);
        this.ratios = parameter2;
        this.postOrderTraversal = new LikelihoodTreeTraversal(this.tree, branchRateModel, TreeTraversal.TraversalType.POST_ORDER);
        this.preOrderTraversal = new SimulationTreeTraversal(this.tree, branchRateModel, TreeTraversal.TraversalType.PRE_ORDER);
        this.updateRatios();
        this.addVariable(parameter2);
        this.constructEpochs();
    }

    @Override
    public void modelRestored(Model model) {
        this.epochKnown = true;
        this.ratiosKnown = true;
    }

    @Override
    public void storeState() {
        this.ratios.storeParameterValues();
    }

    @Override
    public void restoreState() {
        this.ratios.restoreParameterValues();
    }

    protected void constructEpochs() {
        this.nodeEpochMap.clear();
        this.epochs.clear();
        this.postOrderTraversal.updateAllNodes();
        this.postOrderTraversal.dispatchTreeTraversalCollectBranchAndNodeOperations();
        List<ProcessOnTreeDelegate.NodeOperation> list = this.postOrderTraversal.getNodeOperations();
        for (ProcessOnTreeDelegate.NodeOperation nodeOperation : list) {
            NodeRef nodeRef = this.tree.getNode(nodeOperation.getNodeNumber());
            NodeRef nodeRef2 = this.tree.getNode(nodeOperation.getLeftChild());
            NodeRef nodeRef3 = this.tree.getNode(nodeOperation.getRightChild());
            double d = this.getAnchorTipHeight(nodeRef2);
            double d2 = this.getAnchorTipHeight(nodeRef3);
            if (this.tree.isRoot(nodeRef)) {
                if (!this.tree.isExternal(nodeRef2)) {
                    this.nodeEpochMap.get(nodeRef2.getNumber()).endEpoch(nodeRef, null);
                }
                if (this.tree.isExternal(nodeRef3)) continue;
                this.nodeEpochMap.get(nodeRef3.getNumber()).endEpoch(nodeRef, null);
                continue;
            }
            if (d2 > d) {
                this.addToEpoch(nodeRef, nodeRef3, nodeRef2);
                continue;
            }
            this.addToEpoch(nodeRef, nodeRef2, nodeRef3);
        }
        this.epochKnown = true;
    }

    private void addToEpoch(NodeRef nodeRef, NodeRef nodeRef2, NodeRef nodeRef3) {
        Epoch epoch = this.nodeEpochMap.get(nodeRef2.getNumber());
        if (epoch == null) {
            if (!this.tree.isExternal(nodeRef2)) {
                throw new RuntimeException("Internal node should be assigned to an epoch already.");
            }
            epoch = new Epoch(nodeRef2);
        }
        epoch.addInternalNode(nodeRef);
        this.nodeEpochMap.put(nodeRef.getNumber(), epoch);
        Epoch epoch2 = this.nodeEpochMap.get(nodeRef3.getNumber());
        if (epoch2 != null) {
            epoch2.endEpoch(nodeRef, epoch);
        }
    }

    private double getAnchorTipHeight(NodeRef nodeRef) {
        double d = this.tree.getNodeHeight(nodeRef);
        if (this.nodeEpochMap.containsKey(nodeRef.getNumber())) {
            d = this.nodeEpochMap.get(nodeRef.getNumber()).getAnchorTipHeight();
        }
        return d;
    }

    public double[] getRatios() {
        this.updateRatios();
        return this.ratios.getParameterValues();
    }

    @Override
    public double[] setMaskByHeightDifference(double d) {
        double[] dArray = new double[this.ratios.getDimension()];
        if (!this.epochKnown) {
            this.constructEpochs();
        }
        for (Epoch epoch : this.epochs) {
            for (int n : epoch.getInternalNodes()) {
                NodeRef nodeRef = this.tree.getNode(n);
                int n2 = this.getRatiosIndex(nodeRef);
                double d2 = this.getNodePartial(nodeRef);
                dArray[n2] = d2 < d ? 0.0 : 1.0;
            }
        }
        return dArray;
    }

    @Override
    public double[] setMaskByRatio(double d) {
        double[] dArray = new double[this.ratios.getDimension()];
        for (int i = 0; i < this.ratios.getDimension(); ++i) {
            if (!(this.ratios.getParameterValue(i) > d) || !(this.ratios.getParameterValue(i) < 1.0 - d)) continue;
            dArray[i] = 1.0;
        }
        return dArray;
    }

    @Override
    public void setNodeHeights(double[] dArray) {
        super.setNodeHeights(dArray);
        this.ratiosKnown = false;
    }

    private void checkNan(double[] dArray) {
        for (int i = 0; i < dArray.length; ++i) {
            if (!Double.isNaN(dArray[i])) continue;
            System.err.println("wrong");
        }
    }

    protected void updateRatios() {
        if (!this.ratiosKnown) {
            if (!this.epochKnown) {
                this.constructEpochs();
            }
            for (Epoch epoch : this.epochs) {
                double d = this.tree.getNodeHeight(this.tree.getNode(epoch.getConnectingNodeNumber()));
                double d2 = epoch.getAnchorTipHeight();
                for (int n : epoch.getInternalNodes()) {
                    NodeRef nodeRef = this.tree.getNode(n);
                    int n2 = this.getRatiosIndex(nodeRef);
                    double d3 = this.tree.getNodeHeight(nodeRef);
                    this.ratios.setParameterValueQuietly(n2, (d3 - d2) / (d - d2));
                    d = d3;
                }
            }
            this.ratiosKnown = true;
        }
    }

    public void setRatios(double[] dArray) {
        for (int i = 0; i < dArray.length; ++i) {
            this.ratios.setParameterValueQuietly(i, dArray[i]);
        }
        this.ratiosKnown = true;
    }

    protected void updateNodeHeights() {
        if (!this.epochKnown) {
            this.constructEpochs();
        }
        this.preOrderUpdateNodeHeights(this.tree, this.tree.getRoot(), null);
        this.tree.pushTreeChangedEvent(TreeChangedEvent.create(false, true));
        this.ratiosKnown = true;
    }

    private void preOrderUpdateNodeHeights(Tree tree, NodeRef nodeRef, NodeRef nodeRef2) {
        Object object;
        if (nodeRef2 != null && !tree.isExternal(nodeRef)) {
            object = this.nodeEpochMap.get(nodeRef.getNumber());
            double d = this.ratios.getParameterValue(this.getRatiosIndex(nodeRef));
            double d2 = d * (tree.getNodeHeight(nodeRef2) - ((Epoch)object).getAnchorTipHeight()) + ((Epoch)object).getAnchorTipHeight();
            this.nodeHeights.setParameterValueQuietly(this.getNodeHeightIndex(nodeRef), d2);
        }
        if (!tree.isExternal(nodeRef)) {
            object = tree.getChild(nodeRef, 0);
            NodeRef nodeRef3 = tree.getChild(nodeRef, 1);
            this.preOrderUpdateNodeHeights(tree, (NodeRef)object, nodeRef);
            this.preOrderUpdateNodeHeights(tree, nodeRef3, nodeRef);
        }
    }

    protected int getNodeHeightIndex(NodeRef nodeRef) {
        return this.getRatiosIndex(nodeRef);
    }

    protected int getRatiosIndex(NodeRef nodeRef) {
        return this.indexHelper.getParameterIndexFromNodeNumber(nodeRef.getNumber()) - this.tree.getExternalNodeCount();
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        TreeChangedEvent treeChangedEvent;
        if (model == this.tree && object instanceof TreeChangedEvent && (treeChangedEvent = (TreeChangedEvent)object).isTreeChanged()) {
            this.ratiosKnown = false;
            this.epochKnown = false;
        }
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (variable == this.nodeHeights) {
            this.ratiosKnown = false;
        }
    }

    @Override
    double[] transform(double[] dArray) {
        this.setNodeHeights(dArray);
        this.updateRatios();
        return this.getRatios();
    }

    @Override
    double[] inverse(double[] dArray) {
        this.setRatios(dArray);
        this.updateNodeHeights();
        return this.getNodeHeights().getParameterValues();
    }

    @Override
    String getReport() {
        this.updateRatios();
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("NodeHeights: ").append(new Vector(this.getNodeHeights().getParameterValues()));
        stringBuilder.append("\n");
        return stringBuilder.toString();
    }

    @Override
    Parameter getParameter() {
        this.updateRatios();
        return this.ratios;
    }

    @Override
    double getLogJacobian(double[] dArray) {
        double d = 0.0;
        for (int i = this.tree.getExternalNodeCount(); i < this.tree.getNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            d += Math.log(this.getNodePartial(nodeRef));
        }
        return -d;
    }

    protected int getNodeHeightGradientIndex(NodeRef nodeRef) {
        return nodeRef.getNumber() - this.tree.getExternalNodeCount();
    }

    @Override
    double[] updateGradientLogDensity(double[] dArray, double[] dArray2) {
        double[] dArray3 = this.getLogTimeArray();
        double[] dArray4 = this.updateGradientUnWeightedLogDensity(dArray3);
        double[] dArray5 = this.updateGradientUnWeightedLogDensity(dArray);
        for (int i = 0; i < this.ratios.getDimension(); ++i) {
            int n = i;
            dArray5[n] = dArray5[n] + (dArray4[i] - 1.0 / this.ratios.getParameterValue(i));
        }
        return dArray5;
    }

    protected double[] getLogTimeArray() {
        double[] dArray = new double[this.tree.getInternalNodeCount()];
        for (int i = 0; i < this.tree.getInternalNodeCount(); ++i) {
            int n = i + this.tree.getExternalNodeCount();
            NodeRef nodeRef = this.tree.getNode(n);
            if (this.tree.isRoot(nodeRef)) continue;
            dArray[i] = 1.0 / (this.tree.getNodeHeight(nodeRef) - this.nodeEpochMap.get(n).getAnchorTipHeight());
        }
        return dArray;
    }

    @Override
    double[] updateGradientUnWeightedLogDensity(double[] dArray, double[] dArray2, int n, int n2) {
        return this.updateGradientUnWeightedLogDensity(dArray);
    }

    private double[] updateGradientUnWeightedLogDensity(double[] dArray) {
        this.updateRatios();
        double[] dArray2 = new double[this.ratios.getDimension()];
        this.postOrderTraversal.updateAllNodes();
        this.postOrderTraversal.dispatchTreeTraversalCollectBranchAndNodeOperations();
        List<ProcessOnTreeDelegate.NodeOperation> list = this.postOrderTraversal.getNodeOperations();
        for (ProcessOnTreeDelegate.NodeOperation nodeOperation : list) {
            NodeRef nodeRef = this.tree.getNode(nodeOperation.getNodeNumber());
            NodeRef nodeRef2 = this.tree.getNode(nodeOperation.getLeftChild());
            NodeRef nodeRef3 = this.tree.getNode(nodeOperation.getRightChild());
            int n = this.getRatiosIndex(nodeRef);
            if (this.tree.isRoot(nodeRef)) continue;
            double d = this.getNodePartial(nodeRef);
            int n2 = n;
            dArray2[n2] = dArray2[n2] + d * dArray[this.getNodeHeightGradientIndex(nodeRef)];
            int n3 = n;
            dArray2[n3] = dArray2[n3] + this.getEpochGradientAddition(nodeRef, nodeRef2, dArray2);
            int n4 = n;
            dArray2[n4] = dArray2[n4] + this.getEpochGradientAddition(nodeRef, nodeRef3, dArray2);
        }
        return dArray2;
    }

    private double getNodePartial(NodeRef nodeRef) {
        return this.tree.getNodeHeight(this.tree.getParent(nodeRef)) - this.nodeEpochMap.get(nodeRef.getNumber()).getAnchorTipHeight();
    }

    private double getEpochGradientAddition(NodeRef nodeRef, NodeRef nodeRef2, double[] dArray) {
        int n = this.getRatiosIndex(nodeRef2);
        int n2 = this.getRatiosIndex(nodeRef);
        if (n < 0) {
            return 0.0;
        }
        if (this.nodeEpochMap.get(nodeRef2.getNumber()) == this.nodeEpochMap.get(nodeRef.getNumber())) {
            return dArray[n] * this.ratios.getParameterValue(n) / this.ratios.getParameterValue(n2);
        }
        return dArray[n] * this.ratios.getParameterValue(n) / (this.tree.getNodeHeight(nodeRef) - this.nodeEpochMap.get(nodeRef2.getNumber()).getAnchorTipHeight()) * this.getNodePartial(nodeRef);
    }

    protected class Epoch
    implements Comparable {
        private final int anchorTipNodeNumber;
        private List<Integer> internalNodes = new ArrayList<Integer>();
        private Epoch lastEpoch;
        private int connectingNodeNumber;

        private Epoch(NodeRef nodeRef) {
            this.anchorTipNodeNumber = nodeRef.getNumber();
            NodeHeightToRatiosTransformDelegate.this.epochs.add(this);
        }

        public double getAnchorTipHeight() {
            return NodeHeightToRatiosTransformDelegate.this.tree.getNodeHeight(NodeHeightToRatiosTransformDelegate.this.tree.getNode(this.anchorTipNodeNumber));
        }

        public void endEpoch(NodeRef nodeRef, Epoch epoch) {
            this.lastEpoch = epoch;
            this.connectingNodeNumber = nodeRef.getNumber();
        }

        public void addInternalNode(NodeRef nodeRef) {
            this.internalNodes.add(0, nodeRef.getNumber());
        }

        public List<Integer> getInternalNodes() {
            return this.internalNodes;
        }

        public int getConnectingNodeNumber() {
            return this.connectingNodeNumber;
        }

        public int compareTo(Object object) {
            return Double.compare(this.getAnchorTipHeight(), ((Epoch)object).getAnchorTipHeight());
        }
    }
}

