/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.analysis;

import cern.colt.list.DoubleArrayList;
import cern.jet.stat.Probability;
import edu.sysu.pmglab.RuntimeProperty;
import edu.sysu.pmglab.analysis.CalcRegionSet;
import edu.sysu.pmglab.analysis.TNBRegressionParamSet;
import edu.sysu.pmglab.analysis.TNBRegressionRSource;
import edu.sysu.pmglab.bytecode.Bytes;
import edu.sysu.pmglab.container.indexable.IndexableSet;
import edu.sysu.pmglab.container.indexable.LinkedSet;
import edu.sysu.pmglab.container.list.DoubleList;
import edu.sysu.pmglab.container.list.List;
import edu.sysu.pmglab.executor.ThreadQueue;
import edu.sysu.pmglab.io.writer.WriterStream;
import edu.sysu.pmglab.kgga.command.SetupApplication;
import edu.sysu.pmglab.progressbar.ProgressBar;
import edu.sysu.pmglab.stat.RobustRegression;
import edu.sysu.pmglab.stat.Summary;
import gnu.trove.map.hash.THashMap;
import java.io.File;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.AbstractCollection;
import java.util.Arrays;
import java.util.Comparator;
import org.rosuda.REngine.REXPMismatchException;
import org.rosuda.REngine.REngineException;
import org.rosuda.REngine.Rserve.RConnection;
import org.rosuda.REngine.Rserve.RserveException;

public class TNBRegressionParamExplorer {
    List<CalcRegionSet> allRegions;
    List<String> predictorNames;
    boolean logarithmExplanatoryVar;
    List<TNBRegressionParamSet> paramSetList;
    double looseFDR = 0.7;
    String rHost;
    int rPort;
    int minSampleSize = 100;
    boolean runStepwise = false;
    int minMutCount = 2;
    String residualType = "deviance";
    int nThreads;
    IndexableSet<String> geneSymbolMap;
    double caseControlRatio;

    public TNBRegressionParamExplorer(InetSocketAddress rServer, double caseControlRatio) {
        this.rHost = rServer.getHostName();
        this.rPort = rServer.getPort();
        this.caseControlRatio = caseControlRatio;
    }

    public void setGeneSymbolMap(IndexableSet<String> geneSymbolMap) {
        this.geneSymbolMap = geneSymbolMap;
    }

    public int getMinMutCount() {
        return this.minMutCount;
    }

    public void setMinMutCount(int minMutCount) {
        this.minMutCount = minMutCount;
    }

    public TNBRegressionParamExplorer setLooseFDR(double looseFDR) {
        this.looseFDR = looseFDR;
        return this;
    }

    public TNBRegressionParamExplorer setResidualType(String residualType) {
        this.residualType = residualType;
        return this;
    }

    public TNBRegressionParamExplorer setMinSampleSize(int minSampleSize) {
        this.minSampleSize = minSampleSize;
        return this;
    }

    public TNBRegressionParamExplorer setRunStepwise(boolean runStepwise) {
        this.runStepwise = runStepwise;
        return this;
    }

    public TNBRegressionParamExplorer setDataSet(List<CalcRegionSet> allRegions, List<String> scoreHeads) {
        this.allRegions = allRegions;
        this.predictorNames = new List();
        for (String item : scoreHeads) {
            String standName = item.replaceAll("@", "_");
            standName = standName.replaceAll("-", "_");
            this.predictorNames.add(standName);
        }
        return this;
    }

    public void setParamSetList(List<TNBRegressionParamSet> paramSetList) {
        this.paramSetList = paramSetList;
    }

    public TNBRegressionParamExplorer setThreads(int nThreads) {
        this.nThreads = RuntimeProperty.verifyThreads(nThreads);
        return this;
    }

