/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta;

import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.util.ArrayList;
import java.util.List;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.Measurement;
import moa.options.ClassOption;
import weka.core.Utils;

public class DynamicWeightedMajority
extends AbstractClassifier
implements MultiClassClassifier {
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Base classifiers to train.", Classifier.class, "bayes.NaiveBayes");
    public IntOption periodOption = new IntOption("period", 'p', "Period between expert removal, creation, and weight update.", 50, 1, Integer.MAX_VALUE);
    public FloatOption betaOption = new FloatOption("beta", 'b', "Factor to punish mistakes by.", 0.5, 0.0, 1.0);
    public FloatOption thetaOption = new FloatOption("theta", 't', "Minimum fraction of weight per model.", 0.01, 0.0, 1.0);
    public IntOption maxExpertsOption = new IntOption("maxExperts", 'e', "Maximum number of allowed experts.", Integer.MAX_VALUE, 2, Integer.MAX_VALUE);
    protected List<Classifier> experts;
    protected List<Double> weights;
    protected long epochs;

    @Override
    public void resetLearningImpl() {
        this.experts = new ArrayList<Classifier>(50);
        Classifier classifier = ((Classifier)this.getPreparedClassOption(this.baseLearnerOption)).copy();
        classifier.resetLearning();
        this.experts.add(classifier);
        this.weights = new ArrayList<Double>(50);
        this.weights.add(1.0);
        this.epochs = 0L;
    }

    protected void scaleWeights(double maxWeight) {
        double sf = 1.0 / maxWeight;
        for (int i = 0; i < this.weights.size(); ++i) {
            this.weights.set(i, this.weights.get(i) * sf);
        }
    }

    protected void removeExperts() {
        for (int i = this.experts.size() - 1; i >= 0; --i) {
            if (!(this.weights.get(i) < this.thetaOption.getValue())) continue;
            this.experts.remove(i);
            this.weights.remove(i);
        }
    }

    protected void removeWeakestExpert(int i) {
        this.experts.remove(i);
        this.weights.remove(i);
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        ++this.epochs;
        double[] Pr = new double[inst.numClasses()];
        double maxWeight = 0.0;
        double weakestExpertWeight = 1.0;
        int weakestExpertIndex = -1;
        for (int i = 0; i < this.experts.size(); ++i) {
            double[] pr = this.experts.get(i).getVotesForInstance(inst);
            int yHat = Utils.maxIndex((double[])pr);
            if (yHat != (int)inst.classValue() && this.epochs % (long)this.periodOption.getValue() == 0L) {
                this.weights.set(i, this.weights.get(i) * this.betaOption.getValue());
            }
            int n = yHat;
            Pr[n] = Pr[n] + this.weights.get(i);
            maxWeight = Math.max(maxWeight, this.weights.get(i));
            if (!(this.weights.get(i) < weakestExpertWeight)) continue;
            weakestExpertIndex = i;
            weakestExpertWeight = this.weights.get(i);
        }
        int yHat = Utils.maxIndex((double[])Pr);
        if (this.epochs % (long)this.periodOption.getValue() == 0L) {
            this.scaleWeights(maxWeight);
            this.removeExperts();
            if (yHat != (int)inst.classValue()) {
                if (this.experts.size() == this.maxExpertsOption.getValue()) {
                    this.removeWeakestExpert(weakestExpertIndex);
                }
                Classifier classifier = ((Classifier)this.getPreparedClassOption(this.baseLearnerOption)).copy();
                classifier.resetLearning();
                this.experts.add(classifier);
                this.weights.add(1.0);
            }
        }
        for (Classifier expert : this.experts) {
            expert.trainOnInstance(inst);
        }
    }

    @Override
    public boolean isRandomizable() {
        return false;
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        double[] Pr = new double[inst.numClasses()];
        for (int i = 0; i < this.experts.size(); ++i) {
            int yHat;
            double[] pr = this.experts.get(i).getVotesForInstance(inst);
            int n = yHat = Utils.maxIndex((double[])pr);
            Pr[n] = Pr[n] + this.weights.get(i);
        }
        Utils.normalize((double[])Pr);
        return Pr;
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurements = null;
        if (this.weights != null) {
            measurements = new Measurement[]{new Measurement("members size", this.weights.size())};
        }
        return measurements;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
    }
}

