/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ctakes.temporal.ae.feature.selection;

import com.google.common.base.Function;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import com.google.common.collect.Table;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.net.URI;
import java.util.LinkedHashSet;
import java.util.Set;
import org.apache.ctakes.temporal.ae.feature.selection.FeatureSelection;
import org.cleartk.ml.Feature;
import org.cleartk.ml.Instance;
import org.cleartk.ml.feature.transform.TransformableFeature;

public class BinaryAlphaFeatureSelection<OUTCOME_T>
extends FeatureSelection<OUTCOME_T> {
    private double featureSelectionThreshold;
    private int numFeatures = 0;
    private String positiveClass = null;
    private AlphaScorer<OUTCOME_T> alphaFunction;
    private LinkedHashSet<String> discardedFeatureNames;

    public BinaryAlphaFeatureSelection(String name) {
        this(name, 0.0);
    }

    public BinaryAlphaFeatureSelection(String name, double threshold) {
        super(name);
        this.featureSelectionThreshold = threshold;
    }

    public BinaryAlphaFeatureSelection(String name, double threshold, String posiClas) {
        super(name);
        this.featureSelectionThreshold = threshold;
        this.positiveClass = posiClas;
    }

    @Override
    public boolean apply(Feature feature) {
        return this.selectedFeatureNames.contains(this.getFeatureName(feature));
    }

    public void train(Iterable<Instance<OUTCOME_T>> instances) {
        this.alphaFunction = new AlphaScorer(this.positiveClass);
        for (Instance<OUTCOME_T> instance : instances) {
            Object outcome = instance.getOutcome();
            for (Feature feature : instance.getFeatures()) {
                if (!this.isTransformable(feature)) continue;
                for (Feature untransformedFeature : ((TransformableFeature)feature).getFeatures()) {
                    this.alphaFunction.update(this.getFeatureName(untransformedFeature), outcome, 1);
                }
            }
        }
        Set featureNames = this.alphaFunction.featValueClassCount.rowKeySet();
        Ordering ordering = Ordering.natural().onResultOf(this.alphaFunction).reverse();
        int totalFeatures = featureNames.size();
        this.numFeatures = (int)Math.round((double)totalFeatures * this.featureSelectionThreshold);
        this.selectedFeatureNames = Sets.newLinkedHashSet((Iterable)ordering.immutableSortedCopy((Iterable)featureNames).subList(0, this.numFeatures));
        this.discardedFeatureNames = Sets.newLinkedHashSet((Iterable)ordering.immutableSortedCopy((Iterable)featureNames).subList(this.numFeatures, totalFeatures));
        this.isTrained = true;
    }

    public void save(URI uri) throws IOException {
        if (!this.isTrained) {
            throw new IllegalStateException("Cannot save before training");
        }
        File out = new File(uri);
        String uriPath = uri.getPath();
        int lastIndex = uriPath.lastIndexOf(46);
        String discardPath = (lastIndex >= 0 ? uriPath.substring(0, lastIndex) : uriPath) + "_discarded.dat";
        File discardOut = new File(discardPath);
        BufferedWriter writer = new BufferedWriter(new FileWriter(out));
        BufferedWriter diswriter = new BufferedWriter(new FileWriter(discardOut));
        for (String feature : this.selectedFeatureNames) {
            writer.append(String.format("%s\t%f\n", feature, this.alphaFunction.score(feature)));
        }
        for (String feature : this.discardedFeatureNames) {
            diswriter.append(String.format("%s\t%f\n", feature, this.alphaFunction.score(feature)));
        }
        writer.close();
        diswriter.close();
    }

    public void load(URI uri) throws IOException {
        this.selectedFeatureNames = Sets.newLinkedHashSet();
        File in = new File(uri);
        BufferedReader reader = new BufferedReader(new FileReader(in));
        String line = null;
        for (int n = 0; (line = reader.readLine()) != null && n < this.numFeatures; ++n) {
            String[] featureValuePair = line.split("\t");
            this.selectedFeatureNames.add(featureValuePair[0]);
        }
        reader.close();
        this.isTrained = true;
    }

    private static class AlphaScorer<OUTCOME_T>
    implements Function<String, Double> {
        protected Multiset<OUTCOME_T> classCounts = HashMultiset.create();
        protected Table<String, OUTCOME_T, Integer> featValueClassCount = HashBasedTable.create();
        private String positiveClass = null;

        public AlphaScorer(String posiClas) {
            this.positiveClass = posiClas;
        }

        public void update(String featureName, OUTCOME_T outcome, int occurrences) {
            Integer count = (Integer)this.featValueClassCount.get((Object)featureName, outcome);
            if (count == null) {
                count = 0;
            }
            this.featValueClassCount.put((Object)featureName, outcome, (Object)(count + occurrences));
            this.classCounts.add(outcome, occurrences);
        }

        public Double apply(String featureName) {
            return this.score(featureName);
        }

        public double score(String featureName) {
            double o11 = 0.0;
            double o01 = 0.0;
            double o00 = 0.0;
            double on11 = 0.0;
            double on01 = 0.0;
            double on00 = 0.0;
            double n0 = 0.0;
            double n1 = 0.0;
            double n = 0.0;
            for (Object clas : this.classCounts.elementSet()) {
                int numAgreement = this.featValueClassCount.contains((Object)featureName, clas) ? (Integer)this.featValueClassCount.get((Object)featureName, clas) : 0;
                int numInstanceInThisClass = this.classCounts.count(clas);
                int numDisagreement = numInstanceInThisClass - numAgreement;
                n += (double)numInstanceInThisClass;
                if (clas.toString().equals("B") || clas.toString().equals("I") || clas.toString().equals(this.positiveClass)) {
                    o11 += (double)(2 * numAgreement);
                    o01 += (double)numDisagreement;
                    on00 += (double)(2 * numDisagreement);
                    on01 += (double)numAgreement;
                    continue;
                }
                if (this.positiveClass == null && clas.toString().equals("O")) {
                    o00 += (double)(2 * numDisagreement);
                    o01 += (double)numAgreement;
                    on11 += (double)(2 * numAgreement);
                    on01 += (double)numDisagreement;
                    continue;
                }
                System.err.println("Please define postive class label for odds ratio calculation.");
                System.exit(0);
            }
            n0 = o00 + o01;
            n1 = o11 + o01;
            if (n0 + n1 != 2.0 * n) {
                System.err.println("Alpha Calculation is wrong.");
                System.exit(0);
            }
            double alpha_positive = 1.0 - (2.0 * n - 1.0) * o01 / (n0 * n1);
            n0 = on00 + on01;
            n1 = on11 + on01;
            double alpha_negative = 1.0 - (2.0 * n - 1.0) * on01 / (n0 * n1);
            return Math.max(alpha_negative, alpha_positive);
        }
    }
}

