/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.stat;

import cern.jet.stat.Probability;
import java.util.Arrays;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.dense.row.factory.LinearSolverFactory_DDRM;
import org.ejml.interfaces.linsol.LinearSolverDense;

public class LinearRegressionGtyFast {
    private boolean isPreCalculated = false;
    private DMatrixRMaj residuals;
    private DMatrixRMaj invXtX;
    private double residualVariance;
    private DMatrixRMaj xCovT;
    private int nSamples;
    private int nCovars;
    private DMatrixRMaj g_vector;
    private DMatrixRMaj term2_vec;
    private DMatrixRMaj tempVec;

    public boolean isPreCalculated() {
        return this.isPreCalculated;
    }

    public void preCalculate(double[] yPhenotype, double[][] covariates) {
        this.nSamples = yPhenotype.length;
        if (this.nSamples == 0) {
            throw new IllegalArgumentException("\u8868\u578b\u6570\u636e\u4e0d\u80fd\u4e3a\u7a7a\u3002");
        }
        DMatrixRMaj y = new DMatrixRMaj(yPhenotype);
        DMatrixRMaj xCov = LinearRegressionGtyFast.createDesignMatrixWithIntercept(covariates, this.nSamples);
        this.xCovT = CommonOps_DDRM.transpose(xCov, null);
        this.nCovars = xCov.getNumCols();
        DMatrixRMaj xtx = new DMatrixRMaj(this.nCovars, this.nCovars);
        CommonOps_DDRM.multTransA(xCov, xCov, xtx);
        this.invXtX = new DMatrixRMaj(this.nCovars, this.nCovars);
        LinearSolverDense<DMatrixRMaj> solver = LinearSolverFactory_DDRM.symmPosDef(this.nCovars);
        if (!solver.setA(xtx) && !(solver = LinearSolverFactory_DDRM.qr(this.nCovars, this.nCovars)).setA(xtx)) {
            throw new RuntimeException("\u77e9\u9635 X'X \u5947\u5f02\uff0c\u96f6\u6a21\u578b\u65e0\u6cd5\u62df\u5408\u3002");
        }
        solver.invert(this.invXtX);
        DMatrixRMaj xty = new DMatrixRMaj(this.nCovars, 1);
        CommonOps_DDRM.mult(this.xCovT, y, xty);
        DMatrixRMaj betaNull = new DMatrixRMaj(this.nCovars, 1);
        CommonOps_DDRM.mult(this.invXtX, xty, betaNull);
        DMatrixRMaj yHatNull = new DMatrixRMaj(this.nSamples, 1);
        CommonOps_DDRM.mult(xCov, betaNull, yHatNull);
        this.residuals = new DMatrixRMaj(this.nSamples, 1);
        CommonOps_DDRM.subtract(y, yHatNull, this.residuals);
        if (this.nSamples <= this.nCovars) {
            throw new IllegalArgumentException("\u6837\u672c\u6570\u5fc5\u987b\u5927\u4e8e\u534f\u53d8\u91cf\u6570\u3002");
        }
        double rss = CommonOps_DDRM.dot(this.residuals, this.residuals);
        this.residualVariance = rss / (double)(this.nSamples - this.nCovars);
        this.g_vector = new DMatrixRMaj(this.nSamples, 1);
        this.term2_vec = new DMatrixRMaj(this.nCovars, 1);
        this.tempVec = new DMatrixRMaj(this.nCovars, 1);
        this.isPreCalculated = true;
    }

    public double[][] testSnp(double[][] genotypes) {
        if (!this.isPreCalculated) {
            throw new IllegalStateException("\u5fc5\u987b\u5728\u6d4b\u8bd5SNP\u4e4b\u524d\u8c03\u7528 preCalculate() \u65b9\u6cd5\u3002");
        }
        if (genotypes == null || genotypes.length == 0 || genotypes[0].length != this.nSamples) {
            throw new IllegalArgumentException("\u57fa\u56e0\u578b\u77e9\u9635\u7ef4\u5ea6\u4e0d\u6b63\u786e\u3002\u5e94\u4e3a [k_snps][n_samples]\u3002");
        }
        int numSnps = genotypes.length;
        if (numSnps == 0) {
            return new double[0][0];
        }
        if (numSnps == 1) {
            return this.testSingleSnpFastPath(genotypes[0]);
        }
        if (this.nSamples <= this.nCovars + numSnps) {
            double[][] results = new double[numSnps][3];
            for (int i = 0; i < numSnps; ++i) {
                Arrays.fill(results[i], Double.NaN);
                results[i][2] = 1.0;
            }
            return results;
        }
        DMatrixRMaj G = new DMatrixRMaj(this.nSamples, numSnps);
        for (int j = 0; j < numSnps; ++j) {
            double[] currentSnpData = genotypes[j];
            double sum = 0.0;
            int nonMissingCount = 0;
            for (int i = 0; i < this.nSamples; ++i) {
                if (Double.isNaN(currentSnpData[i])) continue;
                sum += currentSnpData[i];
                ++nonMissingCount;
            }
            if (nonMissingCount == 0) {
                double[][] results = new double[numSnps][3];
                for (int i = 0; i < numSnps; ++i) {
                    Arrays.fill(results[i], Double.NaN);
                }
                results[j][2] = 1.0;
                return results;
            }
            double meanGty = sum / (double)nonMissingCount;
            for (int i = 0; i < this.nSamples; ++i) {
                G.set(i, j, Double.isNaN(currentSnpData[i]) ? meanGty : currentSnpData[i]);
            }
        }
        return this.performMultiSnpTestCore(G);
    }

