/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.factorAnalysis;

import dr.inference.distribution.NormalStatisticsProvider;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.model.ScaledMatrixParameter;
import dr.inference.model.Variable;
import dr.inference.model.VariableListener;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.inference.operators.factorAnalysis.FactorAnalysisStatisticsProvider;
import dr.inference.operators.rejection.RejectionOperator;
import dr.inferencexml.operators.factorAnalysis.LoadingsOperatorParserUtilities;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.Vector;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.Reportable;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

public class LoadingsScaleGibbsOperator
extends SimpleMCMCOperator
implements GibbsOperator,
RejectionOperator.RejectionProvider,
VariableListener,
Reportable {
    private final Parameter sccaleComponent;
    private final MatrixParameterInterface matrixComponent;
    private final FactorAnalysisStatisticsProvider statisticsProvider;
    private final int nTraits;
    private final int nFactors;
    private final NormalStatisticsProvider prior;
    private final double[] mean;
    private final double[][] variance;
    private final Parameter[] listeningParameters;
    private boolean needToUpdateStatistics = true;
    private double[][] cholesky;
    private static final String LOADINGS_SCALE_OPERATOR = "loadingsScaleGibbsOperator";
    public static AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser(){

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            FactorAnalysisStatisticsProvider factorAnalysisStatisticsProvider = LoadingsOperatorParserUtilities.parseAdaptorAndStatistics((XMLObject)xMLObject);
            NormalStatisticsProvider normalStatisticsProvider = (NormalStatisticsProvider)xMLObject.getChild(NormalStatisticsProvider.class);
            MatrixParameterInterface matrixParameterInterface = factorAnalysisStatisticsProvider.getAdaptor().getLoadings();
            if (!(matrixParameterInterface instanceof ScaledMatrixParameter)) {
                throw new XMLParseException("The loadings matrix is of class" + matrixParameterInterface.getClass() + ". It must be of class " + ScaledMatrixParameter.class);
            }
            return new LoadingsScaleGibbsOperator(factorAnalysisStatisticsProvider, normalStatisticsProvider);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return XMLSyntaxRule.Utils.concatenate(LoadingsOperatorParserUtilities.statisticsProviderRules, new XMLSyntaxRule[]{new ElementRule(NormalStatisticsProvider.class)});
        }

        @Override
        public String getParserDescription() {
            return "Gibbs operator for scale component of loadings.";
        }

        @Override
        public Class getReturnType() {
            return LoadingsScaleGibbsOperator.class;
        }

        @Override
        public String getParserName() {
            return LoadingsScaleGibbsOperator.LOADINGS_SCALE_OPERATOR;
        }
    };

    LoadingsScaleGibbsOperator(FactorAnalysisStatisticsProvider factorAnalysisStatisticsProvider, NormalStatisticsProvider normalStatisticsProvider) {
        ScaledMatrixParameter scaledMatrixParameter = (ScaledMatrixParameter)factorAnalysisStatisticsProvider.getAdaptor().getLoadings();
        this.sccaleComponent = scaledMatrixParameter.getScaleParameter();
        this.matrixComponent = scaledMatrixParameter.getMatrixParameter();
        this.statisticsProvider = factorAnalysisStatisticsProvider;
        this.nTraits = this.matrixComponent.getRowDimension();
        this.nFactors = this.matrixComponent.getColumnDimension();
        this.prior = normalStatisticsProvider;
        this.mean = new double[this.nFactors];
        this.variance = new double[this.nFactors][this.nFactors];
        for (Parameter parameter : this.listeningParameters = factorAnalysisStatisticsProvider.getAdaptor().getLoadingsDependentParameter()) {
            parameter.addParameterListener(this);
        }
    }

    @Override
    public String getOperatorName() {
        return LOADINGS_SCALE_OPERATOR;
    }

    @Override
    public double doOperation() {
        double[] dArray = this.getProposedUpdate();
        for (int i = 0; i < this.nFactors; ++i) {
            this.sccaleComponent.setParameterValueQuietly(i, dArray[i]);
        }
        this.sccaleComponent.fireParameterChangedEvent();
        return 0.0;
    }

    @Override
    public double[] getProposedUpdate() {
        if (this.needToUpdateStatistics) {
            this.statisticsProvider.getAdaptor().drawFactors();
            this.updateMeanAndVariance();
            try {
                this.cholesky = new CholeskyDecomposition(this.variance).getL();
            }
            catch (IllegalDimension illegalDimension) {
                illegalDimension.printStackTrace();
            }
            this.needToUpdateStatistics = false;
        }
        double[] dArray = MultivariateNormalDistribution.nextMultivariateNormalCholesky(this.mean, this.cholesky);
        return dArray;
    }

    @Override
    public Parameter getParameter() {
        return this.sccaleComponent;
    }

    private void updateMeanAndVariance() {
        int n;
        int n2;
        double d;
        int n3;
        double[][] dArray = new double[this.nFactors][this.nFactors];
        double[][] dArray2 = new double[this.nFactors][this.nFactors];
        double[] dArray3 = new double[this.nFactors];
        for (n3 = 0; n3 < this.nTraits; ++n3) {
            d = this.statisticsProvider.getAdaptor().getColumnPrecision(n3);
            this.statisticsProvider.getFactorInnerProduct(n3, this.nFactors, dArray2);
            this.statisticsProvider.getFactorTraitProduct(n3, this.nFactors, this.mean);
            for (int i = 0; i < this.nFactors; ++i) {
                double d2 = this.matrixComponent.getParameterValue(n3, i);
                double[] dArray4 = dArray[i];
                int n4 = i;
                dArray4[n4] = dArray4[n4] + dArray2[i][i] * d2 * d2 * d;
                for (int j = i + 1; j < this.nFactors; ++j) {
                    double d3 = this.matrixComponent.getParameterValue(n3, j);
                    double[] dArray5 = dArray[i];
                    int n5 = j;
                    dArray5[n5] = dArray5[n5] + dArray2[i][j] * d2 * d3 * d;
                    dArray[j][i] = dArray[i][j];
                }
                this.statisticsProvider.getFactorTraitProduct(n3, this.nFactors, this.mean);
                int n6 = i;
                dArray3[n6] = dArray3[n6] + this.mean[i] * d2 * d;
            }
        }
        for (n3 = 0; n3 < this.nFactors; ++n3) {
            d = this.prior.getNormalSD(n3);
            double d4 = 1.0 / (d * d);
            double[] dArray6 = dArray[n3];
            int n7 = n3;
            dArray6[n7] = dArray6[n7] + d4;
            int n8 = n3;
            dArray3[n8] = dArray3[n8] + d4 * this.prior.getNormalMean(n3);
        }
        SymmetricMatrix symmetricMatrix = new SymmetricMatrix(dArray).inverse();
        for (n2 = 0; n2 < this.nFactors; ++n2) {
            for (n = 0; n < this.nFactors; ++n) {
                this.variance[n2][n] = symmetricMatrix.component(n2, n);
            }
        }
        for (n2 = 0; n2 < this.nFactors; ++n2) {
            this.mean[n2] = 0.0;
            for (n = 0; n < this.nFactors; ++n) {
                int n9 = n2;
                this.mean[n9] = this.mean[n9] + this.variance[n2][n] * dArray3[n];
            }
        }
    }

    @Override
    public void variableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        for (Parameter parameter : this.listeningParameters) {
            if (variable != parameter) continue;
            this.needToUpdateStatistics = true;
            break;
        }
    }

    @Override
    public String getReport() {
        this.updateMeanAndVariance();
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append(this.getOperatorName() + "Report:\n");
        stringBuilder.append("Scale mean:\n");
        stringBuilder.append(new Vector(this.mean));
        stringBuilder.append("\n\n");
        stringBuilder.append("Scale covariance:\n");
        stringBuilder.append(new Matrix(this.variance));
        stringBuilder.append("\n\n");
        return stringBuilder.toString();
    }
}

