/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.run;

import com.carrotsearch.hppc.cursors.ObjectCursor;
import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.engine.EngineProcess;
import com.googlecode.clearnlp.engine.EngineSetter;
import com.googlecode.clearnlp.feature.xml.POSFtrXml;
import com.googlecode.clearnlp.pos.POSLib;
import com.googlecode.clearnlp.pos.POSNode;
import com.googlecode.clearnlp.pos.POSTagger;
import com.googlecode.clearnlp.reader.POSReader;
import com.googlecode.clearnlp.run.AbstractRun;
import com.googlecode.clearnlp.util.UTFile;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTXml;
import com.googlecode.clearnlp.util.list.SortedDoubleArrayList;
import com.googlecode.clearnlp.util.map.Prob1DMap;
import com.googlecode.clearnlp.util.pair.Pair;
import java.io.FileInputStream;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.kohsuke.args4j.Option;
import org.w3c.dom.Element;

public class POSTrain
extends AbstractRun {
    public static final int MODEL_SIZE = 2;
    protected final int FLAG_DOMAIN = 0;
    protected final int FLAG_GENERAL = 1;
    protected final int FLAG_DYNAMIC = 2;
    @Option(name="-i", usage="directory containg training files (required)", required=true, metaVar="<directory>")
    protected String s_trainDir;
    @Option(name="-c", usage="configuration file (required)", required=true, metaVar="<filename>")
    protected String s_configXml;
    @Option(name="-f", usage="feature template file (required)", required=true, metaVar="<filename>")
    protected String s_featureXml;
    @Option(name="-m", usage="model file (output; required)", required=true, metaVar="<filename>")
    protected String s_modelFile;
    @Option(name="-t", usage="similarity threshold (default: -1)", required=false, metaVar="<double>")
    protected double d_threshold = -1.0;
    @Option(name="-s", usage="model type - 0|1|2 (default: 1)\n0: train only a domain-specific model\n1: train only a generalized model\n2: train both models using dynamic model selection", required=false, metaVar="<integer>")
    protected int i_flag = 1;

    public POSTrain() {
    }

    public POSTrain(String[] args) {
        this.initArgs(args);
        try {
            this.run(this.s_configXml, this.s_featureXml, this.s_trainDir, this.s_modelFile, this.d_threshold, this.i_flag);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void run(String configXml, String featureXml, String trainDir, String modelFile, double threshold, int modId) throws Exception {
        Element eConfig = UTXml.getDocumentElement(new FileInputStream(configXml));
        POSReader reader = (POSReader)this.getReader((Element)eConfig).o1;
        POSFtrXml xml = new POSFtrXml(new FileInputStream(featureXml));
        String[] trainFiles = UTFile.getSortedFileList(trainDir);
        if (modId == 2) {
            if (threshold < 0.0) {
                threshold = this.crossValidate(trainFiles, reader, xml, eConfig);
            }
            POSTagger[] taggers = this.getTrainedTaggers(eConfig, reader, xml, trainFiles, -1);
            EngineSetter.setPOSTaggers(modelFile, featureXml, taggers, threshold, 2);
        } else {
            POSTagger[] taggers = new POSTagger[]{this.getTrainedTagger(eConfig, reader, xml, trainFiles, -1, modId)};
            taggers[0].clearFormSet();
            EngineSetter.setPOSTaggers(modelFile, featureXml, taggers, threshold, 1);
        }
    }

    public POSTagger getTrainedTagger(Element eConfig, POSReader reader, POSFtrXml xml, String[] trnFiles, int devId, int modId) throws Exception {
        Set<String> sLemmas = this.getLemmaSet(reader, xml, modId, trnFiles, devId);
        Pair<Set<String>, Map<String, String>> p = this.getLexica(reader, xml, modId, sLemmas, trnFiles, devId);
        StringTrainSpace space = this.getTrainSpace(reader, xml, modId, sLemmas, (Set)p.o1, (Map)p.o2, trnFiles, devId);
        StringModel model = (StringModel)this.getModel(UTXml.getFirstElementByTagName(eConfig, "train"), space, modId);
        return new POSTagger(xml, sLemmas, (Set<String>)((Set)p.o1), (Map<String, String>)((Map)p.o2), model);
    }

    public POSTagger[] getTrainedTaggers(Element eConfig, POSReader reader, POSFtrXml xml, String[] trnFiles, int devId) throws Exception {
        POSTagger[] taggers = new POSTagger[2];
        for (int modId = 0; modId < 2; ++modId) {
            System.out.printf("===== Training model %d =====\n", modId);
            taggers[modId] = this.getTrainedTagger(eConfig, reader, xml, trnFiles, devId, modId);
        }
        return taggers;
    }

    private Set<String> getLemmaSet(POSReader reader, POSFtrXml xml, int modId, String[] trnFiles, int devId) throws Exception {
        HashSet<String> set;
        int dfCutoff = xml.getDocumentFrequency(modId);
        Prob1DMap map = new Prob1DMap();
        int size = trnFiles.length;
        System.out.println("Collecting n-gram set:");
        System.out.println("- document frequency cutoff: " + dfCutoff);
        for (int i = 0; i < size; ++i) {
            POSNode[] nodes;
            if (devId == i) continue;
            reader.open(UTInput.createBufferedFileReader(trnFiles[i]));
            set = new HashSet();
            while ((nodes = reader.next()) != null) {
                EngineProcess.normalizeForms(nodes);
                for (POSNode node : nodes) {
                    set.add(node.lemma);
                }
            }
            reader.close();
            for (String s : set) {
                map.add(s);
            }
        }
        set = new HashSet<String>();
        for (ObjectCursor cur : map.keys()) {
            String lemma = (String)cur.value;
            if (map.get(lemma) <= dfCutoff) continue;
            set.add(lemma);
        }
        System.out.printf("- lemma reduction: %d -> %d\n", map.size(), set.size());
        return set;
    }

    private Pair<Set<String>, Map<String, String>> getLexica(POSReader reader, POSFtrXml xml, int xmlId, Set<String> sLemmas, String[] trnFiles, int devId) {
        POSTagger tagger = new POSTagger(sLemmas);
        int size = trnFiles.length;
        int featureCutoff = xml.getFeatureCutoff(xmlId);
        double ambiguityThreshold = xml.getAmbiguityThreshold(xmlId);
        System.out.println("Collecting lexica:");
        System.out.println("- lexica cutoff: " + featureCutoff);
        System.out.println("- ambiguity class threshold: " + ambiguityThreshold);
        for (int i = 0; i < size; ++i) {
            POSNode[] nodes;
            if (devId == i) continue;
            reader.open(UTInput.createBufferedFileReader(trnFiles[i]));
            while ((nodes = reader.next()) != null) {
                tagger.tag(nodes);
            }
            reader.close();
        }
        Set<String> sForms = tagger.getFormSet(featureCutoff);
        Map<String, String> mAmbi = tagger.getAmbiguityMap(ambiguityThreshold);
        System.out.println("- # of word-forms: " + sForms.size());
        System.out.println("- # of word-forms with ambiguity classes: " + mAmbi.size());
        return new Pair<Set<String>, Map<String, String>>(sForms, mAmbi);
    }

    private StringTrainSpace getTrainSpace(POSReader reader, POSFtrXml xml, int modId, Set<String> sLemmas, Set<String> sForms, Map<String, String> ambiguityMap, String[] trnFiles, int devId) {
        StringTrainSpace space = new StringTrainSpace(false, xml.getLabelCutoff(modId), xml.getFeatureCutoff(modId));
        POSTagger tagger = new POSTagger(xml, sLemmas, sForms, ambiguityMap, space);
        int size = trnFiles.length;
        System.out.println("Collecting training instances:");
        for (int i = 0; i < size; ++i) {
            POSNode[] nodes;
            if (devId == i) continue;
            reader.open(UTInput.createBufferedFileReader(trnFiles[i]));
            while ((nodes = reader.next()) != null) {
                tagger.tag(nodes);
            }
            reader.close();
            System.out.print(".");
        }
        System.out.println();
        return space;
    }

    public double crossValidate(String[] trnFiles, POSReader reader, POSFtrXml xml, Element eConfig) throws Exception {
        SortedDoubleArrayList list = new SortedDoubleArrayList();
        int size = trnFiles.length;
        for (int devId = 0; devId < size; ++devId) {
            System.out.printf("<== Cross validation %d ==>\n", devId);
            POSTagger[] taggers = this.getTrainedTaggers(eConfig, reader, xml, trnFiles, devId);
            this.crossValidatePredict(trnFiles[devId], reader, taggers, list);
        }
        int n = (int)Math.round((double)list.size() * 0.05);
        double threshold = Math.ceil(list.get(n) * 1000.0) / 1000.0;
        System.out.println("Out-of-domain validation:");
        System.out.println("- threshold: " + threshold);
        return threshold;
    }

    private void crossValidatePredict(String devFile, POSReader reader, POSTagger[] taggers, SortedDoubleArrayList list) {
        int modId;
        POSNode[] nodes;
        int[] local = new int[2];
        int[] correct = new int[2];
        int total = 0;
        System.out.println("Predicting: " + devFile);
        reader.open(UTInput.createBufferedFileReader(devFile));
        while ((nodes = reader.next()) != null) {
            double sim;
            String[] gold = POSLib.getLabels(nodes);
            total += gold.length;
            for (modId = 0; modId < 2; ++modId) {
                taggers[modId].tag(nodes);
                local[modId] = this.countCorrect(nodes, gold);
                int n = modId;
                correct[n] = correct[n] + local[modId];
            }
            if (local[0] <= local[1] || !((sim = taggers[0].getCosineSimilarity(nodes)) > 0.0)) continue;
            list.add(sim);
        }
        reader.close();
        for (modId = 0; modId < 2; ++modId) {
            double accuracy = 100.0 * (double)correct[modId] / (double)total;
            System.out.printf("- accuracy %d: %7.5f (%d/%d)\n", modId, accuracy, correct[modId], total);
        }
    }

    private int countCorrect(POSNode[] nodes, String[] gold) {
        int correct = 0;
        int n = nodes.length;
        for (int i = 0; i < n; ++i) {
            if (!gold[i].equals(nodes[i].pos)) continue;
            ++correct;
        }
        return correct;
    }

    public static void main(String[] args) {
        new POSTrain(args);
    }
}