    private double[][] testSingleSnpFastPath(double[] genotype) {
        double sum = 0.0;
        int nObserved = 0;
        for (int i = 0; i < this.nSamples; ++i) {
            if (Double.isNaN(genotype[i])) continue;
            sum += genotype[i];
            ++nObserved;
        }
        if (nObserved == 0) {
            return new double[][]{{Double.NaN, Double.NaN, 1.0}};
        }
        double meanGty = sum / (double)nObserved;
        for (int i = 0; i < this.nSamples; ++i) {
            this.g_vector.data[i] = Double.isNaN(genotype[i]) ? meanGty : genotype[i];
        }
        double Uj = CommonOps_DDRM.dot(this.g_vector, this.residuals);
        double gTg = CommonOps_DDRM.dot(this.g_vector, this.g_vector);
        CommonOps_DDRM.mult(this.xCovT, this.g_vector, this.term2_vec);
        CommonOps_DDRM.mult(this.invXtX, this.term2_vec, this.tempVec);
        double term2 = CommonOps_DDRM.dot(this.term2_vec, this.tempVec);
        double Vj_unscaled = gTg - term2;
        if (Vj_unscaled <= 1.0E-8) {
            return new double[][]{{Double.NaN, Double.NaN, 1.0}};
        }
        double beta = Uj / Vj_unscaled;
        double Vj_corrected = this.residualVariance * Vj_unscaled;
        double se = Math.sqrt(this.residualVariance / Vj_unscaled);
        double chiSq = Uj * Uj / Vj_corrected;
        double pValue = Probability.chiSquareComplemented(1.0, chiSq);
        return new double[][]{{beta, se, pValue}};
    }

    private double[][] performMultiSnpTestCore(DMatrixRMaj G) {
        int k = G.getNumCols();
        DMatrixRMaj U_vector = new DMatrixRMaj(k, 1);
        CommonOps_DDRM.multTransA(G, this.residuals, U_vector);
        DMatrixRMaj GtG = new DMatrixRMaj(k, k);
        CommonOps_DDRM.multTransA(G, G, GtG);
        DMatrixRMaj XcovTG = new DMatrixRMaj(this.nCovars, k);
        CommonOps_DDRM.mult(this.xCovT, G, XcovTG);
        DMatrixRMaj temp_term = new DMatrixRMaj(this.nCovars, k);
        CommonOps_DDRM.mult(this.invXtX, XcovTG, temp_term);
        DMatrixRMaj term2_matrix = new DMatrixRMaj(k, k);
        CommonOps_DDRM.multTransA(XcovTG, temp_term, term2_matrix);
        DMatrixRMaj V_matrix_unscaled = new DMatrixRMaj(k, k);
        CommonOps_DDRM.subtract(GtG, term2_matrix, V_matrix_unscaled);
        LinearSolverDense<DMatrixRMaj> snpSolver = LinearSolverFactory_DDRM.symmPosDef(k);
        if (!snpSolver.setA(V_matrix_unscaled) && !(snpSolver = LinearSolverFactory_DDRM.qr(k, k)).setA(V_matrix_unscaled)) {
            double[][] results = new double[k][3];
            for (int i = 0; i < k; ++i) {
                results[i] = new double[]{Double.NaN, Double.NaN, 1.0};
            }
            return results;
        }
        DMatrixRMaj beta_g = new DMatrixRMaj(k, 1);
        snpSolver.solve(U_vector, beta_g);
        DMatrixRMaj V_inv = new DMatrixRMaj(k, k);
        snpSolver.invert(V_inv);
        double[][] results = new double[k][3];
        for (int i = 0; i < k; ++i) {
            double beta = beta_g.get(i, 0);
            double varBeta = this.residualVariance * V_inv.get(i, i);
            if (varBeta < 0.0 || Double.isNaN(varBeta)) {
                results[i] = new double[]{beta, Double.NaN, 1.0};
                continue;
            }
            double se = Math.sqrt(varBeta);
            double chiSq = beta * beta / varBeta;
            double pValue = Probability.chiSquareComplemented(1.0, chiSq);
            results[i] = new double[]{beta, se, pValue};
        }
        return results;
    }

    private static DMatrixRMaj createDesignMatrixWithIntercept(double[][] covariates, int nSamples) {
        int nCovars = covariates != null && covariates.length > 0 && covariates[0] != null ? covariates[0].length : 0;
        DMatrixRMaj designMatrix = new DMatrixRMaj(nSamples, 1 + nCovars);
        for (int i = 0; i < nSamples; ++i) {
            designMatrix.set(i, 0, 1.0);
            if (covariates == null || nCovars <= 0) continue;
            for (int j = 0; j < nCovars; ++j) {
                designMatrix.set(i, 1 + j, covariates[i][j]);
            }
        }
        return designMatrix;
    }
}

