/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees;

import weka.classifiers.Classifier;
import weka.classifiers.Sourcable;
import weka.classifiers.rules.ZeroR;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.ContingencyTables;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class DecisionStump
extends Classifier
implements WeightedInstancesHandler,
Sourcable {
    static final long serialVersionUID = 1618384535950391L;
    private int m_AttIndex;
    private double m_SplitPoint;
    private double[][] m_Distribution;
    private Instances m_Instances;
    private Classifier m_ZeroR;

    public String globalInfo() {
        return "Class for building and using a decision stump. Usually used in conjunction with a boosting algorithm. Does regression (based on mean-squared error) or classification (based on entropy). Missing is treated as a separate value.";
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.DATE_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    @Override
    public void buildClassifier(Instances instances) throws Exception {
        double bestVal = Double.MAX_VALUE;
        double bestPoint = -1.7976931348623157E308;
        int bestAtt = -1;
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        if (instances.numAttributes() == 1) {
            System.err.println("Cannot build model (only class attribute present in data!), using ZeroR model instead!");
            this.m_ZeroR = new ZeroR();
            this.m_ZeroR.buildClassifier(instances);
            return;
        }
        this.m_ZeroR = null;
        double[][] bestDist = new double[3][instances.numClasses()];
        this.m_Instances = new Instances(instances);
        int numClasses = this.m_Instances.classAttribute().isNominal() ? this.m_Instances.numClasses() : 1;
        boolean first = true;
        int i = 0;
        while (i < this.m_Instances.numAttributes()) {
            if (i != this.m_Instances.classIndex()) {
                this.m_Distribution = new double[3][numClasses];
                double currVal = this.m_Instances.attribute(i).isNominal() ? this.findSplitNominal(i) : this.findSplitNumeric(i);
                if (first || currVal < bestVal) {
                    bestVal = currVal;
                    bestAtt = i;
                    bestPoint = this.m_SplitPoint;
                    int j = 0;
                    while (j < 3) {
                        System.arraycopy(this.m_Distribution[j], 0, bestDist[j], 0, numClasses);
                        ++j;
                    }
                }
                first = false;
            }
            ++i;
        }
        this.m_AttIndex = bestAtt;
        this.m_SplitPoint = bestPoint;
        this.m_Distribution = bestDist;
        if (this.m_Instances.classAttribute().isNominal()) {
            i = 0;
            while (i < this.m_Distribution.length) {
                double sumCounts = Utils.sum(this.m_Distribution[i]);
                if (sumCounts == 0.0) {
                    System.arraycopy(this.m_Distribution[2], 0, this.m_Distribution[i], 0, this.m_Distribution[2].length);
                    Utils.normalize(this.m_Distribution[i]);
                } else {
                    Utils.normalize(this.m_Distribution[i], sumCounts);
                }
                ++i;
            }
        }
        this.m_Instances = new Instances(this.m_Instances, 0);
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_ZeroR != null) {
            return this.m_ZeroR.distributionForInstance(instance);
        }
        return this.m_Distribution[this.whichSubset(instance)];
    }

    @Override
    public String toSource(String className) throws Exception {
        StringBuffer text = new StringBuffer("class ");
        Attribute c = this.m_Instances.classAttribute();
        text.append(className).append(" {\n  public static double classify(Object[] i) {\n");
        text.append("    /* " + this.m_Instances.attribute(this.m_AttIndex).name() + " */\n");
        text.append("    if (i[").append(this.m_AttIndex);
        text.append("] == null) { return ");
        text.append(this.sourceClass(c, this.m_Distribution[2])).append(";");
        if (this.m_Instances.attribute(this.m_AttIndex).isNominal()) {
            text.append(" } else if (((String)i[").append(this.m_AttIndex);
            text.append("]).equals(\"");
            text.append(this.m_Instances.attribute(this.m_AttIndex).value((int)this.m_SplitPoint));
            text.append("\")");
        } else {
            text.append(" } else if (((Double)i[").append(this.m_AttIndex);
            text.append("]).doubleValue() <= ").append(this.m_SplitPoint);
        }
        text.append(") { return ");
        text.append(this.sourceClass(c, this.m_Distribution[0])).append(";");
        text.append(" } else { return ");
        text.append(this.sourceClass(c, this.m_Distribution[1])).append(";");
        text.append(" }\n  }\n}\n");
        return text.toString();
    }

    private String sourceClass(Attribute c, double[] dist) {
        if (c.isNominal()) {
            return Integer.toString(Utils.maxIndex(dist));
        }
        return Double.toString(dist[0]);
    }

    public String toString() {
        if (this.m_ZeroR != null) {
            StringBuffer buf = new StringBuffer();
            buf.append(String.valueOf(this.getClass().getName().replaceAll(".*\\.", "")) + "\n");
            buf.append(String.valueOf(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=")) + "\n\n");
            buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
            buf.append(this.m_ZeroR.toString());
            return buf.toString();
        }
        if (this.m_Instances == null) {
            return "Decision Stump: No model built yet.";
        }
        try {
            StringBuffer text = new StringBuffer();
            text.append("Decision Stump\n\n");
            text.append("Classifications\n\n");
            Attribute att = this.m_Instances.attribute(this.m_AttIndex);
            if (att.isNominal()) {
                text.append(String.valueOf(att.name()) + " = " + att.value((int)this.m_SplitPoint) + " : ");
                text.append(this.printClass(this.m_Distribution[0]));
                text.append(String.valueOf(att.name()) + " != " + att.value((int)this.m_SplitPoint) + " : ");
                text.append(this.printClass(this.m_Distribution[1]));
            } else {
                text.append(String.valueOf(att.name()) + " <= " + this.m_SplitPoint + " : ");
                text.append(this.printClass(this.m_Distribution[0]));
                text.append(String.valueOf(att.name()) + " > " + this.m_SplitPoint + " : ");
                text.append(this.printClass(this.m_Distribution[1]));
            }
            text.append(String.valueOf(att.name()) + " is missing : ");
            text.append(this.printClass(this.m_Distribution[2]));
            if (this.m_Instances.classAttribute().isNominal()) {
                text.append("\nClass distributions\n\n");
                if (att.isNominal()) {
                    text.append(String.valueOf(att.name()) + " = " + att.value((int)this.m_SplitPoint) + "\n");
                    text.append(this.printDist(this.m_Distribution[0]));
                    text.append(String.valueOf(att.name()) + " != " + att.value((int)this.m_SplitPoint) + "\n");
                    text.append(this.printDist(this.m_Distribution[1]));
                } else {
                    text.append(String.valueOf(att.name()) + " <= " + this.m_SplitPoint + "\n");
                    text.append(this.printDist(this.m_Distribution[0]));
                    text.append(String.valueOf(att.name()) + " > " + this.m_SplitPoint + "\n");
                    text.append(this.printDist(this.m_Distribution[1]));
                }
                text.append(String.valueOf(att.name()) + " is missing\n");
                text.append(this.printDist(this.m_Distribution[2]));
            }
            return text.toString();
        }
        catch (Exception e) {
            return "Can't print decision stump classifier!";
        }
    }

    private String printDist(double[] dist) throws Exception {
        StringBuffer text = new StringBuffer();
        if (this.m_Instances.classAttribute().isNominal()) {
            int i = 0;
            while (i < this.m_Instances.numClasses()) {
                text.append(String.valueOf(this.m_Instances.classAttribute().value(i)) + "\t");
                ++i;
            }
            text.append("\n");
            i = 0;
            while (i < this.m_Instances.numClasses()) {
                text.append(String.valueOf(dist[i]) + "\t");
                ++i;
            }
            text.append("\n");
        }
        return text.toString();
    }

    private String printClass(double[] dist) throws Exception {
        StringBuffer text = new StringBuffer();
        if (this.m_Instances.classAttribute().isNominal()) {
            text.append(this.m_Instances.classAttribute().value(Utils.maxIndex(dist)));
        } else {
            text.append(dist[0]);
        }
        return String.valueOf(text.toString()) + "\n";
    }

    private double findSplitNominal(int index) throws Exception {
        if (this.m_Instances.classAttribute().isNominal()) {
            return this.findSplitNominalNominal(index);
        }
        return this.findSplitNominalNumeric(index);
    }

    private double findSplitNominalNominal(int index) throws Exception {
        double bestVal = Double.MAX_VALUE;
        double[][] counts = new double[this.m_Instances.attribute(index).numValues() + 1][this.m_Instances.numClasses()];
        double[] sumCounts = new double[this.m_Instances.numClasses()];
        double[][] bestDist = new double[3][this.m_Instances.numClasses()];
        int numMissing = 0;
        int i = 0;
        while (i < this.m_Instances.numInstances()) {
            Instance inst = this.m_Instances.instance(i);
            if (inst.isMissing(index)) {
                ++numMissing;
                double[] dArray = counts[this.m_Instances.attribute(index).numValues()];
                int n = (int)inst.classValue();
                dArray[n] = dArray[n] + inst.weight();
            } else {
                double[] dArray = counts[(int)inst.value(index)];
                int n = (int)inst.classValue();
                dArray[n] = dArray[n] + inst.weight();
            }
            ++i;
        }
        i = 0;
        while (i < this.m_Instances.attribute(index).numValues()) {
            int j = 0;
            while (j < this.m_Instances.numClasses()) {
                int n = j;
                sumCounts[n] = sumCounts[n] + counts[i][j];
                ++j;
            }
            ++i;
        }
        System.arraycopy(counts[this.m_Instances.attribute(index).numValues()], 0, this.m_Distribution[2], 0, this.m_Instances.numClasses());
        i = 0;
        while (i < this.m_Instances.attribute(index).numValues()) {
            int j = 0;
            while (j < this.m_Instances.numClasses()) {
                this.m_Distribution[0][j] = counts[i][j];
                this.m_Distribution[1][j] = sumCounts[j] - counts[i][j];
                ++j;
            }
            double currVal = ContingencyTables.entropyConditionedOnRows(this.m_Distribution);
            if (currVal < bestVal) {
                bestVal = currVal;
                this.m_SplitPoint = i;
                j = 0;
                while (j < 3) {
                    System.arraycopy(this.m_Distribution[j], 0, bestDist[j], 0, this.m_Instances.numClasses());
                    ++j;
                }
            }
            ++i;
        }
        if (numMissing == 0) {
            System.arraycopy(sumCounts, 0, bestDist[2], 0, this.m_Instances.numClasses());
        }
        this.m_Distribution = bestDist;
        return bestVal;
    }

    private double findSplitNominalNumeric(int index) throws Exception {
        double bestVal = Double.MAX_VALUE;
        double[] sumsSquaresPerValue = new double[this.m_Instances.attribute(index).numValues()];
        double[] sumsPerValue = new double[this.m_Instances.attribute(index).numValues()];
        double[] weightsPerValue = new double[this.m_Instances.attribute(index).numValues()];
        double totalSumSquaresW = 0.0;
        double totalSumW = 0.0;
        double totalSumOfWeightsW = 0.0;
        double totalSumOfWeights = 0.0;
        double totalSum = 0.0;
        double[] sumsSquares = new double[3];
        double[] sumOfWeights = new double[3];
        double[][] bestDist = new double[3][1];
        int i = 0;
        while (i < this.m_Instances.numInstances()) {
            Instance inst = this.m_Instances.instance(i);
            if (inst.isMissing(index)) {
                double[] dArray = this.m_Distribution[2];
                dArray[0] = dArray[0] + inst.classValue() * inst.weight();
                sumsSquares[2] = sumsSquares[2] + inst.classValue() * inst.classValue() * inst.weight();
                sumOfWeights[2] = sumOfWeights[2] + inst.weight();
            } else {
                int n = (int)inst.value(index);
                weightsPerValue[n] = weightsPerValue[n] + inst.weight();
                int n2 = (int)inst.value(index);
                sumsPerValue[n2] = sumsPerValue[n2] + inst.classValue() * inst.weight();
                int n3 = (int)inst.value(index);
                sumsSquaresPerValue[n3] = sumsSquaresPerValue[n3] + inst.classValue() * inst.classValue() * inst.weight();
            }
            totalSumOfWeights += inst.weight();
            totalSum += inst.classValue() * inst.weight();
            ++i;
        }
        if (totalSumOfWeights <= 0.0) {
            return bestVal;
        }
        i = 0;
        while (i < this.m_Instances.attribute(index).numValues()) {
            totalSumOfWeightsW += weightsPerValue[i];
            totalSumSquaresW += sumsSquaresPerValue[i];
            totalSumW += sumsPerValue[i];
            ++i;
        }
        i = 0;
        while (i < this.m_Instances.attribute(index).numValues()) {
            this.m_Distribution[0][0] = sumsPerValue[i];
            sumsSquares[0] = sumsSquaresPerValue[i];
            sumOfWeights[0] = weightsPerValue[i];
            this.m_Distribution[1][0] = totalSumW - sumsPerValue[i];
            sumsSquares[1] = totalSumSquaresW - sumsSquaresPerValue[i];
            sumOfWeights[1] = totalSumOfWeightsW - weightsPerValue[i];
            double currVal = this.variance(this.m_Distribution, sumsSquares, sumOfWeights);
            if (currVal < bestVal) {
                bestVal = currVal;
                this.m_SplitPoint = i;
                int j = 0;
                while (j < 3) {
                    bestDist[j][0] = sumOfWeights[j] > 0.0 ? this.m_Distribution[j][0] / sumOfWeights[j] : totalSum / totalSumOfWeights;
                    ++j;
                }
            }
            ++i;
        }
        this.m_Distribution = bestDist;
        return bestVal;
    }

    private double findSplitNumeric(int index) throws Exception {
        if (this.m_Instances.classAttribute().isNominal()) {
            return this.findSplitNumericNominal(index);
        }
        return this.findSplitNumericNumeric(index);
    }

    private double findSplitNumericNominal(int index) throws Exception {
        Instance inst;
        double bestVal = Double.MAX_VALUE;
        int numMissing = 0;
        double[] sum = new double[this.m_Instances.numClasses()];
        double[][] bestDist = new double[3][this.m_Instances.numClasses()];
        int i = 0;
        while (i < this.m_Instances.numInstances()) {
            inst = this.m_Instances.instance(i);
            if (!inst.isMissing(index)) {
                double[] dArray = this.m_Distribution[1];
                int n = (int)inst.classValue();
                dArray[n] = dArray[n] + inst.weight();
            } else {
                double[] dArray = this.m_Distribution[2];
                int n = (int)inst.classValue();
                dArray[n] = dArray[n] + inst.weight();
                ++numMissing;
            }
            ++i;
        }
        System.arraycopy(this.m_Distribution[1], 0, sum, 0, this.m_Instances.numClasses());
        int j = 0;
        while (j < 3) {
            System.arraycopy(this.m_Distribution[j], 0, bestDist[j], 0, this.m_Instances.numClasses());
            ++j;
        }
        this.m_Instances.sort(index);
        i = 0;
        while (i < this.m_Instances.numInstances() - (numMissing + 1)) {
            inst = this.m_Instances.instance(i);
            Instance instPlusOne = this.m_Instances.instance(i + 1);
            double[] dArray = this.m_Distribution[0];
            int n = (int)inst.classValue();
            dArray[n] = dArray[n] + inst.weight();
            double[] dArray2 = this.m_Distribution[1];
            int n2 = (int)inst.classValue();
            dArray2[n2] = dArray2[n2] - inst.weight();
            if (inst.value(index) < instPlusOne.value(index)) {
                double currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
                double currVal = ContingencyTables.entropyConditionedOnRows(this.m_Distribution);
                if (currVal < bestVal) {
                    this.m_SplitPoint = currCutPoint;
                    bestVal = currVal;
                    int j2 = 0;
                    while (j2 < 3) {
                        System.arraycopy(this.m_Distribution[j2], 0, bestDist[j2], 0, this.m_Instances.numClasses());
                        ++j2;
                    }
                }
            }
            ++i;
        }
        if (numMissing == 0) {
            System.arraycopy(sum, 0, bestDist[2], 0, this.m_Instances.numClasses());
        }
        this.m_Distribution = bestDist;
        return bestVal;
    }

    private double findSplitNumericNumeric(int index) throws Exception {
        Instance inst;
        double bestVal = Double.MAX_VALUE;
        int numMissing = 0;
        double[] sumsSquares = new double[3];
        double[] sumOfWeights = new double[3];
        double[][] bestDist = new double[3][1];
        double totalSum = 0.0;
        double totalSumOfWeights = 0.0;
        int i = 0;
        while (i < this.m_Instances.numInstances()) {
            inst = this.m_Instances.instance(i);
            if (!inst.isMissing(index)) {
                double[] dArray = this.m_Distribution[1];
                dArray[0] = dArray[0] + inst.classValue() * inst.weight();
                sumsSquares[1] = sumsSquares[1] + inst.classValue() * inst.classValue() * inst.weight();
                sumOfWeights[1] = sumOfWeights[1] + inst.weight();
            } else {
                double[] dArray = this.m_Distribution[2];
                dArray[0] = dArray[0] + inst.classValue() * inst.weight();
                sumsSquares[2] = sumsSquares[2] + inst.classValue() * inst.classValue() * inst.weight();
                sumOfWeights[2] = sumOfWeights[2] + inst.weight();
                ++numMissing;
            }
            totalSumOfWeights += inst.weight();
            totalSum += inst.classValue() * inst.weight();
            ++i;
        }
        if (totalSumOfWeights <= 0.0) {
            return bestVal;
        }
        this.m_Instances.sort(index);
        i = 0;
        while (i < this.m_Instances.numInstances() - (numMissing + 1)) {
            inst = this.m_Instances.instance(i);
            Instance instPlusOne = this.m_Instances.instance(i + 1);
            double[] dArray = this.m_Distribution[0];
            dArray[0] = dArray[0] + inst.classValue() * inst.weight();
            sumsSquares[0] = sumsSquares[0] + inst.classValue() * inst.classValue() * inst.weight();
            sumOfWeights[0] = sumOfWeights[0] + inst.weight();
            double[] dArray2 = this.m_Distribution[1];
            dArray2[0] = dArray2[0] - inst.classValue() * inst.weight();
            sumsSquares[1] = sumsSquares[1] - inst.classValue() * inst.classValue() * inst.weight();
            sumOfWeights[1] = sumOfWeights[1] - inst.weight();
            if (inst.value(index) < instPlusOne.value(index)) {
                double currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
                double currVal = this.variance(this.m_Distribution, sumsSquares, sumOfWeights);
                if (currVal < bestVal) {
                    this.m_SplitPoint = currCutPoint;
                    bestVal = currVal;
                    int j = 0;
                    while (j < 3) {
                        bestDist[j][0] = sumOfWeights[j] > 0.0 ? this.m_Distribution[j][0] / sumOfWeights[j] : totalSum / totalSumOfWeights;
                        ++j;
                    }
                }
            }
            ++i;
        }
        this.m_Distribution = bestDist;
        return bestVal;
    }

    private double variance(double[][] s, double[] sS, double[] sumOfWeights) {
        double var = 0.0;
        int i = 0;
        while (i < s.length) {
            if (sumOfWeights[i] > 0.0) {
                var += sS[i] - s[i][0] * s[i][0] / sumOfWeights[i];
            }
            ++i;
        }
        return var;
    }

    private int whichSubset(Instance instance) throws Exception {
        if (instance.isMissing(this.m_AttIndex)) {
            return 2;
        }
        if (instance.attribute(this.m_AttIndex).isNominal()) {
            if ((double)((int)instance.value(this.m_AttIndex)) == this.m_SplitPoint) {
                return 0;
            }
            return 1;
        }
        if (instance.value(this.m_AttIndex) <= this.m_SplitPoint) {
            return 0;
        }
        return 1;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5535 $");
    }

    public static void main(String[] argv) {
        DecisionStump.runClassifier(new DecisionStump(), argv);
    }
}

