package edu.sysu.pmglab.stat;

import java.util.Arrays;
import java.util.Comparator;
import java.util.stream.IntStream;
import kotlin.time.DurationKt;
import org.apache.commons.math3.analysis.interpolation.MicrosphereInterpolator;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.stat.StatUtils;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.factory.DecompositionFactory_DDRM;
import org.ejml.interfaces.decomposition.CholeskyDecomposition_F64;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:edu/sysu/pmglab/stat/MCMCEstimationContinuous.class */
public class MCMCEstimationContinuous {
    private static final MersenneTwister random = new MersenneTwister();
    private static final NormalDistribution normal = new NormalDistribution(random, CMAESOptimizer.DEFAULT_STOPFITNESS, 1.0d);

    public static void main(String[] strArr) {
        double[] dArr = new double[DurationKt.NANOS_IN_MILLIS];
        double[] dArr2 = new double[DurationKt.NANOS_IN_MILLIS];
        double[] dArr3 = new double[DurationKt.NANOS_IN_MILLIS];
        for (int i = 0; i < 1000000; i++) {
            dArr[i] = random.nextDouble();
            dArr2[i] = normal.sample();
            dArr3[i] = dArr[i] < 0.1d ? 1.0d : (1.0d * dArr[i]) + 2.0d;
        }
        double variance = StatUtils.variance(dArr3);
        for (int i2 = 0; i2 < 1000000; i2++) {
            int i3 = i2;
            dArr3[i3] = dArr3[i3] + (dArr2[i2] * CMAESOptimizer.DEFAULT_STOPFITNESS);
        }
        double variance2 = ((variance * (1.0d - 0.01d)) - ((StatUtils.variance(dArr3) - variance) * 0.01d)) / 0.01d;
        double[] dArr4 = new double[MicrosphereInterpolator.DEFAULT_MICROSPHERE_ELEMENTS];
        double[] dArr5 = new double[MicrosphereInterpolator.DEFAULT_MICROSPHERE_ELEMENTS];
        double[] dArr6 = new double[MicrosphereInterpolator.DEFAULT_MICROSPHERE_ELEMENTS];
        for (int i4 = 0; i4 < 2000; i4++) {
            dArr4[i4] = random.nextDouble();
            dArr5[i4] = normal.sample();
            dArr6[i4] = (dArr4[i4] < 0.1d ? 1.0d : (1.0d * dArr4[i4]) + 2.0d) + (dArr5[i4] * CMAESOptimizer.DEFAULT_STOPFITNESS) + (normal.sample() * Math.sqrt(variance2));
        }
        long nanoTime = System.nanoTime();
        double[] mcmcEstimation = new MCMCEstimationContinuous().mcmcEstimation(dArr4, dArr6, dArr5, 50000, 10000);
        System.out.println("Runtime: " + ((System.nanoTime() - nanoTime) / 1.0E9d) + " seconds");
        System.out.println("Estimated c_true: " + mcmcEstimation[0] + " True value: 0.1");
        System.out.println("Estimated A_true: " + mcmcEstimation[1] + " True value: 1.0");
        System.out.println("Estimated k_true: " + mcmcEstimation[2] + " True value: 1.0");
        System.out.println("Estimated b_true: " + mcmcEstimation[3] + " True value: 2.0");
        System.out.println("Estimated z_true: " + mcmcEstimation[4] + " True value: " + CMAESOptimizer.DEFAULT_STOPFITNESS);
    }

