/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.inference;

import cc.mallet.grmm.inference.AbstractInferencer;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.Variable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

public class VariableElimination
extends AbstractInferencer {
    transient FactorGraph mdlCurrent;
    private static final long serialVersionUID = 1L;

    private Factor eliminate(Collection allPhi, Variable node) {
        HashSet<Factor> phiSet = new HashSet<Factor>();
        Iterator j = allPhi.iterator();
        while (j.hasNext()) {
            Factor cpf = (Factor)j.next();
            if (!cpf.varSet().isEmpty() && !cpf.containsVar(node)) continue;
            phiSet.add(cpf);
            j.remove();
        }
        return TableFactor.multiplyAll(phiSet);
    }

    public Factor unnormalizedMarginal(FactorGraph model, Variable query) {
        HashSet<Factor> allPhi = new HashSet<Factor>();
        Iterator i = model.factorsIterator();
        while (i.hasNext()) {
            Factor factor = (Factor)i.next();
            allPhi.add(factor.duplicate());
        }
        Set nodes = model.variablesSet();
        for (Variable node : nodes) {
            if (node == query) continue;
            Factor newCPF = this.eliminate(allPhi, node);
            Factor singleCPF = newCPF.varSet().size() == 1 ? newCPF : newCPF.marginalizeOut(node);
            allPhi.add(singleCPF);
        }
        Factor marginal = this.eliminate(allPhi, query);
        assert (marginal.containsVar(query));
        assert (marginal.varSet().size() == 1);
        return marginal;
    }

    public double computeNormalizationFactor(FactorGraph m) {
        Variable var = (Variable)m.variablesSet().iterator().next();
        Factor marginal = this.unnormalizedMarginal(m, var);
        return marginal.sum();
    }

    public void computeMarginals(FactorGraph m) {
        this.mdlCurrent = m;
    }

    public Factor lookupMarginal(Variable var) {
        Factor marginal = this.unnormalizedMarginal(this.mdlCurrent, var);
        marginal.normalize();
        return marginal;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.defaultWriteObject();
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
    }
}

