/*
 * Decompiled with CFR 0.152.
 */
package org.apache.datasketches.pig.sampling;

import java.io.IOException;
import java.util.ArrayList;
import org.apache.datasketches.sampling.ReservoirItemsSketch;
import org.apache.datasketches.sampling.ReservoirItemsUnion;
import org.apache.datasketches.sampling.SamplingPigUtil;
import org.apache.pig.AccumulatorEvalFunc;
import org.apache.pig.Algebraic;
import org.apache.pig.EvalFunc;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;
import org.apache.pig.impl.logicalLayer.FrontendException;
import org.apache.pig.impl.logicalLayer.schema.Schema;

public class ReservoirSampling
extends AccumulatorEvalFunc<Tuple>
implements Algebraic {
    static final String N_ALIAS = "n";
    static final String K_ALIAS = "k";
    static final String SAMPLES_ALIAS = "samples";
    private static final int DEFAULT_TARGET_K = 1024;
    private final int targetK_;
    private ReservoirItemsSketch<Tuple> reservoir_;

    public ReservoirSampling(String kStr) {
        this.targetK_ = Integer.parseInt(kStr);
        if (this.targetK_ < 2) {
            throw new IllegalArgumentException("ReservoirSampling requires target reservoir size >= 2: " + this.targetK_);
        }
    }

    ReservoirSampling() {
        this.targetK_ = 1024;
    }

    public Tuple exec(Tuple inputTuple) throws IOException {
        if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
            return null;
        }
        DataBag samples = (DataBag)inputTuple.get(0);
        if (samples.size() <= (long)this.targetK_) {
            return ReservoirSampling.createResultTuple(samples.size(), this.targetK_, samples);
        }
        return (Tuple)super.exec(inputTuple);
    }

    public void accumulate(Tuple inputTuple) throws IOException {
        if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
            return;
        }
        DataBag samples = (DataBag)inputTuple.get(0);
        if (this.reservoir_ == null) {
            this.reservoir_ = ReservoirItemsSketch.newInstance((int)this.targetK_);
        }
        for (Tuple t : samples) {
            this.reservoir_.update((Object)t);
        }
    }

    public Tuple getValue() {
        if (this.reservoir_ == null) {
            return null;
        }
        ArrayList<Tuple> data = SamplingPigUtil.getRawSamplesAsList(this.reservoir_);
        DataBag sampleBag = BagFactory.getInstance().newDefaultBag(data);
        return ReservoirSampling.createResultTuple(this.reservoir_.getN(), this.reservoir_.getK(), sampleBag);
    }

    public void cleanup() {
        this.reservoir_ = null;
    }

    public Schema outputSchema(Schema input) {
        if (input != null && input.size() > 0) {
            try {
                Schema source = input;
                if (source.size() == 1 && source.getField((int)0).type == 120) {
                    source = source.getField((int)0).schema;
                }
                Schema recordSchema = new Schema();
                recordSchema.add(new Schema.FieldSchema(N_ALIAS, 15));
                recordSchema.add(new Schema.FieldSchema(K_ALIAS, 10));
                recordSchema.add(new Schema.FieldSchema(SAMPLES_ALIAS, source, 120));
                return new Schema(new Schema.FieldSchema(this.getSchemaName(((Object)((Object)this)).getClass().getName().toLowerCase(), source), recordSchema, 110));
            }
            catch (FrontendException e) {
                throw new RuntimeException(e);
            }
        }
        return null;
    }

    static Tuple createResultTuple(long n, int k, DataBag samples) {
        Tuple output = TupleFactory.getInstance().newTuple(3);
        try {
            output.set(0, (Object)n);
            output.set(1, (Object)k);
            output.set(2, (Object)samples);
        }
        catch (ExecException e) {
            throw new RuntimeException("Pig error: " + e.getMessage(), e);
        }
        return output;
    }

    public String getInitial() {
        return Initial.class.getName();
    }

    public String getIntermed() {
        return IntermediateFinal.class.getName();
    }

    public String getFinal() {
        return IntermediateFinal.class.getName();
    }

    static ArrayList<Tuple> dataBagToArrayList(DataBag bag) {
        int arrayLength = (int)bag.size();
        ArrayList<Tuple> output = new ArrayList<Tuple>(arrayLength);
        for (Tuple t : bag) {
            output.add(t);
        }
        return output;
    }

    public static class IntermediateFinal
    extends EvalFunc<Tuple> {
        private final int targetK_;

        public IntermediateFinal() {
            this.targetK_ = 1024;
        }

        public IntermediateFinal(String kStr) {
            this.targetK_ = Integer.parseInt(kStr);
            if (this.targetK_ < 2) {
                throw new IllegalArgumentException("ReservoirSampling requires target reservoir size >= 2: " + this.targetK_);
            }
        }

        public Tuple exec(Tuple inputTuple) throws IOException {
            if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
                return null;
            }
            ReservoirItemsUnion union = ReservoirItemsUnion.newInstance((int)this.targetK_);
            DataBag outerBag = (DataBag)inputTuple.get(0);
            for (Tuple reservoir : outerBag) {
                int k;
                long n = (Long)reservoir.get(0);
                if (n <= (long)(k = ((Integer)reservoir.get(1)).intValue()) && k <= this.targetK_) {
                    for (Tuple t : (DataBag)reservoir.get(2)) {
                        union.update((Object)t);
                    }
                    continue;
                }
                ArrayList<Tuple> samples = ReservoirSampling.dataBagToArrayList((DataBag)reservoir.get(2));
                union.update(n, k, samples);
            }
            ReservoirItemsSketch result = union.getResult();
            ArrayList data = SamplingPigUtil.getRawSamplesAsList(result);
            DataBag sampleBag = BagFactory.getInstance().newDefaultBag(data);
            Tuple output = TupleFactory.getInstance().newTuple(3);
            output.set(0, (Object)result.getN());
            output.set(1, (Object)result.getK());
            output.set(2, (Object)sampleBag);
            return output;
        }
    }

    public static class Initial
    extends EvalFunc<Tuple> {
        private final int targetK_;

        public Initial() {
            this.targetK_ = 1024;
        }

        public Initial(String kStr) {
            this.targetK_ = Integer.parseInt(kStr);
            if (this.targetK_ < 2) {
                throw new IllegalArgumentException("ReservoirSampling requires target reservoir size >= 2: " + this.targetK_);
            }
        }

        public Tuple exec(Tuple inputTuple) throws IOException {
            DataBag outputBag;
            if (inputTuple == null || inputTuple.size() < 1 || inputTuple.isNull(0)) {
                return null;
            }
            DataBag records = (DataBag)inputTuple.get(0);
            int k = this.targetK_;
            if (records.size() <= (long)this.targetK_) {
                outputBag = records;
            } else {
                ReservoirItemsSketch reservoir = ReservoirItemsSketch.newInstance((int)this.targetK_);
                for (Tuple t : records) {
                    reservoir.update((Object)t);
                }
                ArrayList data = SamplingPigUtil.getRawSamplesAsList(reservoir);
                outputBag = BagFactory.getInstance().newDefaultBag(data);
                k = reservoir.getK();
            }
            Tuple output = TupleFactory.getInstance().newTuple(3);
            output.set(0, (Object)records.size());
            output.set(1, (Object)k);
            output.set(2, (Object)outputBag);
            return output;
        }
    }
}