    public boolean calculatePValuesAndObtainMFLC(RConnection rcon, List<String> regionLabelTrunc, List<double[]> scoreListTrunc, TNBRegressionParamSet curTNBParam, int threadIdC, boolean subtractCtrlFromCaseCount) throws IOException {
        double scoreBinCut = curTNBParam.scoreBinCut;
        int truncPoint = curTNBParam.truncationPoint;
        int geneSize = this.allRegions.size();
        boolean cutoffModel = curTNBParam.weightCountAdjustModel == 1;
        boolean adjustAF = curTNBParam.adjustAF;
        boolean useControlPredictor = curTNBParam.useControlMutPredictor;
        int requiredScoreNum = this.predictorNames.size();
        boolean takeLog = !adjustAF;
        double minLogarithm = 1.0E-8;
        double scale = 1.0;
        takeLog = false;
        for (int i = 0; i < geneSize; ++i) {
            double[] scores2;
            double[] scores1;
            int weightedCaseCount;
            CalcRegionSet gr = this.allRegions.get(i);
            String name = this.geneSymbolMap.valueOf(Integer.parseInt(gr.getRegions().get(0).getLabel()));
            if (name.equals("RET")) {
                boolean bl = false;
            }
            if (adjustAF) {
                weightedCaseCount = gr.getDiff(-1.0, cutoffModel);
                if (weightedCaseCount <= this.minMutCount) continue;
                weightedCaseCount = gr.getDiff(scoreBinCut, cutoffModel);
            } else if (subtractCtrlFromCaseCount) {
                int controlCount;
                boolean prioritizeCase = true;
                if (prioritizeCase) {
                    weightedCaseCount = gr.getInteractedMutNumCase(-1.0, cutoffModel);
                    if (weightedCaseCount <= this.minMutCount) continue;
                    weightedCaseCount = gr.getInteractedMutNumCase(scoreBinCut, cutoffModel);
                    if (!Double.isNaN(this.caseControlRatio)) {
                        controlCount = gr.getInteractedMutNumControl(scoreBinCut, cutoffModel);
                        weightedCaseCount -= (int)((double)controlCount * this.caseControlRatio);
                    }
                } else {
                    weightedCaseCount = gr.getInteractedMutNumControl(-1.0, cutoffModel);
                    if (weightedCaseCount <= this.minMutCount) continue;
                    weightedCaseCount = gr.getInteractedMutNumControl(scoreBinCut, cutoffModel);
                    if (!Double.isNaN(this.caseControlRatio)) {
                        controlCount = gr.getInteractedMutNumCase(scoreBinCut, cutoffModel);
                        weightedCaseCount -= (int)((double)controlCount * this.caseControlRatio);
                    }
                }
            } else {
                weightedCaseCount = gr.getInteractedMutNumControl(-1.0, cutoffModel);
                if (weightedCaseCount <= this.minMutCount) continue;
                weightedCaseCount = gr.getInteractedMutNumControl(scoreBinCut, cutoffModel);
            }
            if (weightedCaseCount <= truncPoint || (scores1 = gr.getRegionScores()) == null) continue;
            boolean hasNA = false;
            for (double v : scores1) {
                if (!Double.isNaN(v)) continue;
                hasNA = true;
                break;
            }
            if (hasNA) continue;
            if (!adjustAF) {
                double regionLen;
                double weightedRef;
                scores2 = useControlPredictor ? new double[scores1.length + 4] : new double[scores1.length + 3];
                System.arraycopy(scores1, 0, scores2, 1, scores1.length);
                scores2[0] = weightedCaseCount;
                if (takeLog) {
                    weightedRef = gr.getMutNumRef(scoreBinCut, cutoffModel);
                    if (weightedRef == 0.0) {
                        // empty if block
                    }
                    weightedRef += 1.0E-8;
                    regionLen = scores1[scores1.length - 1] + 1.0E-8;
                    weightedRef = Math.log(weightedRef);
                    regionLen = Math.log(regionLen);
                } else {
                    weightedRef = gr.getMutNumRef(scoreBinCut, cutoffModel) * 1.0;
                    regionLen = scores1[scores1.length - 1];
                }
                scores2[scores1.length] = regionLen;
                scores2[scores1.length + 1] = weightedRef;
                scores2[scores1.length + 2] = takeLog ? weightedRef * regionLen : weightedRef * regionLen;
                if (useControlPredictor) {
                    double weightedControlCount;
                    if (takeLog) {
                        weightedControlCount = (double)gr.getInteractedMutNumControl(scoreBinCut, cutoffModel) + 1.0E-8;
                        weightedControlCount = Math.log(weightedControlCount);
                    } else {
                        weightedControlCount = (double)gr.getInteractedMutNumControl(scoreBinCut, cutoffModel) * 1.0;
                    }
                    scores2[scores2.length - 1] = weightedControlCount;
                }
            } else {
                scores2 = new double[scores1.length + 1];
                System.arraycopy(scores1, 0, scores2, 1, scores1.length);
                scores2[0] = weightedCaseCount;
            }
            if (scores2.length < requiredScoreNum + 1) {
                SetupApplication.GlobalLogger.warn("Gene/Region {} has missed scores!", (Object)gr.getAllRegionIDsFull());
                continue;
            }
            scoreListTrunc.add(scores2);
            regionLabelTrunc.add(gr.getAllRegionIDsFull());
        }
        if (!adjustAF && !scoreListTrunc.isEmpty()) {
            RobustRegression.INSTANCE.removeOutlierRow(regionLabelTrunc, scoreListTrunc, 1, 1);
        }
        int sigGeneNum = 0;
        double[] countsMatrixOrg = null;
        if (regionLabelTrunc.size() < this.minSampleSize) {
            String info = "The number of genes with non-zero mutation counts is  " + regionLabelTrunc.size() + " at truncation point " + truncPoint + ". The analysis of RUNNER is stopped.";
            SetupApplication.GlobalLogger.info(info);
            return false;
        }
        double mflc = 1.0;
        double[] meanSD = new double[2];
        DoubleList pValues = new DoubleList();
        LinkedSet sigGeneSet0 = new LinkedSet();
        LinkedSet sigGeneSet1 = new LinkedSet();
        List<double[]> geneScoreListInsig = new List<double[]>();
        geneScoreListInsig.addAll(scoreListTrunc);
        boolean testOutput = false;
        WriterStream fs = null;
        if (testOutput) {
            fs = new WriterStream(new File("./test.txt"), WriterStream.Option.DEFAULT);
            fs.writeChar("ID\tDepVarScore");
            for (String s : this.predictorNames) {
                fs.writeChar("\t" + s);
            }
            fs.write(10);
        }
        try {
            int calculateTime = 0;
            while (true) {
                double[] residueCounter;
                ++calculateTime;
                geneSize = geneScoreListInsig.size();
                int colNum = ((double[])geneScoreListInsig.get(0)).length;
                double[] countsMatrix = new double[geneSize * colNum];
                for (int j = 0; j < geneSize; ++j) {
                    double[] v = (double[])geneScoreListInsig.get(j);
                    System.arraycopy(v, 0, countsMatrix, j * v.length, v.length);
                    if (!testOutput) continue;
                    fs.writeChar(regionLabelTrunc.get(j));
                    fs.write(9);
                    fs.writeChar(String.valueOf(v[0]));
                    for (int i = 1; i < v.length; ++i) {
                        fs.writeChar("\t" + v[i]);
                    }
                    fs.write(10);
                }
                if (testOutput) {
                    fs.close();
                }
                pValues.clear();
                if (countsMatrixOrg == null) {
                    countsMatrixOrg = Arrays.copyOf(countsMatrix, countsMatrix.length);
                }
                if ((residueCounter = this.calculateResidueByR(rcon, this.predictorNames, countsMatrix, colNum, truncPoint, countsMatrixOrg, threadIdC, meanSD, adjustAF)) == null) break;
                int geneNum = residueCounter.length;
                boolean testOutput1 = false;
                if (testOutput1) {
                    fs = new WriterStream(new File("./testp.txt"), WriterStream.Option.DEFAULT);
                    fs.writeChar("p\n");
                }
                for (double value : residueCounter) {
                    double zSc = (value - meanSD[0]) / meanSD[1];
                    zSc = zSc < 0.0 ? 1.0 - Probability.normal(zSc) : Probability.normal(-zSc);
                    pValues.add(zSc);
                    if (!testOutput1) continue;
                    fs.writeChar(zSc + "\n");
                }
                if (testOutput1) {
                    fs.close();
                }
                pValues.sort();
                double adjGenePValueCutoff = Summary.benjaminiHochbergFDR(this.looseFDR, pValues);
                geneScoreListInsig.clear();
                ((AbstractCollection)sigGeneSet1).clear();
                for (int i = 0; i < geneNum; ++i) {
                    double zSc = (residueCounter[i] - meanSD[0]) / meanSD[1];
                    zSc = zSc < 0.0 ? 1.0 - Probability.normal(zSc) : Probability.normal(-zSc);
                    if (zSc <= adjGenePValueCutoff) {
                        ((AbstractCollection)sigGeneSet1).add(regionLabelTrunc.get(i));
                        continue;
                    }
                    geneScoreListInsig.add(scoreListTrunc.get(i));
                }
                if (sigGeneSet0.equals(sigGeneSet1) || ((AbstractCollection)sigGeneSet0).size() >= ((AbstractCollection)sigGeneSet1).size()) {
                    mflc = Summary.MLFC(adjGenePValueCutoff, pValues);
                    sigGeneNum = ((AbstractCollection)sigGeneSet1).size();
                    break;
                }
                ((AbstractCollection)sigGeneSet0).clear();
                sigGeneSet0.addAll(sigGeneSet1);
            }
            curTNBParam.MLFC = mflc;
            curTNBParam.sigGeneNum = sigGeneNum;
            curTNBParam.mean = meanSD[0];
            curTNBParam.sd = meanSD[1];
            ++curTNBParam.cycTime;
        }
        catch (Exception e1) {
            SetupApplication.GlobalLogger.error("Errors at scoreBinCut {}, truncation point {}!", (Object)scoreBinCut, (Object)truncPoint);
            e1.printStackTrace();
        }
        geneScoreListInsig.clear();
        return true;
    }

