package weka.classifiers.meta.ensembleSelection;

import java.util.Random;
import weka.classifiers.Evaluation;
import weka.core.Instances;

/* loaded from: input_file:weka/classifiers/meta/ensembleSelection/ModelBag.class */
public class ModelBag {
    private double[][][] m_models;
    private int[] m_modelIndex;
    private int m_bagSize;
    private int m_numChosen;
    private int[] m_timesChosen;
    private boolean m_debug;
    private double m_bestPerformance;
    private int[] m_bestTimesChosen;

    public ModelBag(double[][][] dArr, double d, boolean z) {
        this.m_debug = z;
        if (dArr.length == 0) {
            throw new IllegalArgumentException("ModelBag needs at least 1 model.");
        }
        this.m_bagSize = (int) (dArr.length * d);
        this.m_models = dArr;
        this.m_modelIndex = new int[this.m_models.length];
        this.m_timesChosen = new int[this.m_models.length];
        this.m_bestTimesChosen = this.m_timesChosen;
        this.m_bestPerformance = 0.0d;
        this.m_numChosen = 0;
        for (int i = 0; i < this.m_models.length; i++) {
            this.m_modelIndex[i] = i;
            this.m_timesChosen[i] = 0;
        }
    }

    private void swap(int i, int i2) {
        if (i != i2) {
            int i3 = this.m_modelIndex[i];
            this.m_modelIndex[i] = this.m_modelIndex[i2];
            this.m_modelIndex[i2] = i3;
            int i4 = this.m_timesChosen[i];
            this.m_timesChosen[i] = this.m_timesChosen[i2];
            this.m_timesChosen[i2] = i4;
        }
    }

    public void shuffle(Random random) {
        if (this.m_models.length < 2) {
            return;
        }
        for (int i = 0; i < this.m_models.length; i++) {
            int nextInt = random.nextInt(this.m_models.length - 1);
            if (nextInt >= i) {
                nextInt++;
            }
            swap(i, nextInt);
        }
    }

    private int[] virtualToRealWeights(int[] iArr) {
        int[] iArr2 = new int[iArr.length];
        for (int i = 0; i < iArr2.length; i++) {
            iArr2[this.m_modelIndex[i]] = iArr[i];
        }
        return iArr2;
    }

    private void updateBestTimesChosen() {
        this.m_bestTimesChosen = virtualToRealWeights(this.m_timesChosen);
    }

    public int[] sortInitialize(int i, boolean z, Instances instances, int i2) throws Exception {
        double[] dArr = new double[this.m_bagSize];
        for (int i3 = 0; i3 < this.m_bagSize; i3++) {
            dArr[i3] = evaluatePredictions(instances, model(i3), i2);
        }
        int[] iArr = new int[i];
        for (int i4 = 0; i4 < i; i4++) {
            int i5 = i4;
            double d = dArr[i4];
            for (int i6 = i4 + 1; i6 < this.m_bagSize; i6++) {
                if (dArr[i6] > d) {
                    d = dArr[i6];
                    i5 = i6;
                }
            }
            swap(i4, i5);
            double d2 = dArr[i4];
            dArr[i4] = dArr[i5];
            dArr[i5] = d2;
            iArr[i4] = this.m_modelIndex[i4];
            if (!z) {
                int[] iArr2 = this.m_timesChosen;
                int i7 = i4;
                iArr2[i7] = iArr2[i7] + 1;
                this.m_numChosen++;
            }
        }
        if (z) {
            double d3 = 0.0d;
            if (i > 0) {
                int[] iArr3 = this.m_timesChosen;
                iArr3[0] = iArr3[0] + 1;
                this.m_numChosen++;
                updateBestTimesChosen();
            }
            for (int i8 = 1; i8 < i; i8++) {
                double evaluatePredictions = evaluatePredictions(instances, computePredictions(i8, true), i2);
                if (evaluatePredictions <= d3) {
                    break;
                }
                d3 = evaluatePredictions;
                int[] iArr4 = this.m_timesChosen;
                int i9 = i8;
                iArr4[i9] = iArr4[i9] + 1;
                this.m_numChosen++;
                updateBestTimesChosen();
            }
        }
        updateBestTimesChosen();
        if (this.m_debug) {
            System.out.println("Sort Initialization added best " + this.m_numChosen + " models to the bag.");
        }
        return iArr;
    }

    public void weightAll(int i) {
        for (int i2 = 0; i2 < this.m_bagSize; i2++) {
            int[] iArr = this.m_timesChosen;
            int i3 = i2;
            iArr[i3] = iArr[i3] + i;
            this.m_numChosen += i;
        }
        updateBestTimesChosen();
    }

