/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.stat;

import java.util.Arrays;
import java.util.Comparator;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.stat.StatUtils;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.factory.DecompositionFactory_DDRM;
import org.ejml.interfaces.decomposition.CholeskyDecomposition_F64;
import org.ejml.simple.SimpleMatrix;

public class MCMCEstimationContinuous {
    private static final MersenneTwister random = new MersenneTwister();
    private static final NormalDistribution normal = new NormalDistribution(random, 0.0, 1.0);

    public static void main(String[] args) {
        double h2 = 0.01;
        double c_true = 0.1;
        double A_true = 1.0;
        double k_true = 1.0;
        double b_true = 2.0;
        double z_true = 0.0;
        int populationSize = 1000000;
        int sampleSize = 2000;
        double[] X2 = new double[populationSize];
        double[] Z = new double[populationSize];
        double[] Y = new double[populationSize];
        for (int i = 0; i < populationSize; ++i) {
            X2[i] = random.nextDouble();
            Z[i] = normal.sample();
            Y[i] = X2[i] < c_true ? A_true : k_true * X2[i] + b_true;
        }
        double varX = StatUtils.variance(Y);
        for (int i = 0; i < populationSize; ++i) {
            int n = i;
            Y[n] = Y[n] + Z[i] * z_true;
        }
        double varZ = StatUtils.variance(Y) - varX;
        double varE = (varX * (1.0 - h2) - varZ * h2) / h2;
        double[] X1 = new double[sampleSize];
        double[] Z1 = new double[sampleSize];
        double[] Y1 = new double[sampleSize];
        for (int i = 0; i < sampleSize; ++i) {
            X1[i] = random.nextDouble();
            Z1[i] = normal.sample();
            double epsilon = normal.sample() * Math.sqrt(varE);
            Y1[i] = (X1[i] < c_true ? A_true : k_true * X1[i] + b_true) + Z1[i] * z_true + epsilon;
        }
        long startTime = System.nanoTime();
        MCMCEstimationContinuous mcmc = new MCMCEstimationContinuous();
        double[] estimates = mcmc.mcmcEstimation(X1, Y1, Z1, 50000, 10000);
        long endTime = System.nanoTime();
        double duration = (double)(endTime - startTime) / 1.0E9;
        System.out.println("Runtime: " + duration + " seconds");
        System.out.println("Estimated c_true: " + estimates[0] + " True value: " + c_true);
        System.out.println("Estimated A_true: " + estimates[1] + " True value: " + A_true);
        System.out.println("Estimated k_true: " + estimates[2] + " True value: " + k_true);
        System.out.println("Estimated b_true: " + estimates[3] + " True value: " + b_true);
        System.out.println("Estimated z_true: " + estimates[4] + " True value: " + z_true);
    }