    public DoubleArrayList calculateFinalPValues(TNBRegressionParamSet curTNBParam) throws REngineException, REXPMismatchException {
        List<String> regionLabelTrunc = new List<String>();
        List<double[]> countScoreListTrunc = new List<double[]>();
        try (RConnection rcon = null;){
            Object v;
            rcon = new RConnection(this.rHost, this.rPort);
            rcon.eval("library(MASS)");
            rcon.eval("set.seed(123456)");
            rcon.eval(TNBRegressionRSource.INSTANCE.sourceCodeR);
            rcon.eval("set.seed(1978)");
            if (curTNBParam.weightCountAdjustModel == 1) {
                curTNBParam.cycTime = 0;
            }
            this.calculatePValuesAndObtainMFLC(rcon, regionLabelTrunc, countScoreListTrunc, curTNBParam, 1, true);
            String logTxt = rcon.eval("summarySimple.zerotrunc(m1)").asString();
            SetupApplication.GlobalLogger.info("The zero-truncated negative-binomial regression model fitted for region-based mutation counts regression:\n{}", (Object)logTxt);
            StringBuilder sb = new StringBuilder();
            int geneSize = countScoreListTrunc.size();
            int colNum = countScoreListTrunc.get(0).length;
            double[] countsMatrix = new double[geneSize * colNum];
            for (int j = 0; j < geneSize; ++j) {
                v = countScoreListTrunc.get(j);
                System.arraycopy(v, 0, countsMatrix, j * ((double[])v).length, ((double[])v).length);
            }
            sb.delete(0, sb.length());
            String strPrefix = "genemutetestF" + (int)(Math.random() * 10000.0);
            rcon.assign("valMat", countsMatrix);
            rcon.voidEval("valMat<-matrix(valMat, nrow=" + geneSize + ", ncol=" + colNum + ", byrow = TRUE)");
            rcon.voidEval(strPrefix + "<-data.frame(valMat);");
            sb.delete(0, sb.length());
            sb.append("colnames(").append(strPrefix).append(") <- c(\"DepVarScore\"");
            v = this.predictorNames.iterator();
            while (v.hasNext()) {
                String name = (String)v.next();
                sb.append(",\"").append(name).append("\"");
            }
            sb.append(")");
            rcon.voidEval(sb.toString());
            sb.delete(0, sb.length());
            double[] residueCounter = rcon.eval("residuals.zrnbFull(m1, " + strPrefix + ", truncPoint=" + curTNBParam.truncationPoint + ",type = \"" + this.residualType + "\")").asDoubles();
            int geneNum = residueCounter.length;
            THashMap<String, Double> regionPValues = new THashMap<String, Double>();
            THashMap<String, double[]> regionScoreMap = new THashMap<String, double[]>();
            DoubleArrayList pValues = new DoubleArrayList();
            for (int t = 0; t < geneNum; ++t) {
                residueCounter[t] = (residueCounter[t] - curTNBParam.mean) / curTNBParam.sd;
                regionPValues.put(regionLabelTrunc.get(t), residueCounter[t]);
                regionScoreMap.put(regionLabelTrunc.get(t), countScoreListTrunc.get(t));
            }
            geneNum = this.allRegions.size();
            for (int t = 0; t < geneNum; ++t) {
                CalcRegionSet gr = this.allRegions.get(t);
                if (!regionScoreMap.containsKey(gr.getAllRegionIDsFull())) continue;
                gr.setFinalFeatureScore((double[])regionScoreMap.get(gr.getAllRegionIDsFull()));
                Bytes comibinedID = new Bytes(gr.getAllRegionIDsFull());
                Double zSc = (Double)regionPValues.get(comibinedID.toString());
                if (zSc == null) continue;
                gr.z = zSc;
                zSc = zSc < 0.0 ? Double.valueOf(1.0 - Probability.normal(zSc)) : Double.valueOf(Probability.normal(-zSc.doubleValue()));
                gr.p = zSc;
                if (Double.isNaN(zSc)) continue;
                pValues.add(zSc);
            }
            DoubleArrayList doubleArrayList = pValues;
            return doubleArrayList;
        }
    }

