/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.PrecisionColumnProvider;
import dr.inference.hmc.PrecisionMatrixVectorProductProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.Parameter;
import dr.inference.operators.hmc.AbstractParticleOperator;
import dr.inference.operators.hmc.MinimumTravelInformation;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedVector;

public class BouncyParticleOperator
extends AbstractParticleOperator
implements Loggable {
    private WrappedVector storedVelocity;
    private final double refreshmentRate = 0.0;

    public BouncyParticleOperator(GradientWrtParameterProvider gradientWrtParameterProvider, PrecisionMatrixVectorProductProvider precisionMatrixVectorProductProvider, PrecisionColumnProvider precisionColumnProvider, double d, AbstractParticleOperator.Options options, Parameter parameter) {
        super(gradientWrtParameterProvider, precisionMatrixVectorProductProvider, precisionColumnProvider, d, options, parameter);
    }

    @Override
    public String getOperatorName() {
        return "Bouncy particle operator";
    }

    @Override
    double integrateTrajectory(WrappedVector wrappedVector, WrappedVector wrappedVector2) {
        WrappedVector wrappedVector3 = this.drawInitialVelocity();
        WrappedVector wrappedVector4 = this.getInitialGradient();
        WrappedVector wrappedVector5 = this.getPrecisionProduct(wrappedVector3);
        AbstractParticleOperator.BounceState bounceState = new AbstractParticleOperator.BounceState(this.drawTotalTravelTime());
        this.initializeNumEvent();
        while (bounceState.remainingTime > 0.0) {
            if (bounceState.type == AbstractParticleOperator.Type.BOUNDARY) {
                this.updateAction(wrappedVector5, wrappedVector3, bounceState.index);
            } else {
                wrappedVector5 = this.getPrecisionProduct(wrappedVector3);
            }
            double d = -ReadableVector.Utils.innerProduct(wrappedVector3, wrappedVector4);
            double d2 = ReadableVector.Utils.innerProduct(wrappedVector3, wrappedVector5);
            double d3 = Math.max(0.0, -d / d2);
            double d4 = d3 * d3 / 2.0 * d2 + d3 * d;
            double d5 = this.getBounceTime(d2, d, d4);
            MinimumTravelInformation minimumTravelInformation = this.getTimeToBoundary(wrappedVector, wrappedVector3);
            double d6 = this.getRefreshTime();
            bounceState = this.doBounce(bounceState.remainingTime, d5, minimumTravelInformation, d6, wrappedVector, wrappedVector3, wrappedVector4, wrappedVector5);
            this.recordOneMoreEvent();
        }
        this.storedVelocity = wrappedVector3;
        return 0.0;
    }

    private AbstractParticleOperator.BounceState doBounce(double d, double d2, MinimumTravelInformation minimumTravelInformation, double d3, WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4) {
        AbstractParticleOperator.BounceState bounceState;
        double d4 = minimumTravelInformation.time;
        int n = minimumTravelInformation.index;
        if (d < Math.min(d4, d2)) {
            BouncyParticleOperator.updatePosition(wrappedVector, wrappedVector2, d);
            bounceState = new AbstractParticleOperator.BounceState(AbstractParticleOperator.Type.NONE, -1, 0.0);
        } else {
            int n2;
            AbstractParticleOperator.Type type;
            if (d3 < Math.min(d4, d2)) {
                type = AbstractParticleOperator.Type.REFRESHMENT;
                n2 = -1;
                BouncyParticleOperator.updatePosition(wrappedVector, wrappedVector2, d3);
                BouncyParticleOperator.updateGradient(wrappedVector3, d3, wrappedVector4);
                this.refreshVelocity(wrappedVector2);
            } else if (d4 < d2) {
                type = AbstractParticleOperator.Type.BOUNDARY;
                n2 = n;
                BouncyParticleOperator.updatePosition(wrappedVector, wrappedVector2, d4);
                BouncyParticleOperator.updateGradient(wrappedVector3, d4, wrappedVector4);
                wrappedVector.set(n, 0.0);
                wrappedVector2.set(n, -1.0 * wrappedVector2.get(n));
                d -= d4;
            } else {
                type = AbstractParticleOperator.Type.GRADIENT;
                n2 = -1;
                BouncyParticleOperator.updatePosition(wrappedVector, wrappedVector2, d2);
                BouncyParticleOperator.updateGradient(wrappedVector3, d2, wrappedVector4);
                BouncyParticleOperator.updateVelocity(wrappedVector2, wrappedVector3, this.preconditioning.mass);
                d -= d2;
            }
            bounceState = new AbstractParticleOperator.BounceState(type, n2, d);
        }
        return bounceState;
    }

    private WrappedVector drawInitialVelocity() {
        WrappedVector wrappedVector = this.preconditioning.mass;
        double[] dArray = new double[wrappedVector.getDim()];
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            dArray[i] = MathUtils.nextGaussian() / Math.sqrt(wrappedVector.get(i));
        }
        if (this.mask != null) {
            this.applyMask(dArray);
        }
        return new WrappedVector.Raw(dArray);
    }

    private MinimumTravelInformation getTimeToBoundary(ReadableVector readableVector, ReadableVector readableVector2) {
        assert (readableVector.getDim() == readableVector2.getDim());
        int n = -1;
        double d = Double.MAX_VALUE;
        int n2 = readableVector.getDim();
        for (int i = 0; i < n2; ++i) {
            double d2 = Math.abs(readableVector.get(i) / readableVector2.get(i));
            if (!(d2 > 0.0) || !this.headingTowardsBoundary(readableVector.get(i), readableVector2.get(i), i) || !(d2 < d)) continue;
            n = i;
            d = d2;
        }
        return new MinimumTravelInformation(d, n);
    }

    private double getRefreshTime() {
        return Double.POSITIVE_INFINITY;
    }

    private double getBounceTime(double d, double d2, double d3) {
        double d4 = d / 2.0;
        double d5 = d2;
        double d6 = -d3 - MathUtils.nextExponential(1.0);
        return (-d5 + Math.sqrt(d5 * d5 - 4.0 * d4 * d6)) / 2.0 / d4;
    }

    private static void updateVelocity(WrappedVector wrappedVector, WrappedVector wrappedVector2, ReadableVector readableVector) {
        ReadableVector.Quotient quotient = new ReadableVector.Quotient(wrappedVector2, readableVector);
        double d = ReadableVector.Utils.innerProduct(wrappedVector, wrappedVector2);
        double d2 = ReadableVector.Utils.innerProduct((ReadableVector)wrappedVector2, quotient);
        int n = wrappedVector.getDim();
        for (int i = 0; i < n; ++i) {
            wrappedVector.set(i, wrappedVector.get(i) - 2.0 * d / d2 * quotient.get(i));
        }
    }

    private void refreshVelocity(WrappedVector wrappedVector) {
        WrappedVector wrappedVector2 = this.preconditioning.mass;
        int n = wrappedVector.getDim();
        for (int i = 0; i < n; ++i) {
            wrappedVector.set(i, MathUtils.nextGaussian() / Math.sqrt(wrappedVector2.get(i)));
        }
        if (this.mask != null) {
            this.applyMask(wrappedVector);
        }
    }

    @Override
    public LogColumn[] getColumns() {
        LogColumn[] logColumnArray = new LogColumn[this.preconditioning.mass.getDim()];
        for (int i = 0; i < this.preconditioning.mass.getDim(); ++i) {
            final int n = i;
            logColumnArray[i] = new NumberColumn("v" + n){

                @Override
                public double getDoubleValue() {
                    if (BouncyParticleOperator.this.storedVelocity != null) {
                        return BouncyParticleOperator.this.storedVelocity.get(n);
                    }
                    return 0.0;
                }
            };
        }
        return logColumnArray;
    }
}

