/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.stat;

import cern.jet.stat.Probability;
import java.util.Arrays;
import java.util.Random;
import java.util.stream.IntStream;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.dense.row.factory.LinearSolverFactory_DDRM;
import org.ejml.interfaces.linsol.LinearSolverDense;

public class LogisticRegressionGtyFast {
    private boolean isPreCalculated = false;
    private DMatrixRMaj residuals;
    private DMatrixRMaj invHessianNull;
    private DMatrixRMaj w0;
    private DMatrixRMaj xCovT_W0;
    private int nSamples;
    private int nCovars;
    private DMatrixRMaj g_vector;
    private DMatrixRMaj g_w0;
    private DMatrixRMaj term2_vec;
    private DMatrixRMaj tempVec;

    public boolean isPreCalculated() {
        return this.isPreCalculated;
    }

    public void preCalculate(double[] yPhenotype, double[][] covariates) {
        int j;
        this.nSamples = yPhenotype.length;
        if (this.nSamples == 0) {
            throw new IllegalArgumentException("\u8868\u578b\u6570\u636e\u4e0d\u80fd\u4e3a\u7a7a\u3002");
        }
        DMatrixRMaj y = new DMatrixRMaj(yPhenotype.length, 1, true, yPhenotype);
        DMatrixRMaj xCov = LogisticRegressionGtyFast.createDesignMatrixWithIntercept(covariates, this.nSamples);
        this.nCovars = xCov.getNumCols();
        DMatrixRMaj beta = new DMatrixRMaj(this.nCovars, 1);
        int maxIter = 25;
        double tolerance = 1.0E-6;
        DMatrixRMaj eta = new DMatrixRMaj(this.nSamples, 1);
        DMatrixRMaj mu = new DMatrixRMaj(this.nSamples, 1);
        DMatrixRMaj xCovT = CommonOps_DDRM.transpose(xCov, null);
        LinearSolverDense<DMatrixRMaj> solver = LinearSolverFactory_DDRM.lu(this.nCovars);
        for (int i = 0; i < 25; ++i) {
            CommonOps_DDRM.mult(xCov, beta, eta);
            for (int j2 = 0; j2 < this.nSamples; ++j2) {
                mu.set(j2, 0, 1.0 / (1.0 + Math.exp(-eta.get(j2, 0))));
            }
            DMatrixRMaj w = new DMatrixRMaj(this.nSamples, 1);
            for (int j3 = 0; j3 < this.nSamples; ++j3) {
                double mu_j = mu.get(j3, 0);
                w.set(j3, 0, Math.max(mu_j * (1.0 - mu_j), 1.0E-8));
            }
            DMatrixRMaj hessian = new DMatrixRMaj(this.nCovars, this.nCovars);
            for (int j4 = 0; j4 < this.nSamples; ++j4) {
                double wi = w.get(j4, 0);
                for (int k = 0; k < this.nCovars; ++k) {
                    double xij = xCov.get(j4, k);
                    for (int t = 0; t < this.nCovars; ++t) {
                        hessian.add(k, t, wi * xij * xCov.get(j4, t));
                    }
                }
            }
            if (!solver.setA(hessian) && !(solver = LinearSolverFactory_DDRM.qr(this.nCovars, this.nCovars)).setA(hessian)) {
                throw new RuntimeException("\u77e9\u9635\u5947\u5f02\uff0c\u96f6\u6a21\u578b\u65e0\u6cd5\u62df\u5408\u3002");
            }
            DMatrixRMaj resid_iter = new DMatrixRMaj(this.nSamples, 1);
            CommonOps_DDRM.subtract(y, mu, resid_iter);
            DMatrixRMaj score = new DMatrixRMaj(this.nCovars, 1);
            CommonOps_DDRM.mult(xCovT, resid_iter, score);
            DMatrixRMaj delta = new DMatrixRMaj(this.nCovars, 1);
            solver.solve(score, delta);
            CommonOps_DDRM.add(beta, delta, beta);
            if (CommonOps_DDRM.elementMaxAbs(delta) < 1.0E-6) break;
        }
        CommonOps_DDRM.mult(xCov, beta, eta);
        DMatrixRMaj mu0 = new DMatrixRMaj(this.nSamples, 1);
        for (j = 0; j < this.nSamples; ++j) {
            mu0.set(j, 0, 1.0 / (1.0 + Math.exp(-eta.get(j, 0))));
        }
        this.residuals = new DMatrixRMaj(this.nSamples, 1);
        CommonOps_DDRM.subtract(y, mu0, this.residuals);
        this.w0 = new DMatrixRMaj(this.nSamples, 1);
        for (j = 0; j < this.nSamples; ++j) {
            double mu_j = mu0.get(j, 0);
            this.w0.set(j, 0, Math.max(mu_j * (1.0 - mu_j), 1.0E-8));
        }
        DMatrixRMaj W0 = CommonOps_DDRM.diag(this.w0.getData());
        DMatrixRMaj hessianNull = new DMatrixRMaj(this.nCovars, this.nCovars);
        DMatrixRMaj temp_m_x_n = new DMatrixRMaj(this.nCovars, this.nSamples);
        CommonOps_DDRM.mult(xCovT, W0, temp_m_x_n);
        CommonOps_DDRM.mult(temp_m_x_n, xCov, hessianNull);
        solver.setA(hessianNull);
        this.invHessianNull = new DMatrixRMaj(this.nCovars, this.nCovars);
        solver.invert(this.invHessianNull);
        this.xCovT_W0 = temp_m_x_n;
        this.g_vector = new DMatrixRMaj(this.nSamples, 1);
        this.g_w0 = new DMatrixRMaj(this.nSamples, 1);
        this.term2_vec = new DMatrixRMaj(this.nCovars, 1);
        this.tempVec = new DMatrixRMaj(this.nCovars, 1);
        this.isPreCalculated = true;
    }