    public void exploreOptimalParameters() {
        List<TNBRegressionParamSet> taskParamSetList = new List<TNBRegressionParamSet>(this.paramSetList);
        this.paramSetList.clear();
        ProgressBar bar = new ProgressBar.Builder().setTextRenderer("Truncated negative binomial regression", "sets").setInitialMax(taskParamSetList.size()).build();
        ThreadQueue threadPool = new ThreadQueue(this.nThreads);
        int j = 0;
        while (j < taskParamSetList.size()) {
            TNBRegressionParamSet curTNBParam = taskParamSetList.get(j);
            int finalJ = j++;
            threadPool.addTask((status, context) -> {
                List<String> regionLabelTrunc = new List<String>();
                List<double[]> scoreListTrunc = new List<double[]>();
                try (RConnection rcon = new RConnection(this.rHost, this.rPort);){
                    rcon.eval("library(MASS)");
                    rcon.eval("set.seed(123456)");
                    rcon.eval(TNBRegressionRSource.INSTANCE.sourceCodeR);
                    boolean isComplete = this.calculatePValuesAndObtainMFLC(rcon, regionLabelTrunc, scoreListTrunc, curTNBParam, finalJ, true);
                    if (isComplete) {
                        ThreadQueue threadQueue = threadPool;
                        synchronized (threadQueue) {
                            this.paramSetList.add(curTNBParam);
                        }
                    }
                    bar.step(1L);
                    scoreListTrunc.clear();
                    regionLabelTrunc.clear();
                }
            });
        }
        threadPool.close();
        bar.close();
        this.paramSetList.sort(Comparator.comparingDouble(o -> o.MLFC));
    }

