/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.rules.multilabel.errormeasurers;

import com.yahoo.labs.samoa.instances.Prediction;
import moa.classifiers.rules.multilabel.errormeasurers.AbstractMultiTargetErrorMeasurer;

public class RootMeanSquaredErrorMT
extends AbstractMultiTargetErrorMeasurer {
    private double weightSeen;
    private double[] sumSquaredError;
    private static final long serialVersionUID = 1L;
    protected boolean hasStarted;
    protected int numLearnedOutputs;

    @Override
    public void addPrediction(Prediction prediction, Prediction trueClass, double weight) {
        int i;
        int numOutputs = prediction.numOutputAttributes();
        if (!this.hasStarted) {
            this.sumSquaredError = new double[numOutputs];
            this.hasStarted = true;
            for (i = 0; i < numOutputs; ++i) {
                if (!prediction.hasVotesForAttribute(i)) continue;
                ++this.numLearnedOutputs;
            }
            this.hasStarted = true;
        }
        for (i = 0; i < numOutputs; ++i) {
            if (!prediction.hasVotesForAttribute(i)) continue;
            double errorOutput = prediction.getVote(i, 0) - trueClass.getVote(i, 0);
            this.sumSquaredError[i] = errorOutput * errorOutput * weight + this.fadingErrorFactor * this.sumSquaredError[i];
        }
        this.weightSeen = weight + this.fadingErrorFactor * this.weightSeen;
    }

    @Override
    public double getCurrentError() {
        if (this.weightSeen == 0.0) {
            return Double.MAX_VALUE;
        }
        double sum = 0.0;
        int numOutputs = this.sumSquaredError.length;
        for (int i = 0; i < numOutputs; ++i) {
            sum += this.sumSquaredError[i];
        }
        return Math.sqrt(sum / (this.weightSeen * (double)this.numLearnedOutputs));
    }

    @Override
    public double getCurrentError(int index) {
        return Math.sqrt(this.sumSquaredError[index] / this.weightSeen);
    }

    @Override
    public double[] getCurrentErrors() {
        double[] errors = null;
        if (this.sumSquaredError != null) {
            errors = new double[this.sumSquaredError.length];
            for (int i = 0; i < this.sumSquaredError.length; ++i) {
                errors[i] = Math.sqrt(this.sumSquaredError[i] / this.weightSeen);
            }
        }
        return errors;
    }
}