    public double[][] testSnp(double[][] genotypes) {
        if (!this.isPreCalculated) {
            throw new IllegalStateException("\u5fc5\u987b\u5728\u6d4b\u8bd5SNP\u4e4b\u524d\u8c03\u7528 preCalculate() \u65b9\u6cd5\u3002");
        }
        if (genotypes == null || genotypes.length == 0 || genotypes[0].length != this.nSamples) {
            throw new IllegalArgumentException("\u57fa\u56e0\u578b\u77e9\u9635\u7ef4\u5ea6\u4e0d\u6b63\u786e\u3002\u5e94\u4e3a [k_snps][n_samples]\u3002");
        }
        int kSnps = genotypes.length;
        if (kSnps == 1) {
            return this.testSingleSnpFastPath(genotypes[0]);
        }
        DMatrixRMaj G = new DMatrixRMaj(this.nSamples, kSnps);
        for (int j = 0; j < kSnps; ++j) {
            int i;
            double[] currentSnpData = genotypes[j];
            double sum = 0.0;
            int nonMissingCount = 0;
            for (i = 0; i < this.nSamples; ++i) {
                if (Double.isNaN(currentSnpData[i])) continue;
                sum += currentSnpData[i];
                ++nonMissingCount;
            }
            if (nonMissingCount == 0) {
                for (i = 0; i < this.nSamples; ++i) {
                    G.set(i, j, 0.0);
                }
                continue;
            }
            double meanGty = sum / (double)nonMissingCount;
            for (int i2 = 0; i2 < this.nSamples; ++i2) {
                G.set(i2, j, Double.isNaN(currentSnpData[i2]) ? meanGty : currentSnpData[i2]);
            }
        }
        return this.performMultiSnpTestCore(G);
    }