    public double[] calculateResidueByR(RConnection rcon, List<String> scoreNames, double[] countsMatrix, int colNum, int trunc, double[] orgCountsMatrix, int threadID, double[] meanSD, boolean adjustAF) throws Exception {
        StringBuilder sb = new StringBuilder();
        int geneSize = countsMatrix.length / colNum;
        String strPrefix = "genemutetest" + threadID;
        rcon.assign("valMat", countsMatrix);
        rcon.voidEval("valMat<-matrix(valMat, nrow=" + geneSize + ", ncol=" + colNum + ", byrow = TRUE)");
        rcon.voidEval(strPrefix + "<-data.frame(valMat);");
        sb.delete(0, sb.length());
        StringBuilder headers = new StringBuilder();
        sb.append("colnames(").append(strPrefix).append(") <- c(\"DepVarScore\"");
        headers.append("DepVarScore");
        int orgScoreNum = scoreNames.size();
        for (int i = 0; i < orgScoreNum; ++i) {
            sb.append(",\"").append(scoreNames.get(i)).append("\"");
            headers.append("\t").append(scoreNames.get(i));
        }
        sb.append(")");
        rcon.voidEval(sb.toString());
        String dataFrameNames = sb.toString();
        if (this.logarithmExplanatoryVar) {
            rcon.voidEval(strPrefix + "[" + strPrefix + " <=0] <- 1E-6");
            rcon.voidEval("len<-dim(" + strPrefix + ")[2]");
            rcon.voidEval(strPrefix + "[,seq(2,len)] <- log(" + strPrefix + "[,seq(2,len)])");
        }
        sb.delete(0, sb.length());
        sb.append("m1 <- zerotrunc(DepVarScore ~ ");
        sb.append(scoreNames.get(0));
        for (int i = 1; i < orgScoreNum; ++i) {
            sb.append("+").append(scoreNames.get(i));
        }
        sb.append(", data = ").append(strPrefix).append(", dist=\"negbin-extended\", truncPoint=").append(trunc).append(")");
        double[] residueCounter = null;
        try {
            rcon.voidEval(sb.toString());
            if (this.runStepwise) {
                if (adjustAF) {
                    rcon.voidEval(" m1 <- stepAIC(m1,direction = \"both\",trace = 0)");
                } else {
                    rcon.voidEval(" m1 <- stepAIC(m1,direction = \"backward\",trace = 0)");
                }
            }
            if (orgCountsMatrix != null) {
                geneSize = orgCountsMatrix.length / colNum;
                strPrefix = "genemutetest" + threadID;
                rcon.assign("valMat", orgCountsMatrix);
                rcon.voidEval("valMat<-matrix(valMat, nrow=" + geneSize + ", ncol=" + colNum + ", byrow = TRUE)");
                rcon.voidEval(strPrefix + "<-data.frame(valMat);");
                rcon.voidEval(dataFrameNames);
                residueCounter = rcon.eval("residuals.zrnbFull(m1, " + strPrefix + ", truncPoint=" + trunc + ",type = \"" + this.residualType + "\")").asDoubles();
            }
            if (adjustAF) {
                double mean = Summary.mean(residueCounter);
                double sd = Summary.sd(residueCounter);
                meanSD[0] = mean;
                meanSD[1] = sd;
            } else {
                double[] residueCounterTMP = rcon.eval("residuals.zerotrunc(m1,type = \"" + this.residualType + "\", truncPoint=" + trunc + ")").asDoubles();
                double[] weights = new double[residueCounterTMP.length];
                Arrays.fill(weights, Double.NaN);
                boolean success = RobustRegression.INSTANCE.iterativeWeighter(residueCounterTMP, weights, 100);
                if (success) {
                    meanSD[0] = Summary.mean(residueCounterTMP, weights, true);
                    meanSD[1] = Summary.stddev(residueCounterTMP, weights);
                } else {
                    meanSD[0] = Summary.mean(residueCounterTMP);
                    meanSD[1] = Summary.sd(residueCounterTMP);
                }
                boolean bl = false;
            }
        }
        catch (RserveException ex) {
            SetupApplication.GlobalLogger.error(sb.toString());
            geneSize = countsMatrix.length / colNum;
            WriterStream fs = new WriterStream(new File("./test" + threadID + ".txt"), WriterStream.Option.DEFAULT);
            fs.writeChar(headers.toString());
            fs.write(10);
            for (int i = 0; i < geneSize; ++i) {
                fs.writeChar(String.valueOf(countsMatrix[i * colNum]));
                for (int j = 1; j < colNum; ++j) {
                    fs.write(9);
                    fs.writeChar(String.valueOf(countsMatrix[i * colNum + j]));
                }
                fs.write(10);
            }
            fs.close();
            ex.printStackTrace();
        }
        return residueCounter;
    }
}