    public double[] mcmcEstimation(double[] X1, double[] Y1, double[] Z1, int n_iter, int burn_in) {
        double c;
        int n = Y1.length;
        Integer[] sortedIdx = (Integer[])IntStream.range(0, n).boxed().sorted(Comparator.comparingDouble(i -> X1[i])).toArray(Integer[]::new);
        double[] X1_sorted = Arrays.stream(sortedIdx).mapToDouble(i -> X1[i]).toArray();
        double[] Y1_sorted = Arrays.stream(sortedIdx).mapToDouble(i -> Y1[i]).toArray();
        double[] diff_Y = new double[n - 1];
        for (int i2 = 0; i2 < n - 1; ++i2) {
            diff_Y[i2] = (Y1_sorted[i2 + 1] - Y1_sorted[i2]) / (X1_sorted[i2 + 1] - X1_sorted[i2]);
        }
        int maxIdx = IntStream.range(0, diff_Y.length).boxed().max(Comparator.comparingDouble(i -> Math.abs(diff_Y[i]))).get();
        double c_init = X1_sorted[maxIdx];
        double finalC = c = Math.max(Math.min(c_init, 0.9), 0.05);
        double A = Arrays.stream(Y1).filter(y -> X1[Arrays.asList(sortedIdx).indexOf(Arrays.stream(sortedIdx).filter(i -> Y1[i] == y).findFirst().get())] < finalC).average().orElse(0.0);
        double k = 0.0;
        double finalC1 = c;
        double b = Arrays.stream(Y1).filter(y -> X1[Arrays.asList(sortedIdx).indexOf(Arrays.stream(sortedIdx).filter(i -> Y1[i] == y).findFirst().get())] >= finalC1).average().orElse(0.0);
        double z = 0.0;
        double sigma2 = StatUtils.variance(Y1);
        double[] c_chain = new double[n_iter];
        double[] A_chain = new double[n_iter];
        double[] k_chain = new double[n_iter];
        double[] b_chain = new double[n_iter];
        double[] z_chain = new double[n_iter];
        double[] sigma2_chain = new double[n_iter];
        double mu_A = 0.0;
        double sigma_A2 = 100.0;
        double mu_k = 0.0;
        double sigma_k2 = 100.0;
        double mu_b = 0.0;
        double sigma_b2 = 100.0;
        double mu_z = 0.0;
        double sigma_z2 = 100.0;
        double a_sigma = 2.0;
        double b_sigma = 1.0;
        double c_sd_proposal = 0.005;
        int accept_count = 0;
        for (int i3 = 0; i3 < n_iter; ++i3) {
            double[] mu_current = new double[n];
            for (int j = 0; j < n; ++j) {
                mu_current[j] = (X1[j] < c ? A : k * X1[j] + b) + z * Z1[j];
            }
            double[] resid = new double[n];
            for (int j = 0; j < n; ++j) {
                resid[j] = Y1[j] - mu_current[j];
            }
            double sum_resid2 = Arrays.stream(resid).map(r -> r * r).sum();
            GammaDistribution gamma = new GammaDistribution(random, a_sigma + (double)n / 2.0, 1.0 / (b_sigma + sum_resid2 / 2.0));
            sigma2 = 1.0 / gamma.sample();
            boolean[] idx_A = new boolean[n];
            int n_A = 0;
            for (int j = 0; j < n; ++j) {
                boolean bl = idx_A[j] = X1[j] < c;
                if (!idx_A[j]) continue;
                ++n_A;
            }
            if (n_A > 0) {
                double[] Y_adj = new double[n_A];
                int idx = 0;
                for (int j = 0; j < n; ++j) {
                    if (!idx_A[j]) continue;
                    Y_adj[idx++] = Y1[j] - z * Z1[j];
                }
                double prec = (double)n_A / sigma2 + 1.0 / sigma_A2;
                double mean_A = (StatUtils.sum(Y_adj) / sigma2 + mu_A / sigma_A2) / prec;
                A = normal.sample() * Math.sqrt(1.0 / prec) + mean_A;
            } else {
                A = normal.sample() * Math.sqrt(sigma_A2) + mu_A;
            }
            boolean[] idx_kb = new boolean[n];
            int n_kb = 0;
            for (int j = 0; j < n; ++j) {
                boolean bl = idx_kb[j] = X1[j] >= c;
                if (!idx_kb[j]) continue;
                ++n_kb;
            }
            if (n_kb > 1) {
                double[] Y_adj = new double[n_kb];
                double[][] X_mat = new double[n_kb][2];
                int idx = 0;
                for (int j = 0; j < n; ++j) {
                    if (!idx_kb[j]) continue;
                    Y_adj[idx] = Y1[j] - z * Z1[j];
                    X_mat[idx][0] = 1.0;
                    X_mat[idx][1] = X1[j];
                    ++idx;
                }
                SimpleMatrix X2 = new SimpleMatrix(X_mat);
                SimpleMatrix Sigma_prior_inv = SimpleMatrix.diag(1.0 / sigma_b2, 1.0 / sigma_k2);
                SimpleMatrix mu_prior = new SimpleMatrix(2, 1, false, new double[]{mu_b, mu_k});
                SimpleMatrix XtX = (SimpleMatrix)((SimpleMatrix)X2.transpose()).mult(X2).scale(1.0 / sigma2);
                SimpleMatrix V_inv = XtX.plus(Sigma_prior_inv);
                SimpleMatrix V = (SimpleMatrix)V_inv.invert();
                SimpleMatrix mean_kb = V.mult(((SimpleMatrix)((SimpleMatrix)X2.transpose()).scale(1.0 / sigma2)).mult(new SimpleMatrix(n_kb, 1, false, Y_adj)).plus(Sigma_prior_inv.mult(mu_prior)));
                double[] kb = this.multivariateNormalSample(mean_kb.getDDRM().getData(), V);
                b = kb[0];
                k = kb[1];
            } else {
                b = normal.sample() * Math.sqrt(sigma_b2) + mu_b;
                k = normal.sample() * Math.sqrt(sigma_k2) + mu_k;
            }
            double[] mu_adj = new double[n];
            for (int j = 0; j < n; ++j) {
                mu_adj[j] = X1[j] < c ? A : k * X1[j] + b;
            }
            double[] Y_adj = new double[n];
            for (int j = 0; j < n; ++j) {
                Y_adj[j] = Y1[j] - mu_adj[j];
            }
            double sum_Z2 = Arrays.stream(Z1).map(z1 -> z1 * z1).sum();
            double prec_z = sum_Z2 / sigma2 + 1.0 / sigma_z2;
            double sum_YZ = 0.0;
            for (int j = 0; j < n; ++j) {
                sum_YZ += Y_adj[j] * Z1[j];
            }
            double mean_z = (sum_YZ / sigma2 + mu_z / sigma_z2) / prec_z;
            z = normal.sample() * Math.sqrt(1.0 / prec_z) + mean_z;
            double c_current = c;
            double c_new = normal.sample() * c_sd_proposal + c_current;
            if (c_new > 0.0 && c_new < 1.0) {
                double loglik_current = this.calcLikelihood(Y1, X1, Z1, c_current, A, k, b, z, sigma2);
                double loglik_new = this.calcLikelihood(Y1, X1, Z1, c_new, A, k, b, z, sigma2);
                double log_alpha = loglik_new - loglik_current;
                if (Math.log(random.nextDouble()) < log_alpha) {
                    c = c_new;
                    if (i3 > burn_in) {
                        ++accept_count;
                    }
                }
            }
            c_chain[i3] = c;
            A_chain[i3] = A;
            k_chain[i3] = k;
            b_chain[i3] = b;
            z_chain[i3] = z;
            sigma2_chain[i3] = sigma2;
            if (i3 % 100 != 0 || i3 > burn_in) continue;
            double accept_rate = (double)accept_count / 100.0;
            if (accept_rate < 0.2) {
                c_sd_proposal *= 0.9;
            }
            if (accept_rate > 0.4) {
                c_sd_proposal *= 1.1;
            }
            accept_count = 0;
        }
        int start = burn_in;
        double c_est = Arrays.stream(c_chain, start, n_iter).average().orElse(0.0);
        double A_est = Arrays.stream(A_chain, start, n_iter).average().orElse(0.0);
        double k_est = Arrays.stream(k_chain, start, n_iter).average().orElse(0.0);
        double b_est = Arrays.stream(b_chain, start, n_iter).average().orElse(0.0);
        double z_est = Arrays.stream(z_chain, start, n_iter).average().orElse(0.0);
        return new double[]{c_est, A_est, k_est, b_est, z_est};
    }