    public void forwardSelect(boolean z, Instances instances, int i) throws Exception {
        double d = -1.0d;
        int i2 = -1;
        for (int i3 = 0; i3 < this.m_bagSize; i3++) {
            if (this.m_timesChosen[i3] == 0 || z) {
                double evaluatePredictions = evaluatePredictions(instances, computePredictions(i3, true), i);
                if (evaluatePredictions > d) {
                    i2 = i3;
                    d = evaluatePredictions;
                }
            }
        }
        if (i2 == -1) {
            if (this.m_debug) {
                System.out.println("Couldn't add model.  No action performed.");
                return;
            }
            return;
        }
        int[] iArr = this.m_timesChosen;
        int i4 = i2;
        iArr[i4] = iArr[i4] + 1;
        this.m_numChosen++;
        if (d > this.m_bestPerformance) {
            updateBestTimesChosen();
            this.m_bestPerformance = d;
        }
    }

    public void backwardEliminate(Instances instances, int i) throws Exception {
        if (this.m_numChosen <= 1) {
            return;
        }
        double d = -1.0d;
        int i2 = -1;
        for (int i3 = 0; i3 < this.m_bagSize; i3++) {
            if (this.m_timesChosen[i3] > 0) {
                double evaluatePredictions = evaluatePredictions(instances, computePredictions(i3, false), i);
                if (evaluatePredictions > d) {
                    i2 = i3;
                    d = evaluatePredictions;
                }
            }
        }
        if (i2 == -1) {
            if (this.m_debug) {
                System.out.println("Couldn't remove model.  No action performed.");
                return;
            }
            return;
        }
        int[] iArr = this.m_timesChosen;
        int i4 = i2;
        iArr[i4] = iArr[i4] - 1;
        this.m_numChosen--;
        if (this.m_debug) {
            System.out.println("Removing model " + this.m_modelIndex[i2] + " (" + i2 + ") " + d);
        }
        if (d > this.m_bestPerformance) {
            updateBestTimesChosen();
            this.m_bestPerformance = d;
        }
    }

    public void forwardSelectOrBackwardEliminate(boolean z, Instances instances, int i) throws Exception {
        double d = -1.0d;
        int i2 = -1;
        boolean z2 = true;
        for (int i3 = 0; i3 < this.m_bagSize; i3++) {
            if (this.m_timesChosen[i3] > 0) {
                double evaluatePredictions = evaluatePredictions(instances, computePredictions(i3, false), i);
                if (evaluatePredictions > d) {
                    i2 = i3;
                    d = evaluatePredictions;
                    z2 = false;
                }
            }
            if (this.m_timesChosen[i3] == 0 || z) {
                double evaluatePredictions2 = evaluatePredictions(instances, computePredictions(i3, true), i);
                if (evaluatePredictions2 > d) {
                    i2 = i3;
                    d = evaluatePredictions2;
                    z2 = true;
                }
            }
        }
        if (i2 == -1) {
            if (this.m_debug) {
                System.out.println("Couldn't add or remove model.  No action performed.");
                return;
            }
            return;
        }
        int i4 = z2 ? 1 : -1;
        int[] iArr = this.m_timesChosen;
        int i5 = i2;
        iArr[i5] = iArr[i5] + i4;
        this.m_numChosen += i4;
        if (d > this.m_bestPerformance) {
            updateBestTimesChosen();
            this.m_bestPerformance = d;
        }
    }

    public int[] getModelWeights() {
        return this.m_bestTimesChosen;
    }

    private double[][] model(int i) {
        return this.m_models[this.m_modelIndex[i]];
    }

    private double[][] computePredictions(int i, boolean z) {
        double[][] dArr = new double[this.m_models[0].length][this.m_models[0][0].length];
        for (int i2 = 0; i2 < this.m_bagSize; i2++) {
            if (this.m_timesChosen[i2] > 0) {
                for (int i3 = 0; i3 < this.m_models[0].length; i3++) {
                    for (int i4 = 0; i4 < this.m_models[0][i3].length; i4++) {
                        double[] dArr2 = dArr[i3];
                        int i5 = i4;
                        dArr2[i5] = dArr2[i5] + (model(i2)[i3][i4] * this.m_timesChosen[i2]);
                    }
                }
            }
        }
        for (int i6 = 0; i6 < this.m_models[0].length; i6++) {
            int i7 = z ? 1 : -1;
            for (int i8 = 0; i8 < this.m_models[0][i6].length; i8++) {
                double[] dArr3 = dArr[i6];
                int i9 = i8;
                dArr3[i9] = dArr3[i9] + (i7 * model(i)[i6][i8]);
                double[] dArr4 = dArr[i6];
                int i10 = i8;
                dArr4[i10] = dArr4[i10] / (this.m_numChosen + i7);
            }
        }
        return dArr;
    }

    private double evaluatePredictions(Instances instances, double[][] dArr, int i) throws Exception {
        Evaluation evaluation = new Evaluation(instances);
        for (int i2 = 0; i2 < instances.numInstances(); i2++) {
            evaluation.evaluateModelOnceAndRecordPrediction(dArr[i2], instances.instance(i2));
        }
        return EnsembleMetricHelper.getMetric(evaluation, i);
    }

    public double[] getIndividualPerformance(Instances instances, int i) throws Exception {
        double[] dArr = new double[this.m_bagSize];
        for (int i2 = 0; i2 < this.m_bagSize; i2++) {
            dArr[i2] = evaluatePredictions(instances, model(i2), i);
        }
        return dArr;
    }
}
