/*
 * Decompiled with CFR 0.152.
 */
package org.cleartk.ml.svmlight;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.TreeMap;
import java.util.jar.JarEntry;
import java.util.jar.JarInputStream;
import java.util.jar.JarOutputStream;
import org.cleartk.ml.encoder.features.FeaturesEncoder;
import org.cleartk.ml.encoder.outcome.OutcomeEncoder;
import org.cleartk.ml.jar.ClassifierBuilder_ImplBase;
import org.cleartk.ml.jar.JarStreams;
import org.cleartk.ml.sigmoid.Sigmoid;
import org.cleartk.ml.svmlight.FitSigmoid;
import org.cleartk.ml.svmlight.SvmLightBooleanOutcomeClassifierBuilder;
import org.cleartk.ml.svmlight.SvmLightStringOutcomeClassifier;
import org.cleartk.ml.svmlight.model.SvmLightModel;
import org.cleartk.ml.util.featurevector.FeatureVector;

public class SvmLightStringOutcomeClassifierBuilder
extends ClassifierBuilder_ImplBase<SvmLightStringOutcomeClassifier, FeatureVector, String, Integer> {
    private TreeMap<Integer, SvmLightModel> models;
    private TreeMap<Integer, Sigmoid> sigmoids;

    public File getTrainingDataFile(File dir) {
        return new File(dir, "training-data-allfalse.svmlight");
    }

    public File getTrainingDataFile(File dir, int label) {
        return new File(dir, String.format("training-data-%d.svmlight", label));
    }

    public void trainClassifier(File dir, String ... args) throws Exception {
        SvmLightBooleanOutcomeClassifierBuilder builder = new SvmLightBooleanOutcomeClassifierBuilder();
        for (File file : dir.listFiles()) {
            if (!file.getName().matches("training-data-\\d+.svmlight")) continue;
            builder.trainClassifier(dir, file, args);
            Sigmoid s = FitSigmoid.fit(new File(file.toString() + ".model"), file);
            ObjectOutputStream o = new ObjectOutputStream(new FileOutputStream(new File(file.toString() + ".sigmoid")));
            o.writeObject(s);
            o.close();
        }
    }

    protected void packageClassifier(File dir, JarOutputStream modelStream) throws IOException {
        super.packageClassifier(dir, modelStream);
        int label = 1;
        while (true) {
            File modelFile = new File(dir, String.format("training-data-%d.svmlight.model", label));
            File sigmoidFile = new File(dir, String.format("training-data-%d.svmlight.sigmoid", label));
            if (!modelFile.exists()) break;
            String modelName = String.format("model-%d.svmlight", label);
            String sigmoidName = String.format("model-%d.sigmoid", label);
            JarStreams.putNextJarEntry((JarOutputStream)modelStream, (String)modelName, (File)modelFile);
            JarStreams.putNextJarEntry((JarOutputStream)modelStream, (String)sigmoidName, (File)sigmoidFile);
            ++label;
        }
    }

    protected void unpackageClassifier(JarInputStream modelStream) throws IOException {
        SvmLightModel model;
        super.unpackageClassifier(modelStream);
        this.models = new TreeMap();
        this.sigmoids = new TreeMap();
        int label = 1;
        while ((model = SvmLightStringOutcomeClassifierBuilder.getNextModel(modelStream, label)) != null) {
            this.models.put(label, model);
            JarStreams.getNextJarEntry((JarInputStream)modelStream, (String)String.format("model-%d.sigmoid", label));
            try {
                this.sigmoids.put(label, (Sigmoid)new ObjectInputStream(modelStream).readObject());
            }
            catch (ClassNotFoundException e) {
                throw new IOException(e);
            }
            ++label;
        }
        if (this.models.isEmpty()) {
            throw new IOException(String.format("no models found in %s", modelStream));
        }
    }

    protected SvmLightStringOutcomeClassifier newClassifier() {
        return new SvmLightStringOutcomeClassifier((FeaturesEncoder<FeatureVector>)this.featuresEncoder, (OutcomeEncoder<String, Integer>)this.outcomeEncoder, this.models, this.sigmoids);
    }

    private static SvmLightModel getNextModel(JarInputStream modelStream, int label) throws IOException {
        JarEntry entry = modelStream.getNextJarEntry();
        if (entry == null) {
            return null;
        }
        String expectedName = String.format("model-%d.svmlight", label);
        if (!entry.getName().equals(expectedName)) {
            throw new IOException(String.format("expected next jar entry to be %s, found %s", expectedName, entry.getName()));
        }
        return SvmLightModel.fromInputStream(modelStream);
    }
}

