package edu.sysu.pmglab.optimizer;

import edu.sysu.pmglab.container.array.Array;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/sysu/pmglab/optimizer/Optimizer.class */
public class Optimizer {
    private static final Logger logger = LoggerFactory.getLogger("GA-Optimizer");
    private final Function<double[], Double> fitness;
    private final BiFunction<Unit[], Random, double[]>[] generator;
    private final int popSize;
    private final int maxIter;
    private final double eliteProb;
    private final boolean display;
    private final double bestObj;
    private final Random random;
    private final double[][] bounds;
    private final int precision;
    private final ArrayList<double[]> initPops;

    /* loaded from: input_file:edu/sysu/pmglab/optimizer/Optimizer$Builder.class */
    public static class Builder {
        private final Function<double[], Double> fitness;
        private BiFunction<Unit[], Random, double[]>[] generator;
        private double[][] bounds;
        private final ArrayList<double[]> initPops = new ArrayList<>();
        private int popSize = 50;
        private int maxIter = 200;
        private double eliteProb = 0.2d;
        private double stepRatio = 0.01d;
        private double bestObj = 0.0d;
        private int precision = 6;
        private boolean display = false;
        private long seed = System.nanoTime();

        public Builder(Function<double[], Double> function, int i) {
            this.fitness = function;
            this.bounds = new double[i][2];
            for (int i2 = 0; i2 < i; i2++) {
                this.bounds[i2][0] = 0.0d;
                this.bounds[i2][1] = 1.0d;
            }
            this.generator = new BiFunction[4];
            this.generator[0] = (unitArr, random) -> {
                return Unit.generate(random, this.bounds);
            };
            this.generator[1] = (unitArr2, random2) -> {
                if (unitArr2[0].vec.length == 1) {
                    return new double[]{(unitArr2[random2.nextInt(unitArr2.length)].vec[0] + unitArr2[random2.nextInt(unitArr2.length)].vec[0]) / 2.0d};
                }
                if (unitArr2[0].vec.length == 2) {
                    return new double[]{unitArr2[random2.nextInt(unitArr2.length)].vec[0], unitArr2[random2.nextInt(unitArr2.length)].vec[1]};
                }
                double[] dArr = new double[unitArr2[0].vec.length];
                int nextInt = random2.nextInt(dArr.length - 2) + 1;
                Unit unit = unitArr2[random2.nextInt(unitArr2.length)];
                Unit unit2 = unitArr2[random2.nextInt(unitArr2.length)];
                System.arraycopy(unit.vec, nextInt, dArr, 0, dArr.length - nextInt);
                System.arraycopy(unit2.vec, 0, dArr, dArr.length - nextInt, nextInt);
                return dArr;
            };
            this.generator[2] = (unitArr3, random3) -> {
                double[] dArr = (double[]) unitArr3[random3.nextInt(unitArr3.length)].vec.clone();
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    float nextFloat = random3.nextFloat();
                    if (nextFloat < 0.333333d) {
                        dArr[i3] = dArr[i3] + (this.stepRatio * (this.bounds[i3][1] - this.bounds[i3][0]));
                        if (dArr[i3] > this.bounds[i3][1]) {
                            dArr[i3] = this.bounds[i3][1];
                        }
                    } else if (nextFloat < 0.666666d) {
                        dArr[i3] = dArr[i3] - (this.stepRatio * (this.bounds[i3][1] - this.bounds[i3][0]));
                        if (dArr[i3] < this.bounds[i3][0]) {
                            dArr[i3] = this.bounds[i3][0];
                        }
                    }
                }
                return dArr;
            };
            this.generator[3] = (unitArr4, random4) -> {
                Unit unit = unitArr4[random4.nextInt(unitArr4.length)];
                Unit unit2 = unitArr4[random4.nextInt(unitArr4.length)];
                double[] dArr = new double[unit.vec.length];
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    dArr[i3] = (unit.vec[i3] + unit2.vec[i3]) / 2.0d;
                }
                return dArr;
            };
        }

        public Builder setGenerator(BiFunction<Unit[], Random, double[]>[] biFunctionArr) {
            this.generator = biFunctionArr;
            return this;
        }

        @SafeVarargs
        public final Builder addGenerator(BiFunction<Unit[], Random, double[]>... biFunctionArr) {
            BiFunction<Unit[], Random, double[]>[] biFunctionArr2 = new BiFunction[this.generator.length + biFunctionArr.length];
            System.arraycopy(this.generator, 0, biFunctionArr2, 0, this.generator.length);
            System.arraycopy(biFunctionArr, 0, biFunctionArr2, this.generator.length, biFunctionArr.length);
            this.generator = biFunctionArr2;
            return this;
        }

        public Builder setPrecision(int i) {
            if (this.popSize < 0) {
                throw new UnsupportedOperationException("precision >= 0");
            }
            this.precision = i;
            return this;
        }

        public Builder addInitPop(double[] dArr) {
            this.initPops.add(dArr);
            return this;
        }

        public Builder setPopSize(int i) {
            if (i < 10) {
                throw new UnsupportedOperationException("popSize >= 10");
            }
            this.popSize = i;
            return this;
        }

        public Builder setMaxIter(int i) {
            if (i < 1) {
                throw new UnsupportedOperationException("maxIter >= 1");
            }
            this.maxIter = i;
            return this;
        }

        public Builder setEliteProb(double d) {
            if (d <= 0.0d || d >= 1.0d) {
                throw new UnsupportedOperationException("0 < eliteProb < 1");
            }
            this.eliteProb = d;
            return this;
        }

        public Builder setStepRatio(double d) {
            if (d <= 0.0d || d >= 1.0d) {
                throw new UnsupportedOperationException("0 < stepRatio < 1");
            }
            this.stepRatio = d;
            return this;
        }

        public Builder setDisplay(boolean z) {
            this.display = z;
            return this;
        }

        public Builder setBestObj(float f) {
            this.bestObj = f;
            return this;
        }

        public Builder setSeed(long j) {
            this.seed = j;
            return this;
        }

        public Builder setBounds(double[][] dArr) {
            this.bounds = dArr;
            return this;
        }

        public Optimizer build() {
            return new Optimizer(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/sysu/pmglab/optimizer/Optimizer$Unit.class */
    public static class Unit implements Comparable<Unit> {
        public double[] vec;
        public double lossValue;

        public Unit(double[] dArr, double d) {
            this.vec = dArr;
            this.lossValue = d;
        }

        public static double[] generate(Random random, double[][] dArr) {
            double[] dArr2 = new double[dArr.length];
            int length = dArr2.length;
            for (int i = 0; i < length; i++) {
                dArr2[i] = (random.nextDouble() * (dArr[i][1] - dArr[i][0])) + dArr[i][0];
            }
            return dArr2;
        }

        @Override // java.lang.Comparable
        public int compareTo(Unit unit) {
            return Double.compare(this.lossValue, unit.lossValue);
        }
    }

    private Optimizer(Builder builder) {
        this.fitness = builder.fitness;
        this.popSize = builder.popSize;
        this.maxIter = builder.maxIter;
        this.eliteProb = builder.eliteProb;
        this.display = builder.display;
        this.bestObj = builder.bestObj;
        this.bounds = builder.bounds;
        this.random = new Random(builder.seed);
        this.initPops = builder.initPops;
        this.generator = builder.generator;
        this.precision = builder.precision;
    }

    private void eliteSelect(Unit[] unitArr) {
        Arrays.sort(unitArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public double[] run() {
        int i = (int) (this.eliteProb * this.popSize);
        Array array = new Array(Unit[].class, this.popSize, false);
        if (this.initPops.size() > 0) {
            double[] precision = setPrecision(this.initPops.get(0));
            array.add(new Unit(precision, this.fitness.apply(precision).doubleValue()));
        }
        for (int i2 = 0; i2 < this.popSize; i2++) {
            double[] precision2 = setPrecision(Unit.generate(this.random, this.bounds));
            array.add(new Unit(precision2, this.fitness.apply(precision2).doubleValue()));
        }
        for (int i3 = 0; i3 < this.maxIter; i3++) {
            array.sort((v0, v1) -> {
                return v0.compareTo(v1);
            });
            if (this.display) {
                logger.info("This is NO.{} iteration, the optimal solution is {}, f(x)={}", Integer.valueOf(i3 + 1), Arrays.toString(((Unit) array.get(0)).vec), Double.valueOf(((Unit) array.get(0)).lossValue));
            }
            if (((Unit) array.get(0)).lossValue <= this.bestObj || i3 == this.maxIter - 1) {
                break;
            }
            Array array2 = new Array(Unit[].class, this.popSize, false);
            Unit[] unitArr = (Unit[]) array.popFirst(i).toArray(new Unit[0]);
            array2.addAll(unitArr);
            while (array2.size() < this.popSize) {
                double[] precision3 = setPrecision(this.generator[this.random.nextInt(this.generator.length)].apply(unitArr, this.random));
                array2.add(new Unit(precision3, this.fitness.apply(precision3).doubleValue()));
            }
            array = array2;
        }
        if (this.display) {
            logger.info("Optimal Solution: {}, Fitness Value: {}", Arrays.toString(((Unit) array.get(0)).vec), Double.valueOf(((Unit) array.get(0)).lossValue));
        }
        return ((Unit) array.get(0)).vec;
    }

    double[] setPrecision(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = BigDecimal.valueOf(dArr[i]).setScale(this.precision, RoundingMode.FLOOR).doubleValue();
            if (dArr[i] < this.bounds[i][0]) {
                dArr[i] = this.bounds[i][0];
            }
        }
        return dArr;
    }
}
