/*
 * Decompiled with CFR 0.152.
 */
package stats.glm;

import cern.colt.list.IntArrayList;
import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.doublealgo.Transform;
import cern.colt.matrix.impl.AbstractMatrix;
import cern.colt.matrix.linalg.Algebra;
import cern.colt.matrix.linalg.CholeskyDecomposition;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.biojava.stats.svm.ItemValue;
import org.biojava.stats.svm.SVMTarget;
import org.biojava.stats.svm.TrainingEvent;
import org.biojava.stats.svm.TrainingListener;
import stats.glm.BasisFunction;
import stats.glm.BasisSource;
import stats.glm.GLMClassificationModel;
import stats.glm.GLMRegressionModel;
import stats.glm.GLMTrainer;
import stats.glm.IRLSException;
import stats.glm.SLMTrainingContext;

public class SLMTrainer
implements GLMTrainer {
    private static DoubleFactory1D mf1d = DoubleFactory1D.dense;
    private static DoubleFactory2D mf2d = DoubleFactory2D.dense;
    private static Algebra alg = Algebra.DEFAULT;
    private static final double eps = 2.2E-16;
    private int maxCycles = 2000;
    private int cleanupCycles = 0;
    private int maxBasis = Integer.MAX_VALUE;
    private int minBasis = Integer.MAX_VALUE;
    private int initBasis = 5;
    private double maxAlpha = 1.0E10;
    private double minAlpha = 0.25;
    private double initialAlpha = 1.0;

    public void setMaxCycles(int n) {
        this.maxCycles = n;
    }

    public void setCleanupCycles(int n) {
        this.cleanupCycles = n;
    }

    public void setInitialBasis(int n) {
        this.initBasis = n;
    }

    public void setMinBasis(int n) {
        this.minBasis = n;
    }

    public void setMaxBasis(int n) {
        this.maxBasis = n;
    }

    public void setMaxAlpha(double d) {
        this.maxAlpha = d;
    }

    public void setMinAlpha(double d) {
        this.minAlpha = d;
    }

    public void setInitialAlpha(double d) {
        this.initialAlpha = d;
    }

    public GLMRegressionModel trainRegression(SVMTarget sVMTarget, BasisSource basisSource, TrainingListener trainingListener) {
        RVMTrainContext rVMTrainContext = new RVMTrainContext(sVMTarget, basisSource, trainingListener);
        return rVMTrainContext.train();
    }

    public GLMClassificationModel trainClassification(SVMTarget sVMTarget, BasisSource basisSource, TrainingListener trainingListener) {
        RVMClassContext rVMClassContext = new RVMClassContext(sVMTarget, basisSource, trainingListener);
        return rVMClassContext.train();
    }

    private class RVMTrainContext
    implements SLMTrainingContext {
        private int cycle = 0;
        private SVMTarget starget;
        private BasisSource basisSource;
        private TrainingListener listener;
        private TrainingEvent tevent;
        private List workingSet;
        private DoubleMatrix1D mu;

        RVMTrainContext(SVMTarget sVMTarget, BasisSource basisSource, TrainingListener trainingListener) {
            this.starget = sVMTarget;
            this.basisSource = new UniqueBasisSource(basisSource);
            this.listener = trainingListener;
            this.tevent = new TrainingEvent(this);
        }

        GLMRegressionModel train() {
            Object object;
            boolean bl = false;
            boolean bl2 = false;
            ArrayList<Object> arrayList = new ArrayList<Object>();
            int n = this.starget.items().size();
            DoubleMatrix1D doubleMatrix1D = mf1d.make(this.starget.items().size());
            int n2 = 0;
            Iterator iterator = this.starget.itemTargets().iterator();
            while (iterator.hasNext()) {
                object = (ItemValue)iterator.next();
                arrayList.add(object.getItem());
                doubleMatrix1D.set(n2++, object.getValue());
            }
            this.workingSet = new ArrayList();
            while (this.basisSource.hasNext(this) && this.workingSet.size() < SLMTrainer.this.initBasis) {
                object = this.basisSource.next(this);
                this.workingSet.add(object);
            }
            int n3 = this.workingSet.size();
            DoubleMatrix2D doubleMatrix2D = mf2d.make(n, n3);
            int n4 = 0;
            while (n4 < n) {
                int n5 = 0;
                while (n5 < n3) {
                    doubleMatrix2D.set(n4, n5, ((BasisFunction)this.workingSet.get(n5)).evaluate(arrayList.get(n4)));
                    ++n5;
                }
                ++n4;
            }
            Cloneable cloneable = mf1d.make(n3, SLMTrainer.this.initialAlpha);
            double d = 0.001;
            this.mu = null;
            DoubleMatrix2D doubleMatrix2D2 = null;
            while (this.cycle < SLMTrainer.this.maxCycles) {
                int n6;
                int n7;
                Cloneable cloneable2;
                Object object2;
                n3 = cloneable.size();
                DoubleMatrix2D doubleMatrix2D3 = mf2d.diagonal((DoubleMatrix1D)cloneable);
                DoubleMatrix2D doubleMatrix2D4 = Transform.mult(mf2d.identity(n), Math.pow(d, -2.0));
                doubleMatrix2D2 = alg.inverse(Transform.plus(alg.mult(alg.mult(doubleMatrix2D.viewDice(), doubleMatrix2D4), doubleMatrix2D), doubleMatrix2D3));
                this.mu = alg.mult(alg.mult(alg.mult(doubleMatrix2D2, doubleMatrix2D.viewDice()), doubleMatrix2D4), doubleMatrix1D);
                double d2 = 0.0;
                DoubleMatrix1D doubleMatrix1D2 = mf1d.make(n3);
                IntArrayList intArrayList = new IntArrayList();
                int n8 = 0;
                while (n8 < n3) {
                    doubleMatrix1D2.set(n8, 1.0 - cloneable.get(n8) * doubleMatrix2D2.get(n8, n8));
                    double d3 = doubleMatrix1D2.get(n8) / Math.pow(this.mu.get(n8), 2.0);
                    cloneable.set(n8, d3);
                    if (d3 < SLMTrainer.this.maxAlpha) {
                        intArrayList.add(n8);
                    }
                    d2 += doubleMatrix1D2.get(n8);
                    ++n8;
                }
                DoubleMatrix1D doubleMatrix1D3 = Transform.minus(alg.mult(doubleMatrix2D, this.mu), doubleMatrix1D);
                double d4 = 0.0;
                int n9 = 0;
                while (n9 < doubleMatrix1D3.size()) {
                    d4 += Math.pow(doubleMatrix1D3.get(n9), 2.0);
                    ++n9;
                }
                d = Math.sqrt(d4 / ((double)n - d2));
                if (intArrayList.size() < n3) {
                    DoubleMatrix1D doubleMatrix1D4 = mf1d.make(intArrayList.size());
                    object2 = mf2d.make(n, intArrayList.size());
                    cloneable2 = new ArrayList();
                    int n10 = 0;
                    while (n10 < intArrayList.size()) {
                        n7 = intArrayList.get(n10);
                        doubleMatrix1D4.set(n10, cloneable.get(n7));
                        n6 = 0;
                        while (n6 < n) {
                            ((DoubleMatrix2D)object2).set(n6, n10, doubleMatrix2D.get(n6, n7));
                            ++n6;
                        }
                        cloneable2.add(this.workingSet.get(n7));
                        ++n10;
                    }
                    cloneable = doubleMatrix1D4;
                    doubleMatrix2D = object2;
                    this.workingSet = cloneable2;
                }
                if (this.workingSet.size() < SLMTrainer.this.minBasis && !bl && this.basisSource.hasNext(this)) {
                    bl = true;
                    bl2 = false;
                }
                if (bl && this.cycle < SLMTrainer.this.maxCycles - SLMTrainer.this.cleanupCycles) {
                    if (bl2) {
                        bl2 = false;
                    } else {
                        int n11 = this.workingSet.size();
                        object2 = this.basisSource.next(this);
                        this.workingSet.add(object2);
                        if (this.workingSet.size() >= SLMTrainer.this.maxBasis || !this.basisSource.hasNext(this)) {
                            bl = false;
                        }
                        n3 = this.workingSet.size();
                        cloneable2 = mf1d.make(n3);
                        DoubleMatrix2D doubleMatrix2D5 = mf2d.make(n, n3);
                        n7 = 0;
                        while (n7 < n11) {
                            ((DoubleMatrix1D)cloneable2).set(n7, cloneable.get(n7));
                            n6 = 0;
                            while (n6 < n) {
                                doubleMatrix2D5.set(n6, n7, doubleMatrix2D.get(n6, n7));
                                ++n6;
                            }
                            ++n7;
                        }
                        n6 = n11;
                        while (n6 < n3) {
                            ((DoubleMatrix1D)cloneable2).set(n6, SLMTrainer.this.initialAlpha);
                            int n12 = 0;
                            while (n12 < n) {
                                doubleMatrix2D5.set(n12, n6, ((BasisFunction)this.workingSet.get(n6)).evaluate(arrayList.get(n12)));
                                ++n12;
                            }
                            ++n6;
                        }
                        cloneable = cloneable2;
                        doubleMatrix2D = doubleMatrix2D5;
                        bl2 = true;
                    }
                }
                ++this.cycle;
                this.listener.trainingCycleComplete(this.tevent);
            }
            this.listener.trainingComplete(this.tevent);
            return new GLMRegressionModel(this.workingSet, this.mu, doubleMatrix2D2, d);
        }

        public int getCurrentCycle() {
            return this.cycle;
        }

        public List getBasisList() {
            return this.workingSet;
        }

        public double getWeightForBasis(BasisFunction basisFunction) {
            int n = this.workingSet.indexOf(basisFunction);
            if (n < 0) {
                return 0.0;
            }
            return this.mu.get(n);
        }

        public SVMTarget getTarget() {
            return this.starget;
        }

        public double getDeviation() {
            throw new UnsupportedOperationException();
        }

        public GLMClassificationModel freezeModel() {
            throw new UnsupportedOperationException();
        }
    }

    private class RVMClassContext
    implements SLMTrainingContext {
        private int cycle = 0;
        private SVMTarget starget;
        private BasisSource basisSource;
        private TrainingListener listener;
        private TrainingEvent tevent;
        private List workingSet;
        private DoubleMatrix1D mu;
        private final double LAMBDA_MIN = Math.pow(2.0, -15.0);
        private final double IRLS_TOL = 1.0E-9;
        private double irls_new_olddev = 0.0;
        private double irls_new_dev = 0.0;

        RVMClassContext(SVMTarget sVMTarget, BasisSource basisSource, TrainingListener trainingListener) {
            this.starget = sVMTarget;
            this.basisSource = new UniqueBasisSource(basisSource);
            this.listener = trainingListener;
            this.tevent = new TrainingEvent(this);
        }

        GLMClassificationModel train() {
            Object object;
            boolean bl = false;
            boolean bl2 = false;
            boolean bl3 = false;
            HashMap<Object, Integer> hashMap = new HashMap<Object, Integer>();
            ArrayList<Object> arrayList = new ArrayList<Object>();
            int n = this.starget.items().size();
            DoubleMatrix1D doubleMatrix1D = mf1d.make(this.starget.items().size());
            int n2 = 0;
            Iterator iterator = this.starget.itemTargets().iterator();
            while (iterator.hasNext()) {
                object = (ItemValue)iterator.next();
                arrayList.add(object.getItem());
                doubleMatrix1D.setQuick(n2++, object.getValue());
            }
            this.workingSet = new ArrayList();
            while (this.basisSource.hasNext(this) && this.workingSet.size() < SLMTrainer.this.initBasis) {
                object = this.basisSource.next(this);
                hashMap.put(object, new Integer(0));
                this.workingSet.add(object);
                bl = this.basisSource.hasNext(this);
            }
            int n3 = this.workingSet.size();
            DoubleMatrix2D doubleMatrix2D = mf2d.make(n, n3);
            int n4 = 0;
            while (n4 < n) {
                int n5 = 0;
                while (n5 < n3) {
                    doubleMatrix2D.setQuick(n4, n5, ((BasisFunction)this.workingSet.get(n5)).evaluate(arrayList.get(n4)));
                    ++n5;
                }
                ++n4;
            }
            DoubleMatrix1D doubleMatrix1D2 = mf1d.make(n3, SLMTrainer.this.initialAlpha);
            this.mu = mf1d.make(n3);
            Object var14_16 = null;
            while (this.cycle < SLMTrainer.this.maxCycles) {
                int n6;
                int n7;
                AbstractMatrix abstractMatrix;
                DoubleMatrix1D doubleMatrix1D3;
                Object object2;
                n3 = doubleMatrix1D2.size();
                try {
                    this.mu = this.irls_new(doubleMatrix1D, doubleMatrix2D, doubleMatrix1D2, this.mu);
                    bl3 = false;
                }
                catch (IRLSException iRLSException) {
                    iRLSException.printStackTrace();
                    if (bl3) {
                        throw new RuntimeException("Hopeless IRLS failure!");
                    }
                    System.out.println("*** Nasty things happened.  Backing out basis");
                    this.mu.set(n3 - 1, 0.0);
                    bl = false;
                    bl3 = true;
                }
                if (bl && Math.abs(this.irls_new_dev - this.irls_new_olddev) > this.irls_new_olddev * 0.05) {
                    System.out.println("\n*** Unstable cycle: " + this.irls_new_olddev + " -> " + this.irls_new_dev);
                    bl2 = true;
                }
                int n8 = 0;
                IntArrayList intArrayList = new IntArrayList();
                int n9 = 0;
                while (n9 < n3) {
                    if (this.mu.get(n9) == 0.0) {
                        System.out.println("Presumed new basis.  Toasted.");
                    } else {
                        double d = 1.0 / Math.pow(this.mu.get(n9), 2.0);
                        d += SLMTrainer.this.minAlpha * Math.exp(-1.0 / SLMTrainer.this.minAlpha * d);
                        doubleMatrix1D2.setQuick(n9, d);
                        if (d == SLMTrainer.this.minAlpha) {
                            ++n8;
                        }
                        if (d < SLMTrainer.this.maxAlpha || bl3) {
                            intArrayList.add(n9);
                        }
                    }
                    ++n9;
                }
                if (intArrayList.size() < n3) {
                    ArrayList arrayList2 = new ArrayList();
                    object2 = mf1d.make(intArrayList.size());
                    doubleMatrix1D3 = mf1d.make(intArrayList.size());
                    abstractMatrix = mf2d.make(n, intArrayList.size());
                    int n10 = 0;
                    while (n10 < intArrayList.size()) {
                        n7 = intArrayList.get(n10);
                        ((DoubleMatrix1D)object2).setQuick(n10, doubleMatrix1D2.getQuick(n7));
                        doubleMatrix1D3.setQuick(n10, this.mu.getQuick(n7));
                        n6 = 0;
                        while (n6 < n) {
                            ((DoubleMatrix2D)abstractMatrix).setQuick(n6, n10, doubleMatrix2D.getQuick(n6, n7));
                            ++n6;
                        }
                        arrayList2.add(this.workingSet.get(n7));
                        ++n10;
                    }
                    doubleMatrix1D2 = object2;
                    this.mu = doubleMatrix1D3;
                    doubleMatrix2D = abstractMatrix;
                    this.workingSet = arrayList2;
                }
                if (this.workingSet.size() < SLMTrainer.this.minBasis && !bl && this.basisSource.hasNext(this)) {
                    bl = true;
                    bl2 = false;
                }
                if (bl && this.cycle < SLMTrainer.this.maxCycles - SLMTrainer.this.cleanupCycles) {
                    if (bl2) {
                        bl2 = false;
                    } else {
                        int n11 = this.workingSet.size();
                        object2 = this.basisSource.next(this);
                        hashMap.put(object2, new Integer(this.cycle));
                        this.workingSet.add(object2);
                        if (this.workingSet.size() >= SLMTrainer.this.maxBasis || !this.basisSource.hasNext(this)) {
                            bl = false;
                        }
                        n3 = this.workingSet.size();
                        doubleMatrix1D3 = mf1d.make(n3);
                        abstractMatrix = mf1d.make(n3);
                        DoubleMatrix2D doubleMatrix2D2 = mf2d.make(n, n3);
                        n7 = 0;
                        while (n7 < n11) {
                            doubleMatrix1D3.setQuick(n7, doubleMatrix1D2.getQuick(n7));
                            ((DoubleMatrix1D)abstractMatrix).setQuick(n7, this.mu.getQuick(n7));
                            n6 = 0;
                            while (n6 < n) {
                                doubleMatrix2D2.setQuick(n6, n7, doubleMatrix2D.getQuick(n6, n7));
                                ++n6;
                            }
                            ++n7;
                        }
                        n6 = n11;
                        while (n6 < n3) {
                            doubleMatrix1D3.setQuick(n6, SLMTrainer.this.initialAlpha);
                            int n12 = 0;
                            while (n12 < n) {
                                doubleMatrix2D2.setQuick(n12, n6, ((BasisFunction)this.workingSet.get(n6)).evaluate(arrayList.get(n12)));
                                ++n12;
                            }
                            ++n6;
                        }
                        doubleMatrix1D2 = doubleMatrix1D3;
                        this.mu = abstractMatrix;
                        doubleMatrix2D = doubleMatrix2D2;
                        bl2 = true;
                    }
                }
                ++this.cycle;
                this.listener.trainingCycleComplete(this.tevent);
            }
            this.listener.trainingComplete(this.tevent);
            return new GLMClassificationModel(this.workingSet, this.mu);
        }

        DoubleMatrix2D copy2D(DoubleMatrix1D doubleMatrix1D) {
            int n = doubleMatrix1D.size();
            DoubleMatrix2D doubleMatrix2D = mf2d.make(n, 1);
            int n2 = 0;
            while (n2 < n) {
                doubleMatrix2D.setQuick(n2, 0, doubleMatrix1D.getQuick(n2));
                ++n2;
            }
            return doubleMatrix2D;
        }

        DoubleMatrix1D view1D(DoubleMatrix2D doubleMatrix2D) {
            if (doubleMatrix2D.rows() == 1) {
                return doubleMatrix2D.viewRow(0);
            }
            if (doubleMatrix2D.columns() == 1) {
                return doubleMatrix2D.viewColumn(0);
            }
            throw new RuntimeException("Matrix must be a vector.");
        }

        DoubleMatrix1D irls_new(DoubleMatrix1D doubleMatrix1D, DoubleMatrix2D doubleMatrix2D, DoubleMatrix1D doubleMatrix1D2, DoubleMatrix1D doubleMatrix1D3) throws IRLSException {
            this.irls_new_olddev = this.irls_new_dev;
            int n = 0;
            DoubleMatrix2D doubleMatrix2D2 = mf2d.diagonal(doubleMatrix1D2);
            DoubleMatrix2D doubleMatrix2D3 = doubleMatrix2D.viewDice();
            int n2 = doubleMatrix1D.size();
            int n3 = doubleMatrix2D.columns();
            DoubleMatrix1D doubleMatrix1D4 = mf1d.make(n2);
            DoubleMatrix1D doubleMatrix1D5 = alg.mult(doubleMatrix2D, doubleMatrix1D3);
            int n4 = 0;
            while (n4 < n2) {
                double d = this.logit(doubleMatrix1D5.getQuick(n4));
                doubleMatrix1D4.setQuick(n4, d);
                ++n4;
            }
            DoubleMatrix1D doubleMatrix1D6 = mf1d.make(n3);
            double d = Double.POSITIVE_INFINITY;
            double d2 = Double.POSITIVE_INFINITY;
            double d3 = 0.0;
            while (d2 == Double.POSITIVE_INFINITY || d - d2 > d2 * 1.0E-9) {
                d = d2;
                DoubleMatrix1D doubleMatrix1D7 = mf1d.make(n2);
                DoubleMatrix1D doubleMatrix1D8 = mf1d.make(n2);
                int n5 = 0;
                while (n5 < n2) {
                    double d4 = doubleMatrix1D4.getQuick(n5);
                    doubleMatrix1D7.setQuick(n5, d4 * (1.0 - d4));
                    doubleMatrix1D8.setQuick(n5, doubleMatrix1D.getQuick(n5) - doubleMatrix1D4.getQuick(n5));
                    ++n5;
                }
                DoubleMatrix2D doubleMatrix2D4 = doubleMatrix2D.viewDice().copy();
                int n6 = 0;
                while (n6 < n3) {
                    int n7 = 0;
                    while (n7 < n2) {
                        doubleMatrix2D4.setQuick(n6, n7, doubleMatrix2D4.getQuick(n6, n7) * doubleMatrix1D7.getQuick(n7));
                        ++n7;
                    }
                    ++n6;
                }
                DoubleMatrix1D doubleMatrix1D9 = Transform.minus(alg.mult(doubleMatrix2D3, doubleMatrix1D8), alg.mult(doubleMatrix2D2, doubleMatrix1D3));
                DoubleMatrix2D doubleMatrix2D5 = Transform.plus(alg.mult(doubleMatrix2D4, doubleMatrix2D), doubleMatrix2D2);
                DoubleMatrix2D doubleMatrix2D6 = new CholeskyDecomposition(doubleMatrix2D5).getL();
                DoubleMatrix1D doubleMatrix1D10 = this.view1D(alg.solve(doubleMatrix2D6, alg.solve(doubleMatrix2D6.viewDice(), this.copy2D(doubleMatrix1D9))));
                double d5 = 1.0;
                while (d5 > this.LAMBDA_MIN) {
                    double d6;
                    doubleMatrix1D6 = mf1d.make(n3);
                    int n8 = 0;
                    while (n8 < n3) {
                        doubleMatrix1D6.setQuick(n8, doubleMatrix1D3.getQuick(n8) + doubleMatrix1D10.getQuick(n8) * d5);
                        ++n8;
                    }
                    DoubleMatrix1D doubleMatrix1D11 = alg.mult(doubleMatrix2D, doubleMatrix1D6);
                    d3 = 0.0;
                    int n9 = 0;
                    while (n9 < n2) {
                        d6 = this.logit(doubleMatrix1D11.getQuick(n9));
                        doubleMatrix1D4.set(n9, d6);
                        d3 = doubleMatrix1D.getQuick(n9) == 1.0 ? (d3 += -Math.log(d6)) : (d3 += -Math.log(1.0 - d6));
                        ++n9;
                    }
                    d6 = 0.0;
                    int n10 = 0;
                    while (n10 < n3) {
                        d6 += doubleMatrix1D2.getQuick(n10) * Math.pow(doubleMatrix1D6.getQuick(n10), 2.0);
                        ++n10;
                    }
                    d2 = (d3 /= (double)n2) + (d6 /= (double)(2 * n2));
                    if (d2 > d) {
                        d5 /= 2.0;
                        continue;
                    }
                    doubleMatrix1D3 = doubleMatrix1D6;
                    d5 = 0.0;
                    if (Double.isInfinite(d2)) {
                        if (++n > 100) {
                            throw new IRLSException("IRLS not getting anywhere :(");
                        }
                        if (n <= 10) continue;
                        System.out.println("Starting again");
                        int n11 = 0;
                        while (n11 < n2) {
                            doubleMatrix1D4.setQuick(n11, (doubleMatrix1D.get(n11) + 0.5) / 2.0);
                            ++n11;
                        }
                        int n12 = 0;
                        while (n12 < n3) {
                            doubleMatrix1D3.setQuick(n12, 0.0);
                            ++n12;
                        }
                        continue;
                    }
                    n = 0;
                }
                if (!(d5 > 0.0)) continue;
                System.out.println("Stopping due to backoff limit");
                this.irls_new_dev = d2;
                return doubleMatrix1D3;
            }
            this.irls_new_dev = d2;
            return doubleMatrix1D3;
        }

        private double logit(double d) {
            return 1.0 / (1.0 + Math.exp(-d));
        }

        private double invlogit(double d) {
            return Math.log(d / (1.0 - d));
        }

        private double diflogit(double d) {
            double d2 = Math.exp(-d);
            return d2 / Math.pow(1.0 + d2, 2.0);
        }

        public double getDeviation() {
            return this.irls_new_dev;
        }

        public int getCurrentCycle() {
            return this.cycle;
        }

        public List getBasisList() {
            return this.workingSet;
        }

        public double getWeightForBasis(BasisFunction basisFunction) {
            int n = this.workingSet.indexOf(basisFunction);
            if (n < 0) {
                return 0.0;
            }
            return this.mu.get(n);
        }

        public SVMTarget getTarget() {
            return this.starget;
        }

        public GLMClassificationModel freezeModel() {
            throw new UnsupportedOperationException();
        }
    }

    private static class UniqueBasisSource
    implements BasisSource {
        private BasisSource child;

        public UniqueBasisSource(BasisSource basisSource) {
            this.child = basisSource;
        }

        public BasisFunction next(SLMTrainingContext sLMTrainingContext) {
            BasisFunction basisFunction = null;
            do {
                basisFunction = this.child.next(sLMTrainingContext);
                if (!sLMTrainingContext.getBasisList().contains(basisFunction)) continue;
                basisFunction = null;
            } while (basisFunction == null);
            return basisFunction;
        }

        public boolean hasNext(SLMTrainingContext sLMTrainingContext) {
            return this.child.hasNext(sLMTrainingContext);
        }
    }
}

