/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treelikelihood;

import dr.evolution.alignment.PatternList;
import dr.evolution.alignment.UncertainSiteList;
import dr.evolution.datatype.DataType;
import dr.evolution.datatype.GeneralDataType;
import dr.evolution.datatype.HiddenCodons;
import dr.evolution.datatype.HiddenDataType;
import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.branchmodel.BranchModel;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.evomodel.tipstatesmodel.TipStatesModel;
import dr.evomodel.treelikelihood.AncestralStateTraitProvider;
import dr.evomodel.treelikelihood.BeagleTreeLikelihood;
import dr.evomodel.treelikelihood.PartialsRescalingScheme;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.math.MathUtils;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

public class AncestralStateBeagleTreeLikelihood
extends BeagleTreeLikelihood
implements TreeTraitProvider,
AncestralStateTraitProvider {
    protected TreeTraitProvider.Helper treeTraits = new TreeTraitProvider.Helper();
    private final DataType dataType;
    private int[][] reconstructedStates;
    private int[][] storedReconstructedStates;
    protected boolean areStatesRedrawn = false;
    protected boolean storedAreStatesRedrawn = false;
    private boolean useMAP = false;
    private boolean returnMarginalLogLikelihood = true;
    private double jointLogLikelihood;
    private double storedJointLogLikelihood;
    private int[][] tipStates;
    private double[][] tipPartials;
    private double[] probabilities;
    private double[] partials;
    protected int[] rateCategory = null;
    private final boolean conditionalProbabilitiesInLogSpace;
    private final CodeFormatter formatter;

    public AncestralStateBeagleTreeLikelihood(PatternList patternList, MutableTreeModel mutableTreeModel, BranchModel branchModel, SiteRateModel siteRateModel, BranchRateModel branchRateModel, TipStatesModel tipStatesModel, boolean bl, PartialsRescalingScheme partialsRescalingScheme, boolean bl2, Map<Set<String>, Parameter> map, DataType dataType, final String string, boolean bl3, boolean bl4, boolean bl5) {
        super(patternList, mutableTreeModel, branchModel, siteRateModel, branchRateModel, tipStatesModel, bl, partialsRescalingScheme, bl2, map);
        int n;
        this.conditionalProbabilitiesInLogSpace = bl5;
        this.dataType = dataType;
        this.probabilities = new double[this.stateCount * this.stateCount * this.categoryCount];
        this.partials = new double[this.stateCount * this.patternCount * this.categoryCount];
        if (this.useAmbiguities()) {
            this.tipPartials = new double[this.tipCount][];
        } else {
            this.tipStates = new int[this.tipCount][];
        }
        for (n = 0; n < this.tipCount; n += 1) {
            String string2 = mutableTreeModel.getTaxonId(n);
            int n2 = patternList.getTaxonIndex(string2);
            if (this.useAmbiguities()) {
                this.tipPartials[n] = this.getPartials(patternList, n2);
                continue;
            }
            this.tipStates[n] = this.getStates(patternList, n2);
        }
        this.reconstructedStates = new int[mutableTreeModel.getNodeCount()][this.patternCount];
        this.storedReconstructedStates = new int[mutableTreeModel.getNodeCount()][this.patternCount];
        this.useMAP = bl3;
        this.returnMarginalLogLikelihood = bl4;
        n = 0;
        this.formatter = new CodeFormatter(dataType, n != 0);
        this.treeTraits.addTrait(new TreeTrait.IA(){

            @Override
            public String getTraitName() {
                return string;
            }

            @Override
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.NODE;
            }

            @Override
            public Class getTraitClass() {
                return int[].class;
            }

            @Override
            public int[] getTrait(Tree tree, NodeRef nodeRef) {
                return AncestralStateBeagleTreeLikelihood.this.getStatesForNode(tree, nodeRef);
            }

            @Override
            public String getTraitString(Tree tree, NodeRef nodeRef) {
                return AncestralStateBeagleTreeLikelihood.formattedState(AncestralStateBeagleTreeLikelihood.this.getStatesForNode(tree, nodeRef), AncestralStateBeagleTreeLikelihood.this.formatter);
            }
        });
    }

    public AncestralStateBeagleTreeLikelihood(PatternList patternList, MutableTreeModel mutableTreeModel, BranchModel branchModel, SiteRateModel siteRateModel, BranchRateModel branchRateModel, TipStatesModel tipStatesModel, boolean bl, PartialsRescalingScheme partialsRescalingScheme, boolean bl2, Map<Set<String>, Parameter> map, DataType dataType, String string, boolean bl3, boolean bl4) {
        this(patternList, mutableTreeModel, branchModel, siteRateModel, branchRateModel, tipStatesModel, bl, partialsRescalingScheme, bl2, map, dataType, string, bl3, bl4, false);
    }

    private double[] getPartials(PatternList patternList, int n) {
        double[] dArray = new double[this.patternCount * this.stateCount];
        int n2 = 0;
        for (int i = 0; i < this.patternCount; ++i) {
            if (patternList instanceof UncertainSiteList) {
                ((UncertainSiteList)patternList).fillPartials(n, i, dArray, n2);
                n2 += this.stateCount;
                continue;
            }
            int n3 = patternList.getPatternState(n, i);
            boolean[] blArray = this.dataType.getStateSet(n3);
            for (int j = 0; j < this.stateCount; ++j) {
                dArray[n2] = blArray[j] ? 1.0 : 0.0;
                ++n2;
            }
        }
        return dArray;
    }

    private int[] getStates(PatternList patternList, int n) {
        int[] nArray = new int[this.patternCount];
        for (int i = 0; i < this.patternCount; ++i) {
            nArray[i] = patternList.getPatternState(n, i);
        }
        return nArray;
    }

    @Override
    public BranchModel getBranchModel() {
        return this.branchModel;
    }

    @Override
    public TreeTrait[] getTreeTraits() {
        return this.treeTraits.getTreeTraits();
    }

    @Override
    public TreeTrait getTreeTrait(String string) {
        return this.treeTraits.getTreeTrait(string);
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        super.handleModelChangedEvent(model, object, n);
        this.fireModelChanged(model);
    }

    public int[] getStatesForNode(Tree tree, NodeRef nodeRef) {
        if (tree != this.treeModel) {
            throw new RuntimeException("Can only reconstruct states on treeModel given to constructor");
        }
        if (!this.likelihoodKnown) {
            this.calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        if (!this.areStatesRedrawn) {
            this.redrawAncestralStates();
        }
        return this.reconstructedStates[nodeRef.getNumber()];
    }

    @Override
    protected int getScaleBufferCount() {
        return this.internalNodeCount + 2;
    }

    private int drawChoice(double[] dArray) {
        if (this.useMAP) {
            double d = dArray[0];
            int n = 0;
            for (int i = 1; i < dArray.length; ++i) {
                if (!(dArray[i] > d)) continue;
                d = dArray[i];
                n = i;
            }
            return n;
        }
        if (this.conditionalProbabilitiesInLogSpace) {
            return MathUtils.randomChoiceLogPDF(dArray);
        }
        return MathUtils.randomChoicePDF(dArray);
    }

    @Override
    public void makeDirty() {
        super.makeDirty();
        this.areStatesRedrawn = false;
    }

    public void redrawAncestralStates() {
        this.jointLogLikelihood = 0.0;
        this.traverseSample(this.treeModel, this.treeModel.getRoot(), null, null);
        this.areStatesRedrawn = true;
    }

    @Override
    protected double calculateLogLikelihood() {
        this.areStatesRedrawn = false;
        double d = super.calculateLogLikelihood();
        if (this.returnMarginalLogLikelihood) {
            return d;
        }
        this.redrawAncestralStates();
        return this.jointLogLikelihood;
    }

    @Override
    public String formattedState(int[] nArray) {
        return AncestralStateBeagleTreeLikelihood.formattedState(nArray, this.formatter);
    }

    private static String formattedState(int[] nArray, CodeFormatter codeFormatter) {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("\"");
        codeFormatter.reset();
        for (int n : nArray) {
            stringBuffer.append(codeFormatter.getCodeString(n));
        }
        stringBuffer.append("\"");
        return stringBuffer.toString();
    }

    protected void getMatrix(int n, double[] dArray) {
        this.beagle.getTransitionMatrix(this.substitutionModelDelegate.getMatrixIndex(n), dArray);
    }

    public void setTipStates(int n, int[] nArray) {
        System.arraycopy(nArray, 0, this.tipStates[n], 0, nArray.length);
        this.beagle.setTipStates(n, nArray);
        this.makeDirty();
    }

    public void getTipStates(int n, int[] nArray) {
        System.arraycopy(this.tipStates[n], 0, nArray, 0, nArray.length);
    }

    @Override
    public void storeState() {
        super.storeState();
        if (this.areStatesRedrawn) {
            for (int i = 0; i < this.reconstructedStates.length; ++i) {
                System.arraycopy(this.reconstructedStates[i], 0, this.storedReconstructedStates[i], 0, this.reconstructedStates[i].length);
            }
        }
        this.storedAreStatesRedrawn = this.areStatesRedrawn;
        this.storedJointLogLikelihood = this.jointLogLikelihood;
    }

    @Override
    public void restoreState() {
        super.restoreState();
        int[][] nArray = this.reconstructedStates;
        this.reconstructedStates = this.storedReconstructedStates;
        this.storedReconstructedStates = nArray;
        this.areStatesRedrawn = this.storedAreStatesRedrawn;
        this.jointLogLikelihood = this.storedJointLogLikelihood;
    }

    public void traverseSample(Tree tree, NodeRef nodeRef, int[] nArray, int[] nArray2) {
        int n = nodeRef.getNumber();
        NodeRef nodeRef2 = tree.getParent(nodeRef);
        double[] dArray = new double[this.stateCount];
        int[] nArray3 = new int[this.patternCount];
        if (!tree.isExternal(nodeRef)) {
            if (nodeRef2 == null) {
                this.getPartials(n, this.partials);
                boolean bl = this.categoryCount > 1;
                double[] dArray2 = null;
                double[] dArray3 = null;
                if (bl) {
                    nArray2 = new int[this.patternCount];
                    dArray2 = new double[this.categoryCount];
                    dArray3 = this.siteRateModel.getCategoryProportions();
                }
                for (int i = 0; i < this.patternCount; ++i) {
                    int n2;
                    if (bl) {
                        for (n2 = 0; n2 < this.categoryCount; ++n2) {
                            dArray2[n2] = 0.0;
                            for (int j = 0; j < this.stateCount; ++j) {
                                int n3 = n2;
                                dArray2[n3] = dArray2[n3] + this.partials[n2 * this.stateCount * this.patternCount + i * this.stateCount + j];
                            }
                            int n4 = n2;
                            dArray2[n4] = dArray2[n4] * dArray3[n2];
                        }
                        nArray2[i] = this.drawChoice(dArray2);
                    }
                    n2 = (nArray2 == null ? 0 : nArray2[i]) * this.stateCount * this.patternCount + i * this.stateCount;
                    double[] dArray4 = this.substitutionModelDelegate.getRootStateFrequencies();
                    for (int j = 0; j < this.stateCount; ++j) {
                        dArray[j] = this.conditionalProbabilitiesInLogSpace ? Math.log(this.partials[n2 + j]) + Math.log(dArray4[j]) : this.partials[n2 + j] * dArray4[j];
                    }
                    try {
                        nArray3[i] = this.drawChoice(dArray);
                    }
                    catch (Error error) {
                        System.err.println(error.toString());
                        System.err.println("Please report error to Marc");
                        nArray3[i] = 0;
                    }
                    this.reconstructedStates[n][i] = nArray3[i];
                    if (this.returnMarginalLogLikelihood) continue;
                    this.jointLogLikelihood += Math.log(dArray4[nArray3[i]]);
                }
                if (bl) {
                    if (this.rateCategory == null) {
                        this.rateCategory = new int[this.patternCount];
                    }
                    System.arraycopy(nArray2, 0, this.rateCategory, 0, this.patternCount);
                }
            } else {
                double[] dArray5 = new double[this.stateCount * this.patternCount * this.categoryCount];
                this.getPartials(n, dArray5);
                this.getMatrix(n, this.probabilities);
                for (int i = 0; i < this.patternCount; ++i) {
                    int n5 = nArray[i] * this.stateCount;
                    int n6 = i * this.stateCount;
                    int n7 = nArray2 == null ? 0 : nArray2[i];
                    int n8 = n7 * this.stateCount * this.stateCount;
                    int n9 = n7 * this.stateCount * this.patternCount;
                    for (int j = 0; j < this.stateCount; ++j) {
                        dArray[j] = this.conditionalProbabilitiesInLogSpace ? Math.log(dArray5[n9 + n6 + j]) + Math.log(this.probabilities[n8 + n5 + j]) : dArray5[n9 + n6 + j] * this.probabilities[n8 + n5 + j];
                    }
                    nArray3[i] = this.drawChoice(dArray);
                    this.reconstructedStates[n][i] = nArray3[i];
                    if (this.returnMarginalLogLikelihood) continue;
                    double d = this.probabilities[n5 + nArray3[i]];
                    this.jointLogLikelihood += Math.log(d);
                }
                this.hookCalculation(tree, nodeRef2, nodeRef, nArray, nArray3, this.probabilities, nArray2);
            }
            NodeRef nodeRef3 = tree.getChild(nodeRef, 0);
            this.traverseSample(tree, nodeRef3, nArray3, nArray2);
            NodeRef nodeRef4 = tree.getChild(nodeRef, 1);
            this.traverseSample(tree, nodeRef4, nArray3, nArray2);
        } else {
            if (this.useAmbiguities()) {
                this.getMatrix(n, this.probabilities);
                double[] dArray6 = this.tipPartials[n];
                for (int i = 0; i < this.patternCount; ++i) {
                    int n10 = nArray[i] * this.stateCount;
                    int n11 = nArray2 == null ? 0 : nArray2[i];
                    int n12 = n11 * this.stateCount * this.stateCount;
                    int n13 = n10 + n12;
                    for (int j = 0; j < this.stateCount; ++j) {
                        dArray[j] = this.conditionalProbabilitiesInLogSpace ? Math.log(this.probabilities[n13 + j]) + Math.log(dArray6[i * this.stateCount + j]) : this.probabilities[n13 + j] * dArray6[i * this.stateCount + j];
                    }
                    this.reconstructedStates[n][i] = this.drawChoice(dArray);
                    if (this.returnMarginalLogLikelihood) continue;
                    double d = this.probabilities[n10 + this.reconstructedStates[n][i]];
                    this.jointLogLikelihood += Math.log(d);
                }
            } else {
                this.getTipStates(n, this.reconstructedStates[n]);
                for (int i = 0; i < this.patternCount; ++i) {
                    int n14;
                    int n15 = this.reconstructedStates[n][i];
                    if (this.dataType.isAmbiguousState(n15)) {
                        n14 = nArray[i] * this.stateCount;
                        int n16 = nArray2 == null ? 0 : nArray2[i];
                        int n17 = n16 * this.stateCount * this.stateCount;
                        this.getMatrix(n, this.probabilities);
                        System.arraycopy(this.probabilities, n14 + n17, dArray, 0, this.stateCount);
                        if (this.useAmbiguities && !this.dataType.isUnknownState(n15)) {
                            boolean[] blArray = this.dataType.getStateSet(n15);
                            for (int j = 0; j < this.stateCount; ++j) {
                                if (blArray[j]) continue;
                                dArray[j] = 0.0;
                            }
                        }
                        if (this.conditionalProbabilitiesInLogSpace) {
                            for (int j = 0; j < this.stateCount; ++j) {
                                dArray[j] = Math.log(dArray[j]);
                            }
                        }
                        this.reconstructedStates[n][i] = this.drawChoice(dArray);
                    }
                    if (this.returnMarginalLogLikelihood) continue;
                    n14 = nArray[i] * this.stateCount;
                    this.getMatrix(n, this.probabilities);
                    if (this.returnMarginalLogLikelihood) continue;
                    double d = this.probabilities[n14 + this.reconstructedStates[n][i]];
                    this.jointLogLikelihood += Math.log(d);
                }
            }
            this.hookCalculation(tree, nodeRef2, nodeRef, nArray, this.reconstructedStates[n], null, nArray2);
        }
    }

    protected void hookCalculation(Tree tree, NodeRef nodeRef, NodeRef nodeRef2, int[] nArray, int[] nArray2, double[] dArray, int[] nArray3) {
    }

    private class CodeFormatter {
        private final DataType dataType;
        private final Function<String, String> appender;
        private final Function<Integer, String> getter;
        private boolean first = true;

        CodeFormatter(DataType dataType, boolean bl) {
            this.dataType = dataType;
            Function<String, String> function = this.appender = dataType instanceof GeneralDataType ? string -> string + " " : Function.identity();
            this.getter = dataType instanceof HiddenCodons ? (bl ? ((HiddenCodons)dataType)::getTripletWithoutHiddenCode : dataType::getTriplet) : (dataType instanceof HiddenDataType && bl ? ((HiddenDataType)((Object)dataType))::getCodeWithoutHiddenState : dataType::getCode);
        }

        String getCodeString(int n) {
            String string = this.getter.apply(n);
            if (this.first) {
                this.first = false;
            } else {
                string = this.appender.apply(string);
            }
            return string;
        }

        void reset() {
            this.first = true;
        }
    }
}

