/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify.constraints.pr;

import cc.mallet.classify.constraints.pr.MaxEntFLPRConstraints;
import cc.mallet.types.FeatureVector;
import com.carrotsearch.hppc.IntIntHashMap;
import com.carrotsearch.hppc.cursors.IntObjectCursor;

public class MaxEntL2FLPRConstraints
extends MaxEntFLPRConstraints {
    private IntIntHashMap constraintIndices = new IntIntHashMap();
    private boolean normalize;

    public MaxEntL2FLPRConstraints(int numFeatures, int numLabels, boolean useValues, boolean normalize) {
        super(numFeatures, numLabels, useValues);
        this.normalize = normalize;
    }

    @Override
    public void addConstraint(int fi, double[] ex, double weight) {
        this.constraints.put(fi, new MaxEntL2FLPRConstraint(ex, weight));
        this.constraintIndices.put(fi, this.constraintIndices.size());
    }

    @Override
    public int numDimensions() {
        return this.constraints.size() * this.numLabels;
    }

    @Override
    public double getAuxiliaryValueContribution(double[] parameters) {
        double value = 0.0;
        for (IntObjectCursor fi : this.constraints) {
            int ci = this.constraintIndices.get(fi.key);
            for (int li = 0; li < this.numLabels; ++li) {
                double param = parameters[ci + li * this.constraints.size()];
                value += ((MaxEntFLPRConstraints.MaxEntFLPRConstraint)fi.value).target[li] * param;
                value -= param * param / (2.0 * ((MaxEntFLPRConstraints.MaxEntFLPRConstraint)fi.value).weight);
            }
        }
        return value;
    }

    @Override
    public void getGradient(double[] parameters, double[] gradient) {
        for (IntObjectCursor fi : this.constraints) {
            int ci = this.constraintIndices.get(fi.key);
            double norm = this.normalize ? ((MaxEntFLPRConstraints.MaxEntFLPRConstraint)fi.value).count : 1.0;
            for (int li = 0; li < this.numLabels; ++li) {
                double param = parameters[ci + li * this.constraints.size()];
                gradient[ci + li * this.constraints.size()] = ((MaxEntFLPRConstraints.MaxEntFLPRConstraint)fi.value).target[li] - ((MaxEntFLPRConstraints.MaxEntFLPRConstraint)fi.value).expectation[li] / norm;
                int n = ci + li * this.constraints.size();
                gradient[n] = gradient[n] - param / ((MaxEntFLPRConstraints.MaxEntFLPRConstraint)fi.value).weight;
            }
        }
    }

    @Override
    public double getCompleteValueContribution() {
        double value = 0.0;
        for (IntObjectCursor fi : this.constraints) {
            double norm = this.normalize ? ((MaxEntFLPRConstraints.MaxEntFLPRConstraint)fi.value).count : 1.0;
            for (int li = 0; li < this.numLabels; ++li) {
                value -= ((MaxEntFLPRConstraints.MaxEntFLPRConstraint)fi.value).weight * Math.pow(((MaxEntFLPRConstraints.MaxEntFLPRConstraint)fi.value).target[li] - ((MaxEntFLPRConstraints.MaxEntFLPRConstraint)fi.value).expectation[li] / norm, 2.0) / 2.0;
            }
        }
        return value;
    }

    @Override
    public double getScore(FeatureVector input, int label, double[] parameters) {
        double score = 0.0;
        for (int i = 0; i < this.indexCache.size(); ++i) {
            int ci = this.constraintIndices.get(this.indexCache.get(i));
            double param = parameters[ci + label * this.constraints.size()];
            double norm = this.normalize ? ((MaxEntFLPRConstraints.MaxEntFLPRConstraint)this.constraints.get((int)this.indexCache.get((int)i))).count : 1.0;
            if (this.useValues) {
                score += param * this.valueCache.get(i) / norm;
                continue;
            }
            score += param / norm;
        }
        return score;
    }

    protected class MaxEntL2FLPRConstraint
    extends MaxEntFLPRConstraints.MaxEntFLPRConstraint {
        public MaxEntL2FLPRConstraint(double[] target, double weight) {
            super(target, weight);
        }
    }
}