    private double calcLikelihood(double[] Y1, double[] X1, double[] Z1, double c, double A, double k, double b, double z, double sigma2) {
        double sumLogLik = 0.0;
        for (int i = 0; i < Y1.length; ++i) {
            double mu = (X1[i] < c ? A : k * X1[i] + b) + z * Z1[i];
            sumLogLik += this.normalDensity(Y1[i], mu, Math.sqrt(sigma2));
        }
        return sumLogLik;
    }

    private double normalDensity(double x, double mean, double sd) {
        return -0.5 * Math.log(Math.PI * 2 * sd * sd) - 0.5 * Math.pow((x - mean) / sd, 2.0);
    }

    private double[] multivariateNormalSample(double[] mean, SimpleMatrix cov) {
        int dim = mean.length;
        double[] normalSamples = new double[dim];
        for (int i = 0; i < dim; ++i) {
            normalSamples[i] = normal.sample();
        }
        CholeskyDecomposition_F64<DMatrixRMaj> cholesky = DecompositionFactory_DDRM.chol(dim, true);
        if (!cholesky.decompose(cov.getDDRM())) {
            throw new RuntimeException("Cholesky decomposition failed");
        }
        SimpleMatrix L = new SimpleMatrix(cholesky.getT(null));
        SimpleMatrix result = L.mult(new SimpleMatrix(dim, 1, false, normalSamples)).plus(new SimpleMatrix(dim, 1, false, mean));
        return result.getDDRM().getData();
    }
}