    private double[][] testSingleSnpFastPath(double[] genotype) {
        double sum = 0.0;
        int nObserved = 0;
        for (int i = 0; i < this.nSamples; ++i) {
            if (Double.isNaN(genotype[i])) continue;
            sum += genotype[i];
            ++nObserved;
        }
        if (nObserved == 0) {
            return new double[][]{{Double.NaN, Double.NaN, 1.0}};
        }
        double meanGty = sum / (double)nObserved;
        for (int i = 0; i < this.nSamples; ++i) {
            this.g_vector.data[i] = Double.isNaN(genotype[i]) ? meanGty : genotype[i];
        }
        double Uj = CommonOps_DDRM.dot(this.g_vector, this.residuals);
        CommonOps_DDRM.elementMult(this.g_vector, this.w0, this.g_w0);
        double term1 = CommonOps_DDRM.dot(this.g_w0, this.g_vector);
        CommonOps_DDRM.mult(this.xCovT_W0, this.g_vector, this.term2_vec);
        CommonOps_DDRM.mult(this.invHessianNull, this.term2_vec, this.tempVec);
        double term2 = CommonOps_DDRM.dot(this.term2_vec, this.tempVec);
        double Vj = term1 - term2;
        if (Vj <= 1.0E-8) {
            return new double[][]{{Double.NaN, Double.NaN, 1.0}};
        }
        double beta = Uj / Vj;
        double se = Math.sqrt(1.0 / Vj);
        double chiSq = Uj * Uj / Vj;
        double pValue = Probability.chiSquareComplemented(1.0, chiSq);
        return new double[][]{{beta, se, pValue}};
    }

    private double[][] performMultiSnpTestCore(DMatrixRMaj G) {
        int kSnps = G.getNumCols();
        DMatrixRMaj U = new DMatrixRMaj(kSnps, 1);
        CommonOps_DDRM.multTransA(G, this.residuals, U);
        DMatrixRMaj G_w_scaled = new DMatrixRMaj(this.nSamples, kSnps);
        for (int i = 0; i < this.nSamples; ++i) {
            double w_i = this.w0.get(i, 0);
            for (int j = 0; j < kSnps; ++j) {
                G_w_scaled.set(i, j, G.get(i, j) * w_i);
            }
        }
        DMatrixRMaj term1 = new DMatrixRMaj(kSnps, kSnps);
        CommonOps_DDRM.multTransA(G, G_w_scaled, term1);
        DMatrixRMaj M = new DMatrixRMaj(this.nCovars, kSnps);
        CommonOps_DDRM.mult(this.xCovT_W0, G, M);
        DMatrixRMaj M_trans = CommonOps_DDRM.transpose(M, null);
        DMatrixRMaj temp_m_x_k = new DMatrixRMaj(this.nCovars, kSnps);
        CommonOps_DDRM.mult(this.invHessianNull, M, temp_m_x_k);
        DMatrixRMaj term2 = new DMatrixRMaj(kSnps, kSnps);
        CommonOps_DDRM.mult(M_trans, temp_m_x_k, term2);
        DMatrixRMaj V = new DMatrixRMaj(kSnps, kSnps);
        CommonOps_DDRM.subtract(term1, term2, V);
        LinearSolverDense<DMatrixRMaj> solver = LinearSolverFactory_DDRM.symmPosDef(kSnps);
        if (!solver.setA(V) && !(solver = LinearSolverFactory_DDRM.lu(kSnps)).setA(V)) {
            int i;
            double[][] results = new double[kSnps][3];
            for (i = 0; i < kSnps; ++i) {
                Arrays.fill(results[i], Double.NaN);
            }
            for (i = 0; i < kSnps; ++i) {
                results[i][2] = 1.0;
            }
            return results;
        }
        DMatrixRMaj beta_G = new DMatrixRMaj(kSnps, 1);
        solver.solve(U, beta_G);
        DMatrixRMaj V_inv = new DMatrixRMaj(kSnps, kSnps);
        solver.invert(V_inv);
        double[][] results = new double[kSnps][3];
        for (int j = 0; j < kSnps; ++j) {
            double beta = beta_G.get(j, 0);
            double var = V_inv.get(j, j);
            if (var < 0.0 || Double.isNaN(var)) {
                results[j][0] = beta;
                results[j][1] = Double.NaN;
                results[j][2] = 1.0;
                continue;
            }
            double se = Math.sqrt(var);
            double z = beta / se;
            double chiSq = z * z;
            double pValue = Probability.chiSquareComplemented(1.0, chiSq);
            results[j][0] = beta;
            results[j][1] = se;
            results[j][2] = pValue;
        }
        return results;
    }

