/*
 * Decompiled with CFR 0.152.
 */
package dr.util;

import dr.util.Transform;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;

public class EuclideanToInfiniteNormUnitBallTransform
extends Transform.MultivariateTransform {
    public static final String NAME = "sphericalTransform";
    public static final String DIMENSION = "dim";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newIntegerRule("dim", false)};

        @Override
        public String getParserName() {
            return EuclideanToInfiniteNormUnitBallTransform.NAME;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            int n = xMLObject.getIntegerAttribute(EuclideanToInfiniteNormUnitBallTransform.DIMENSION);
            Transform.Array array = new Transform.Array(Transform.FISHER_Z, n * (n + 1), null);
            ArrayList<Transform.MultivariableTransform> arrayList = new ArrayList<Transform.MultivariableTransform>();
            for (int i = 0; i < n + 1; ++i) {
                arrayList.add(new EuclideanToInfiniteNormUnitBallTransform(n));
            }
            Transform.MultivariateArray multivariateArray = new Transform.MultivariateArray(arrayList);
            return new Transform.ComposeMultivariable(array, multivariateArray);
        }

        @Override
        public String getParserDescription() {
            return "A spherical transform using Fisher Z and LKJ.";
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override
        public Class getReturnType() {
            return Transform.ComposeMultivariable.class;
        }
    };

    public EuclideanToInfiniteNormUnitBallTransform(int n) {
        super(n);
    }

    @Override
    public double[] transform(double[] dArray) {
        assert (this.isInEuclideanUnitBall(dArray)) : "Initial vector is not in the Euclidean unit ball.";
        double[] dArray2 = new double[dArray.length];
        double d = 1.0;
        for (int i = 0; i < this.dim; ++i) {
            double d2;
            dArray2[i] = d2 = dArray[i] / d;
            d *= Math.sqrt(1.0 - d2 * d2);
        }
        return dArray2;
    }

    @Override
    public double[] inverse(double[] dArray) {
        assert (this.isInInfiniteUnitBall(dArray)) : "Initial vector is not in the Euclidean unit ball.";
        double[] dArray2 = new double[dArray.length];
        double d = 1.0;
        for (int i = 0; i < this.dim; ++i) {
            double d2 = dArray[i];
            dArray2[i] = d2 * d;
            d *= Math.sqrt(1.0 - d2 * d2);
        }
        return dArray2;
    }

    private boolean isInEuclideanUnitBall(double[] dArray) {
        return EuclideanToInfiniteNormUnitBallTransform.squaredNorm(dArray) <= 1.0;
    }

    private boolean isInStrictEuclideanUnitBall(double[] dArray) {
        return EuclideanToInfiniteNormUnitBallTransform.squaredNorm(dArray) <= 1.0;
    }

    @Override
    public boolean isInInteriorDomain(double[] dArray) {
        return this.isInStrictEuclideanUnitBall(dArray);
    }

    private boolean isInInfiniteUnitBall(double[] dArray) {
        for (int i = 0; i < this.dim; ++i) {
            if (dArray[i] <= 1.0 && dArray[i] >= -1.0) continue;
            return false;
        }
        return true;
    }

    public static double squaredNorm(double[] dArray) {
        return EuclideanToInfiniteNormUnitBallTransform.squaredNorm(dArray, 0, dArray.length);
    }

    public static double squaredNorm(double[] dArray, int n, int n2) {
        double d = 0.0;
        for (int i = 0; i < n2; ++i) {
            d += dArray[n + i] * dArray[n + i];
        }
        return d;
    }

    public static double projection(double[] dArray) {
        return EuclideanToInfiniteNormUnitBallTransform.projection(dArray, 0, dArray.length);
    }

    public static double projection(double[] dArray, int n, int n2) {
        return Math.sqrt(1.0 - EuclideanToInfiniteNormUnitBallTransform.squaredNorm(dArray, n, n2));
    }

    @Override
    public double[] inverse(double[] dArray, int n, int n2, double d) {
        throw new RuntimeException("Not relevant.");
    }

    @Override
    public String getTransformName() {
        return "EuclideanToInfiniteNormUnitBallTransform";
    }

    @Override
    public double[] gradient(double[] dArray, int n, int n2) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double[] gradientInverse(double[] dArray, int n, int n2) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    protected double getLogJacobian(double[] dArray) {
        double[] dArray2 = this.transform(dArray);
        double d = 0.0;
        for (int i = 0; i < this.dim - 1; ++i) {
            d += (double)(this.dim - i - 1) * Math.log(1.0 - Math.pow(dArray2[i], 2.0));
        }
        return -0.5 * d;
    }

    @Override
    protected double[] getGradientLogJacobianInverse(double[] dArray) {
        double[] dArray2 = new double[dArray.length];
        for (int i = 0; i < this.dim - 1; ++i) {
            dArray2[i] = (double)(-(this.dim - i - 1)) * dArray[i] / (1.0 - Math.pow(dArray[i], 2.0));
        }
        return dArray2;
    }

    @Override
    public double[][] computeJacobianMatrixInverse(double[] dArray) {
        double[][] dArray2 = new double[this.dim][this.dim];
        for (int i = 0; i < this.dim; ++i) {
            int n;
            double d = 1.0;
            for (n = 0; n < i; ++n) {
                d *= Math.sqrt(1.0 - Math.pow(dArray[n], 2.0));
            }
            dArray2[i][i] = d;
            double d2 = dArray[i];
            d *= -d2 / Math.sqrt(1.0 - Math.pow(d2, 2.0));
            for (n = i + 1; n < this.dim; ++n) {
                d2 = dArray[n];
                dArray2[i][n] = d2 * d;
                d *= Math.sqrt(1.0 - Math.pow(d2, 2.0));
            }
        }
        return dArray2;
    }
}