    public double[] mcmcEstimation(double[] dArr, double[] dArr2, double[] dArr3, int i, int i2) {
        int length = dArr2.length;
        Integer[] numArr = (Integer[]) IntStream.range(0, length).boxed().sorted(Comparator.comparingDouble(num -> {
            return dArr[num.intValue()];
        })).toArray(i3 -> {
            return new Integer[i3];
        });
        double[] array = Arrays.stream(numArr).mapToDouble(num2 -> {
            return dArr[num2.intValue()];
        }).toArray();
        double[] array2 = Arrays.stream(numArr).mapToDouble(num3 -> {
            return dArr2[num3.intValue()];
        }).toArray();
        double[] dArr4 = new double[length - 1];
        for (int i4 = 0; i4 < length - 1; i4++) {
            dArr4[i4] = (array2[i4 + 1] - array2[i4]) / (array[i4 + 1] - array[i4]);
        }
        double max = Math.max(Math.min(array[IntStream.range(0, dArr4.length).boxed().max(Comparator.comparingDouble(num4 -> {
            return Math.abs(dArr4[num4.intValue()]);
        })).get().intValue()], 0.9d), 0.05d);
        double orElse = Arrays.stream(dArr2).filter(d -> {
            return dArr[Arrays.asList(numArr).indexOf(Arrays.stream(numArr).filter(num5 -> {
                return dArr2[num5.intValue()] == d;
            }).findFirst().get())] < max;
        }).average().orElse(CMAESOptimizer.DEFAULT_STOPFITNESS);
        double d2 = 0.0d;
        double orElse2 = Arrays.stream(dArr2).filter(d3 -> {
            return dArr[Arrays.asList(numArr).indexOf(Arrays.stream(numArr).filter(num5 -> {
                return dArr2[num5.intValue()] == d3;
            }).findFirst().get())] >= max;
        }).average().orElse(CMAESOptimizer.DEFAULT_STOPFITNESS);
        double d4 = 0.0d;
        StatUtils.variance(dArr2);
        double[] dArr5 = new double[i];
        double[] dArr6 = new double[i];
        double[] dArr7 = new double[i];
        double[] dArr8 = new double[i];
        double[] dArr9 = new double[i];
        double[] dArr10 = new double[i];
        double d5 = 0.005d;
        int i5 = 0;
        for (int i6 = 0; i6 < i; i6++) {
            double[] dArr11 = new double[length];
            for (int i7 = 0; i7 < length; i7++) {
                dArr11[i7] = (dArr[i7] < max ? orElse : (d2 * dArr[i7]) + orElse2) + (d4 * dArr3[i7]);
            }
            double[] dArr12 = new double[length];
            for (int i8 = 0; i8 < length; i8++) {
                dArr12[i8] = dArr2[i8] - dArr11[i8];
            }
            double sample = 1.0d / new GammaDistribution(random, 2.0d + (length / 2.0d), 1.0d / (1.0d + (Arrays.stream(dArr12).map(d6 -> {
                return d6 * d6;
            }).sum() / 2.0d))).sample();
            boolean[] zArr = new boolean[length];
            int i9 = 0;
            for (int i10 = 0; i10 < length; i10++) {
                zArr[i10] = dArr[i10] < max;
                if (zArr[i10]) {
                    i9++;
                }
            }
            if (i9 > 0) {
                double[] dArr13 = new double[i9];
                int i11 = 0;
                for (int i12 = 0; i12 < length; i12++) {
                    if (zArr[i12]) {
                        int i13 = i11;
                        i11++;
                        dArr13[i13] = dArr2[i12] - (d4 * dArr3[i12]);
                    }
                }
                double d7 = (i9 / sample) + (1.0d / 100.0d);
                orElse = (normal.sample() * Math.sqrt(1.0d / d7)) + (((StatUtils.sum(dArr13) / sample) + (CMAESOptimizer.DEFAULT_STOPFITNESS / 100.0d)) / d7);
            } else {
                orElse = (normal.sample() * Math.sqrt(100.0d)) + CMAESOptimizer.DEFAULT_STOPFITNESS;
            }
            boolean[] zArr2 = new boolean[length];
            int i14 = 0;
            for (int i15 = 0; i15 < length; i15++) {
                zArr2[i15] = dArr[i15] >= max;
                if (zArr2[i15]) {
                    i14++;
                }
            }
            if (i14 > 1) {
                double[] dArr14 = new double[i14];
                double[][] dArr15 = new double[i14][2];
                int i16 = 0;
                for (int i17 = 0; i17 < length; i17++) {
                    if (zArr2[i17]) {
                        dArr14[i16] = dArr2[i17] - (d4 * dArr3[i17]);
                        dArr15[i16][0] = 1.0d;
                        dArr15[i16][1] = dArr[i17];
                        i16++;
                    }
                }
                SimpleMatrix simpleMatrix = new SimpleMatrix(dArr15);
                SimpleMatrix diag = SimpleMatrix.diag(1.0d / 100.0d, 1.0d / 100.0d);
                SimpleMatrix simpleMatrix2 = new SimpleMatrix(2, 1, false, new double[]{CMAESOptimizer.DEFAULT_STOPFITNESS, CMAESOptimizer.DEFAULT_STOPFITNESS});
                SimpleMatrix invert = simpleMatrix.transpose().mult(simpleMatrix).scale(1.0d / sample).plus(diag).invert();
                double[] multivariateNormalSample = multivariateNormalSample(invert.mult(simpleMatrix.transpose().scale(1.0d / sample).mult(new SimpleMatrix(i14, 1, false, dArr14)).plus(diag.mult(simpleMatrix2))).getDDRM().getData(), invert);
                orElse2 = multivariateNormalSample[0];
                d2 = multivariateNormalSample[1];
            } else {
                orElse2 = (normal.sample() * Math.sqrt(100.0d)) + CMAESOptimizer.DEFAULT_STOPFITNESS;
                d2 = (normal.sample() * Math.sqrt(100.0d)) + CMAESOptimizer.DEFAULT_STOPFITNESS;
            }
            double[] dArr16 = new double[length];
            for (int i18 = 0; i18 < length; i18++) {
                dArr16[i18] = dArr[i18] < max ? orElse : (d2 * dArr[i18]) + orElse2;
            }
            double[] dArr17 = new double[length];
            for (int i19 = 0; i19 < length; i19++) {
                dArr17[i19] = dArr2[i19] - dArr16[i19];
            }
            double sum = (Arrays.stream(dArr3).map(d8 -> {
                return d8 * d8;
            }).sum() / sample) + (1.0d / 100.0d);
            double d9 = 0.0d;
            for (int i20 = 0; i20 < length; i20++) {
                d9 += dArr17[i20] * dArr3[i20];
            }
            d4 = (normal.sample() * Math.sqrt(1.0d / sum)) + (((d9 / sample) + (CMAESOptimizer.DEFAULT_STOPFITNESS / 100.0d)) / sum);
            double d10 = max;
            double sample2 = (normal.sample() * d5) + d10;
            if (sample2 > CMAESOptimizer.DEFAULT_STOPFITNESS && sample2 < 1.0d && Math.log(random.nextDouble()) < calcLikelihood(dArr2, dArr, dArr3, sample2, orElse, d2, orElse2, d4, sample) - calcLikelihood(dArr2, dArr, dArr3, d10, orElse, d2, orElse2, d4, sample)) {
                max = sample2;
                if (i6 > i2) {
                    i5++;
                }
            }
            dArr5[i6] = max;
            dArr6[i6] = orElse;
            dArr7[i6] = d2;
            dArr8[i6] = orElse2;
            dArr9[i6] = d4;
            dArr10[i6] = sample;
            if (i6 % 100 == 0 && i6 <= i2) {
                double d11 = i5 / 100.0d;
                if (d11 < 0.2d) {
                    d5 *= 0.9d;
                }
                if (d11 > 0.4d) {
                    d5 *= 1.1d;
                }
                i5 = 0;
            }
        }
        return new double[]{Arrays.stream(dArr5, i2, i).average().orElse(CMAESOptimizer.DEFAULT_STOPFITNESS), Arrays.stream(dArr6, i2, i).average().orElse(CMAESOptimizer.DEFAULT_STOPFITNESS), Arrays.stream(dArr7, i2, i).average().orElse(CMAESOptimizer.DEFAULT_STOPFITNESS), Arrays.stream(dArr8, i2, i).average().orElse(CMAESOptimizer.DEFAULT_STOPFITNESS), Arrays.stream(dArr9, i2, i).average().orElse(CMAESOptimizer.DEFAULT_STOPFITNESS)};
    }

    private double calcLikelihood(double[] dArr, double[] dArr2, double[] dArr3, double d, double d2, double d3, double d4, double d5, double d6) {
        double d7 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d7 += normalDensity(dArr[i], (dArr2[i] < d ? d2 : (d3 * dArr2[i]) + d4) + (d5 * dArr3[i]), Math.sqrt(d6));
        }
        return d7;
    }

    private double normalDensity(double d, double d2, double d3) {
        return ((-0.5d) * Math.log((6.283185307179586d * d3) * d3)) - (0.5d * Math.pow((d - d2) / d3, 2.0d));
    }

    private double[] multivariateNormalSample(double[] dArr, SimpleMatrix simpleMatrix) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr2[i] = normal.sample();
        }
        CholeskyDecomposition_F64<DMatrixRMaj> chol = DecompositionFactory_DDRM.chol(length, true);
        if (chol.decompose(simpleMatrix.getDDRM())) {
            return new SimpleMatrix(chol.getT(null)).mult(new SimpleMatrix(length, 1, false, dArr2)).plus(new SimpleMatrix(length, 1, false, dArr)).getDDRM().getData();
        }
        throw new RuntimeException("Cholesky decomposition failed");
    }
}
