/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.types.tests;

import cc.mallet.classify.Classifier;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.pipe.FeatureSequence2FeatureVector;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2Label;
import cc.mallet.pipe.TokenSequence2FeatureSequence;
import cc.mallet.pipe.iterator.RandomTokenSequenceIterator;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.InstanceList;
import cc.mallet.types.PagedInstanceList;
import cc.mallet.util.Randoms;
import java.io.File;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

public class TestPagedInstanceList
extends TestCase {
    public TestPagedInstanceList(String name) {
        super(name);
    }

    public static Test suite() {
        return new TestSuite((Class<?>)TestPagedInstanceList.class);
    }

    private static Alphabet dictOfSize(int size) {
        Alphabet ret = new Alphabet();
        for (int i = 0; i < size; ++i) {
            ret.lookupIndex("feature" + i);
        }
        return ret;
    }

    public void testRandomTrained() {
        SerialPipes p = new SerialPipes(new Pipe[]{new TokenSequence2FeatureSequence(), new FeatureSequence2FeatureVector(), new Target2Label()});
        double testAcc1 = this.testRandomTrainedOn(new InstanceList(p));
        double testAcc2 = this.testRandomTrainedOn(new PagedInstanceList(p, 700, 200, new File(".")));
        TestPagedInstanceList.assertEquals(testAcc1, testAcc2, 0.01);
    }

    private double testRandomTrainedOn(InstanceList training) {
        MaxEntTrainer trainer = new MaxEntTrainer();
        Alphabet fd = TestPagedInstanceList.dictOfSize(3);
        String[] classNames = new String[]{"class0", "class1", "class2"};
        Randoms r = new Randoms(1);
        RandomTokenSequenceIterator iter = new RandomTokenSequenceIterator(r, new Dirichlet(fd, 2.0), 30.0, 0.0, 10.0, 200.0, classNames);
        training.addThruPipe(iter);
        InstanceList testing = new InstanceList(training.getPipe());
        testing.addThruPipe(new RandomTokenSequenceIterator(r, new Dirichlet(fd, 2.0), 30.0, 0.0, 10.0, 200.0, classNames));
        System.out.println("Training set size = " + training.size());
        System.out.println("Testing set size = " + testing.size());
        Object classifier = trainer.train(training);
        System.out.println("Accuracy on training set:");
        System.out.println(classifier.getClass().getName() + ": " + new Trial((Classifier)classifier, training).getAccuracy());
        System.out.println("Accuracy on testing set:");
        double testAcc = new Trial((Classifier)classifier, testing).getAccuracy();
        System.out.println(classifier.getClass().getName() + ": " + testAcc);
        return testAcc;
    }

    public static void main(String[] args) throws Throwable {
        TestSuite theSuite;
        if (args.length > 0) {
            theSuite = new TestSuite();
            for (int i = 0; i < args.length; ++i) {
                theSuite.addTest(new TestPagedInstanceList(args[i]));
            }
        } else {
            theSuite = (TestSuite)TestPagedInstanceList.suite();
        }
        TestRunner.run(theSuite);
    }
}

