/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc.deprecated;

import dr.inference.distribution.MomentDistributionModel;
import dr.inference.model.LatentFactorModel;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Variable;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.hmc.deprecated.AbstractHamiltonianMCOperator;

@Deprecated
public class LoadingsHamiltonianMC
extends AbstractHamiltonianMCOperator {
    private LatentFactorModel lfm;
    private MomentDistributionModel prior;
    private MatrixParameterInterface factors;
    private MatrixParameterInterface loadings;
    private MatrixParameterInterface Precision;
    private int nfac;
    private int ntaxa;
    private int ntraits;
    private double stepSize;
    private int nSteps;

    public LoadingsHamiltonianMC(LatentFactorModel latentFactorModel, MomentDistributionModel momentDistributionModel, double d, AdaptationMode adaptationMode, double d2, int n, double d3, MatrixParameterInterface matrixParameterInterface) {
        super(adaptationMode, d3);
        this.setWeight(d);
        this.lfm = latentFactorModel;
        this.prior = momentDistributionModel;
        this.factors = latentFactorModel.getFactors();
        this.loadings = matrixParameterInterface;
        this.Precision = latentFactorModel.getColumnPrecision();
        this.nfac = latentFactorModel.getFactorDimension();
        this.ntaxa = latentFactorModel.getFactors().getColumnDimension();
        this.ntraits = this.Precision.getRowDimension();
        this.stepSize = d2;
        this.nSteps = n;
    }

    @Override
    protected double getAdaptableParameterValue() {
        return 0.0;
    }

    @Override
    public void setAdaptableParameterValue(double d) {
    }

    @Override
    public double getRawParameter() {
        return 0.0;
    }

    @Override
    public String getOperatorName() {
        return "LoadingsHamiltonianMC";
    }

    @Override
    public double doOperation() {
        int n;
        int n2;
        int n3;
        double[][] dArray = this.getGradient();
        this.drawMomentum(this.lfm.getFactorDimension() * this.ntraits);
        double d = this.stepSize;
        double d2 = 0.0;
        for (n3 = 0; n3 < this.momentum.length; ++n3) {
            d2 += this.momentum[n3] * this.momentum[n3] / (2.0 * this.getMomentumSd() * this.getMomentumSd());
        }
        for (n3 = 0; n3 < this.lfm.getFactorDimension(); ++n3) {
            for (n2 = 0; n2 < this.ntraits; ++n2) {
                this.momentum[n3 * this.ntraits + n2] = this.momentum[n3 * this.ntraits + n2] - d / 2.0 * dArray[n2][n3];
            }
        }
        for (n3 = 0; n3 < this.nSteps; ++n3) {
            for (n2 = 0; n2 < this.lfm.getFactorDimension(); ++n2) {
                for (n = 0; n < this.ntraits; ++n) {
                    this.loadings.setParameterValueQuietly(n, n2, this.loadings.getParameterValue(n, n2) + d * this.momentum[n2 * this.ntraits + n]);
                }
            }
            this.loadings.fireParameterChangedEvent(-1, Variable.ChangeType.ALL_VALUES_CHANGED);
            if (n3 == this.nSteps) continue;
            dArray = this.getGradient();
            for (n2 = 0; n2 < this.lfm.getFactorDimension(); ++n2) {
                for (n = 0; n < this.ntraits; ++n) {
                    this.momentum[n2 * this.ntraits + n] = this.momentum[n2 * this.ntraits + n] - d * dArray[n][n2];
                }
            }
        }
        dArray = this.getGradient();
        for (n3 = 0; n3 < this.lfm.getFactorDimension(); ++n3) {
            for (n2 = 0; n2 < this.ntraits; ++n2) {
                this.momentum[n3 * this.ntraits + n2] = this.momentum[n3 * this.ntraits + n2] - d / 2.0 * dArray[n2][n3];
            }
        }
        double d3 = 0.0;
        for (n = 0; n < this.momentum.length; ++n) {
            d3 += this.momentum[n] * this.momentum[n] / (2.0 * this.getMomentumSd() * this.getMomentumSd());
        }
        return d2 - d3;
    }

    private double[][] getLFMDerivative() {
        int n;
        int n2;
        double[] dArray = this.lfm.getResidual();
        double[][] dArray2 = new double[this.ntraits][this.lfm.getFactorDimension()];
        for (n2 = 0; n2 < this.ntaxa; ++n2) {
            for (n = 0; n < this.ntraits; ++n) {
                for (int i = 0; i < this.lfm.getFactorDimension(); ++i) {
                    double[] dArray3 = dArray2[n];
                    int n3 = i;
                    dArray3[n3] = dArray3[n3] - dArray[n2 * this.ntaxa + n] * this.factors.getParameterValue(i, n2);
                }
            }
        }
        for (n2 = 0; n2 < this.ntraits; ++n2) {
            n = 0;
            while (n < this.lfm.getFactorDimension()) {
                double[] dArray4 = dArray2[n2];
                int n4 = n++;
                dArray4[n4] = dArray4[n4] * this.Precision.getParameterValue(n2, n2);
            }
        }
        return dArray2;
    }

    private double[][] getGradient() {
        double[][] dArray = this.getLFMDerivative();
        for (int i = 0; i < this.loadings.getRowDimension(); ++i) {
            for (int j = 0; j < this.loadings.getColumnDimension(); ++j) {
                double[] dArray2 = dArray[i];
                int n = j;
                dArray2[n] = dArray2[n] + (2.0 / this.loadings.getParameterValue(i, j) + (this.loadings.getParameterValue(i, j) - this.prior.getMean()[0]) / this.prior.getScaleMatrix()[0][0]);
            }
        }
        return dArray;
    }
}