    private static DMatrixRMaj createDesignMatrixWithIntercept(double[][] covariates, int nSamples) {
        int nCovars = covariates != null && covariates.length > 0 && covariates[0] != null ? covariates[0].length : 0;
        DMatrixRMaj designMatrix = new DMatrixRMaj(nSamples, 1 + nCovars);
        for (int i = 0; i < nSamples; ++i) {
            designMatrix.set(i, 0, 1.0);
            if (covariates == null || nCovars <= 0) continue;
            for (int j = 0; j < nCovars; ++j) {
                designMatrix.set(i, 1 + j, covariates[i][j]);
            }
        }
        return designMatrix;
    }

    public static void main(String[] args) {
        int j;
        int n = 5000;
        int numCovars = 2;
        int numSnps = 1;
        Random rand = new Random(2024L);
        double[] allPhenotypes = new double[n];
        double[][] allGenotypes = new double[numSnps][n];
        double[][] allCovariates = new double[n][numCovars];
        for (int i2 = 0; i2 < n; ++i2) {
            for (j = 0; j < numCovars; ++j) {
                allCovariates[i2][j] = rand.nextGaussian() * 0.5;
            }
        }
        double[] mafs = new double[]{0.3, 0.25};
        for (j = 0; j < numSnps; ++j) {
            for (int i3 = 0; i3 < n; ++i3) {
                double p = rand.nextDouble();
                allGenotypes[j][i3] = p < mafs[j] * mafs[j] ? 2.0 : (p < 2.0 * mafs[j] * (1.0 - mafs[j]) + mafs[j] * mafs[j] ? 1.0 : 0.0);
            }
        }
        double beta0 = -1.5;
        double[] betaG = new double[]{0.3, 0.4};
        double[] betaCov = new double[]{0.5, -0.2};
        for (int i4 = 0; i4 < n; ++i4) {
            int j2;
            double logit = beta0;
            for (j2 = 0; j2 < numSnps; ++j2) {
                logit += betaG[j2] * allGenotypes[j2][i4];
            }
            for (j2 = 0; j2 < numCovars; ++j2) {
                logit += betaCov[j2] * allCovariates[i4][j2];
            }
            double prob = 1.0 / (1.0 + Math.exp(-logit));
            allPhenotypes[i4] = rand.nextDouble() < prob ? 1.0 : 0.0;
        }
        int[][] missingIndexes = new int[numSnps][];
        for (int j3 = 0; j3 < numSnps; ++j3) {
            int numMissing = (int)((double)n * 0.4);
            missingIndexes[j3] = new int[numMissing];
            for (int k = 0; k < numMissing; ++k) {
                int idx = rand.nextInt(n);
                allGenotypes[j3][idx] = Double.NaN;
                missingIndexes[j3][k] = idx;
            }
        }
        System.out.println("--- \u65b9\u6cd5 1: LogisticRegressionGtyFast (\u9884\u8ba1\u7b97 + \u5747\u503c\u63d2\u8865) ---");
        LogisticRegressionGtyFast gwasFast = new LogisticRegressionGtyFast();
        long startTime = System.currentTimeMillis();
        gwasFast.preCalculate(allPhenotypes, allCovariates);
        long precalcTime = System.currentTimeMillis();
        double[][] resultsFast = gwasFast.testSnp(allGenotypes);
        long testTime = System.currentTimeMillis();
        System.out.println("\u9884\u8ba1\u7b97\u8017\u65f6: " + (precalcTime - startTime) + " ms");
        System.out.println("\u68c0\u9a8c\u8017\u65f6: " + (testTime - precalcTime) + " ms");
        LogisticRegressionGtyFast.printResultsTable(resultsFast, "SNP");
        System.out.println("\n--- \u65b9\u6cd5 2: LogisticRegressionFull (\u9ec4\u91d1\u6807\u51c6 - \u5b8c\u6574\u6848\u4f8b\u5206\u6790) ---");
        boolean[] isSampleMissing = new boolean[n];
        for (int j4 = 0; j4 < numSnps; ++j4) {
            for (int i5 = 0; i5 < n; ++i5) {
                if (!Double.isNaN(allGenotypes[j4][i5])) continue;
                isSampleMissing[i5] = true;
            }
        }
        int nComplete = (int)IntStream.range(0, n).filter(i -> !isSampleMissing[i]).count();
        System.out.println("\u539f\u59cb\u6837\u672c\u6570: " + n + ", \u5b8c\u6574\u6848\u4f8b\u6837\u672c\u6570: " + nComplete);
        double[] phenotypesComplete = new double[nComplete];
        double[][] predictorsComplete = new double[nComplete][numSnps + numCovars];
        int completeIdx = 0;
        for (int i6 = 0; i6 < n; ++i6) {
            int j5;
            if (isSampleMissing[i6]) continue;
            phenotypesComplete[completeIdx] = allPhenotypes[i6];
            for (j5 = 0; j5 < numSnps; ++j5) {
                predictorsComplete[completeIdx][j5] = allGenotypes[j5][i6];
            }
            for (j5 = 0; j5 < numCovars; ++j5) {
                predictorsComplete[completeIdx][numSnps + j5] = allCovariates[i6][j5];
            }
            ++completeIdx;
        }
        try {
            LogisticRegressionFull model = new LogisticRegressionFull(predictorsComplete, phenotypesComplete);
            long goldStartTime = System.currentTimeMillis();
            model.fit();
            long goldEndTime = System.currentTimeMillis();
            System.out.println("\u5b8c\u6574\u6a21\u578b\u62df\u5408\u8017\u65f6: " + (goldEndTime - goldStartTime) + " ms");
            double[] betasGold = model.getCoefficients();
            double[] seGold = model.getStandardErrors();
            double[] pValsGold = model.getPValues();
            double[][] resultsGold = new double[numSnps][3];
            for (int j6 = 0; j6 < numSnps; ++j6) {
                resultsGold[j6][0] = betasGold[1 + j6];
                resultsGold[j6][1] = seGold[1 + j6];
                resultsGold[j6][2] = pValsGold[1 + j6];
            }
            LogisticRegressionGtyFast.printResultsTable(resultsGold, "SNP");
        }
        catch (Exception e) {
            System.out.println("\u9ec4\u91d1\u6807\u51c6\u6a21\u578b\u62df\u5408\u5931\u8d25: " + e.getMessage());
            e.printStackTrace();
        }
        System.out.println("\n--- \u7ed3\u8bba ---");
        System.out.println("\u65b9\u6cd51 (\u8fd1\u4f3c) \u548c \u65b9\u6cd52 (\u9ec4\u91d1\u6807\u51c6) \u7684\u7ed3\u679c\u975e\u5e38\u63a5\u8fd1\uff0c\u4f46\u7531\u4e8e\u5904\u7406\u7f3a\u5931\u6570\u636e\u7684\u65b9\u5f0f\u4e0d\u540c\u800c\u7565\u6709\u5dee\u5f02\u3002");
        System.out.println("\u65b9\u6cd51 \u7684\u901f\u5ea6\u4f18\u52bf\u5728\u4e8e\u9884\u8ba1\u7b97\uff0c\u5355\u6b21\u68c0\u9a8c\u975e\u5e38\u5feb\uff0c\u9002\u5408\u5927\u89c4\u6a21\u626b\u63cf\u3002");
    }

    public static void printResultsTable(double[][] results, String varPrefix) {
        if (results == null || results.length == 0) {
            System.out.println("\u65e0\u7ed3\u679c\u53ef\u663e\u793a\u3002");
            return;
        }
        System.out.println("-----------------------------------------------------");
        System.out.printf("%-8s | %-12s | %-12s | %-12s%n", varPrefix, "Est. BETA", "Std. Error", "P-value");
        System.out.println("-----------------------------------------------------");
        for (int i = 0; i < results.length; ++i) {
            System.out.printf("%-8s | %-12.4f | %-12.4f | %-12.3e%n", varPrefix + " " + (i + 1), results[i][0], results[i][1], results[i][2]);
        }
        System.out.println("-----------------------------------------------------");
    }

    static class LogisticRegressionFull {
        private final DMatrixRMaj X;
        private final DMatrixRMaj y;
        private DMatrixRMaj beta;
        private DMatrixRMaj varCovar;

        public LogisticRegressionFull(double[][] predictors, double[] response) {
            this.y = new DMatrixRMaj(response.length, 1, true, response);
            this.X = LogisticRegressionGtyFast.createDesignMatrixWithIntercept(predictors, response.length);
        }

        public void fit() {
            DMatrixRMaj wVec;
            int nSamples = this.X.getNumRows();
            int nParams = this.X.getNumCols();
            this.beta = new DMatrixRMaj(nParams, 1);
            int maxIter = 50;
            double tolerance = 1.0E-7;
            DMatrixRMaj eta = new DMatrixRMaj(nSamples, 1);
            DMatrixRMaj mu = new DMatrixRMaj(nSamples, 1);
            DMatrixRMaj XT = CommonOps_DDRM.transpose(this.X, null);
            LinearSolverDense<DMatrixRMaj> solver = LinearSolverFactory_DDRM.lu(nParams);
            for (int i = 0; i < 50; ++i) {
                CommonOps_DDRM.mult(this.X, this.beta, eta);
                for (int j = 0; j < nSamples; ++j) {
                    mu.set(j, 0, 1.0 / (1.0 + Math.exp(-eta.get(j, 0))));
                }
                wVec = new DMatrixRMaj(nSamples, 1);
                for (int j = 0; j < nSamples; ++j) {
                    double mu_j = mu.get(j, 0);
                    wVec.set(j, 0, Math.max(mu_j * (1.0 - mu_j), 1.0E-8));
                }
                DMatrixRMaj W = CommonOps_DDRM.diag(wVec.getData());
                DMatrixRMaj hessian = new DMatrixRMaj(nParams, nParams);
                DMatrixRMaj temp_p_x_n = new DMatrixRMaj(nParams, nSamples);
                CommonOps_DDRM.mult(XT, W, temp_p_x_n);
                CommonOps_DDRM.mult(temp_p_x_n, this.X, hessian);
                if (!solver.setA(hessian)) {
                    throw new RuntimeException("\u77e9\u9635\u5947\u5f02");
                }
                DMatrixRMaj resid = new DMatrixRMaj(nSamples, 1);
                CommonOps_DDRM.subtract(this.y, mu, resid);
                DMatrixRMaj score = new DMatrixRMaj(nParams, 1);
                CommonOps_DDRM.mult(XT, resid, score);
                DMatrixRMaj delta = new DMatrixRMaj(nParams, 1);
                solver.solve(score, delta);
                CommonOps_DDRM.add(this.beta, delta, this.beta);
                if (CommonOps_DDRM.elementMaxAbs(delta) < 1.0E-7) break;
            }
            DMatrixRMaj finalHessian = new DMatrixRMaj(nParams, nParams);
            wVec = new DMatrixRMaj(nSamples, 1);
            for (int j = 0; j < nSamples; ++j) {
                double mu_j = mu.get(j, 0);
                wVec.set(j, 0, Math.max(mu_j * (1.0 - mu_j), 1.0E-8));
            }
            DMatrixRMaj W = CommonOps_DDRM.diag(wVec.getData());
            DMatrixRMaj temp_p_x_n = new DMatrixRMaj(nParams, nSamples);
            CommonOps_DDRM.mult(XT, W, temp_p_x_n);
            CommonOps_DDRM.mult(temp_p_x_n, this.X, finalHessian);
            solver.setA(finalHessian);
            this.varCovar = new DMatrixRMaj(nParams, nParams);
            solver.invert(this.varCovar);
        }

        public double[] getCoefficients() {
            return this.beta.getData();
        }

        public double[] getStandardErrors() {
            double[] se = new double[this.beta.getNumRows()];
            for (int i = 0; i < se.length; ++i) {
                se[i] = Math.sqrt(this.varCovar.get(i, i));
            }
            return se;
        }

        public double[] getPValues() {
            double[] pValues = new double[this.beta.getNumRows()];
            double[] se = this.getStandardErrors();
            for (int i = 0; i < pValues.length; ++i) {
                double z = this.beta.get(i, 0) / se[i];
                pValues[i] = Probability.chiSquareComplemented(1.0, z * z);
            }
            return pValues;
        }
    }
}

