/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.analysis;

import cern.colt.list.DoubleArrayList;
import edu.sysu.pmglab.analysis.CalcRegionSet;
import edu.sysu.pmglab.analysis.DriverType;
import edu.sysu.pmglab.analysis.GenomeRegion;
import edu.sysu.pmglab.analysis.TNBRegressionParamExplorer;
import edu.sysu.pmglab.analysis.TNBRegressionParamSet;
import edu.sysu.pmglab.container.indexable.IndexableSet;
import edu.sysu.pmglab.container.list.List;
import edu.sysu.pmglab.io.writer.WriterStream;
import edu.sysu.pmglab.kgga.command.SetupApplication;
import edu.sysu.pmglab.plot.PValuePainter;
import edu.sysu.pmglab.stat.DynamicScanWindows;
import edu.sysu.pmglab.stat.Summary;
import edu.sysu.pmglab.utils.ValueUtils;
import gnu.trove.map.hash.THashMap;
import java.io.File;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Arrays;
import java.util.Comparator;

public class TNBRegression {
    List<CalcRegionSet> regions;
    InetSocketAddress rServer;
    List<String> addedValueFields;
    String residualType = "pearson";
    boolean runStepwise = true;
    int minSampleSize = 200;
    DriverType driverType = DriverType.REGION;
    int minMutCount = 2;
    int nThreads = 1;
    boolean adjustAF;
    boolean containRef;
    boolean iRunner;

    public List<CalcRegionSet> getRegions() {
        return this.regions;
    }

    public void setRegions(List<CalcRegionSet> regions) {
        this.regions = regions;
    }

    public void setFieldLabels(List<String> accuValueFields) {
        this.addedValueFields = accuValueFields;
    }

    public TNBRegression setMinSampleSize(int minSampleSize) {
        this.minSampleSize = minSampleSize;
        return this;
    }

    public int getMinMutCount() {
        return this.minMutCount;
    }

    public TNBRegression setMinMutCount(int minMutCount) {
        this.minMutCount = minMutCount;
        return this;
    }

    public TNBRegression setResidualType(String residualType) {
        this.residualType = residualType;
        return this;
    }

    public TNBRegression setRunStepwise(boolean runSetwise) {
        this.runStepwise = runSetwise;
        return this;
    }

    public TNBRegression setRServer(InetSocketAddress address) {
        this.rServer = address;
        return this;
    }

    public TNBRegression setDriverType(DriverType driverType) {
        this.driverType = ValueUtils.getOrDefault(driverType, DriverType.REGION);
        return this;
    }

    public void produceRegionPValues(IndexableSet<String> geneSymbolMap, boolean considerFunVar, String outPathPrefix, int weightCountAdjustModel, double caseControlRatio) throws Exception {
        int tn;
        int s;
        if (this.regions.size() < this.minSampleSize) {
            String info = "The number of genes or regions with variants is only " + this.regions.size() + " (<" + this.minSampleSize + "), insufficient for the regression analysis. The process aborted!";
            SetupApplication.GlobalLogger.error(info);
            return;
        }
        double looseFDR = 0.5;
        List<TNBRegressionParamSet> taskParamSetList = new List<TNBRegressionParamSet>();
        int truncatedNum = 5;
        if (this.adjustAF) {
            looseFDR = 0.8;
            int scoreBinNum = 10;
            double scoreBinLen = 0.05;
            if (!considerFunVar) {
                scoreBinLen = 0.0;
                scoreBinNum = 1;
                truncatedNum = 10;
            }
            for (s = 1; s < scoreBinNum; ++s) {
                double scoreBinCut = scoreBinLen * (double)s;
                for (tn = 0; tn < truncatedNum; ++tn) {
                    TNBRegressionParamSet paraSet = new TNBRegressionParamSet(scoreBinCut, tn);
                    paraSet.weightCountAdjustModel = weightCountAdjustModel;
                    paraSet.adjustAF = true;
                    taskParamSetList.add(paraSet);
                }
            }
        } else {
            int scoreBinNum = 20;
            double scoreBinLen = 0.025;
            if (!considerFunVar) {
                scoreBinLen = 0.0;
                scoreBinNum = 1;
                truncatedNum = 10;
            }
            taskParamSetList = new List();
            for (s = 0; s < scoreBinNum; ++s) {
                double scoreBinCut = scoreBinLen * (double)s;
                for (tn = 0; tn < truncatedNum; ++tn) {
                    TNBRegressionParamSet paraSet = new TNBRegressionParamSet(scoreBinCut, tn);
                    paraSet.weightCountAdjustModel = weightCountAdjustModel;
                    paraSet.iRunner = this.isIRunner();
                    paraSet.useControlMutPredictor = this.isContainRef();
                    taskParamSetList.add(paraSet);
                }
            }
        }
        TNBRegressionParamExplorer tnbExplorer = new TNBRegressionParamExplorer(this.rServer, caseControlRatio);
        tnbExplorer.setGeneSymbolMap(geneSymbolMap);
        tnbExplorer.setRunStepwise(this.runStepwise).setLooseFDR(looseFDR).setResidualType(this.residualType).setThreads(this.nThreads);
        tnbExplorer.setDataSet(this.regions, this.addedValueFields).setMinSampleSize(this.minSampleSize);
        tnbExplorer.setMinMutCount(this.minMutCount);
        tnbExplorer.setParamSetList(taskParamSetList);
        tnbExplorer.exploreOptimalParameters();
        if (taskParamSetList.isEmpty()) {
            SetupApplication.GlobalLogger.error("Failed to fit the truncated negative binomial regression model");
            return;
        }
        SetupApplication.GlobalLogger.info("Best standardized score bin: {}; Optimal truncation point: {};  MLFC:{}", taskParamSetList.get((int)0).scoreBinCut, taskParamSetList.get((int)0).truncationPoint, taskParamSetList.get((int)0).MLFC);
        DoubleArrayList pValues = tnbExplorer.calculateFinalPValues(taskParamSetList.get(0));
        this.drawQQPlot(pValues, outPathPrefix);
        this.savePValuesInTSVFile(taskParamSetList.get(0), pValues, geneSymbolMap, outPathPrefix);
        pValues.clear();
    }

    private void drawQQPlot(DoubleArrayList pVlaues, String outPathPrefix) throws Exception {
        PValuePainter painter = new PValuePainter(300, 270);
        List<String> names = new List<String>();
        names.add("");
        List<DoubleArrayList> pValueList = new List<DoubleArrayList>();
        pValueList.add(pVlaues);
        String outputPath = outPathPrefix + ".qq.pdf";
        painter.drawMultipleQQPlotPDF(pValueList, names, null, outputPath, 1.0E-10);
        SetupApplication.GlobalLogger.info("The QQ plot of p-values is generated at " + outputPath);
    }

    private void savePValuesInTSVFile(TNBRegressionParamSet paramSet, DoubleArrayList pValues, IndexableSet<String> geneSymbolMap, String outPathPrefix) throws IOException {
        this.regions.sort(Comparator.comparingDouble(o -> o.p));
        File outpath = new File(outPathPrefix + ".txt");
        WriterStream fs = new WriterStream(outpath, WriterStream.Option.DEFAULT);
        DoubleArrayList qValues = new DoubleArrayList();
        int addedGeneScoreNum = this.addedValueFields.size() + 1;
        boolean adjustAF = paramSet.adjustAF;
        boolean controlCountAdjust = paramSet.useControlMutPredictor;
        controlCountAdjust = true;
        double adjGenePValueCutoff = Summary.benjaminiHochbergFDR(0.05, pValues, qValues);
        fs.writeChar("#MLFC\t" + paramSet.MLFC + "\n");
        fs.writeChar("#scoreBinCut\t" + paramSet.scoreBinCut + "\n");
        fs.writeChar("#truncationPoint\t" + paramSet.truncationPoint + "\n");
        fs.writeChar("ID\tRegion\tChromosome\tStartPosition\tEndPosition\tCaseUnweightedMutationCounts");
        for (String str : this.addedValueFields) {
            fs.writeChar("\t" + str);
        }
        if (controlCountAdjust) {
            fs.writeChar("\tControlUnweightedMutationCounts");
        }
        if (adjustAF) {
            fs.writeChar("\tUnweightedMutationCountsDiff");
        }
        fs.writeChar("\tz\tp\tFDRq\n");
        int size = qValues.size();
        THashMap<Integer, String> geneBroadRegionCodeMap = new THashMap<Integer, String>();
        geneBroadRegionCodeMap.put(0, "Exons");
        geneBroadRegionCodeMap.put(1, "UTR5");
        geneBroadRegionCodeMap.put(2, "UTR3");
        geneBroadRegionCodeMap.put(3, "Upstream");
        geneBroadRegionCodeMap.put(4, "Downstream");
        int intronStartID = 10;
        int sigNum = 0;
        StringBuilder regionStr = new StringBuilder();
        StringBuilder sb = new StringBuilder();
        boolean cutoffModel = true;
        for (int i = 0; i < size; ++i) {
            CalcRegionSet gr = this.regions.get(i);
            double[] addedValues = gr.getFinalFeatureScore();
            if (addedValues == null) continue;
            String label = gr.getAllRegionIDs();
            if (label.contains(";")) {
                String[] splitLabel = label.split(";");
                for (int j = 0; j < splitLabel.length; ++j) {
                    sb.append(geneSymbolMap.valueOf(Integer.parseInt(splitLabel[j])));
                    sb.append(";");
                }
                fs.writeChar(sb.substring(0, sb.length() - 1));
                sb.delete(0, sb.length());
            } else {
                fs.writeChar(geneSymbolMap.valueOf(Integer.parseInt(label)));
            }
            fs.write(9);
            int[] regionTypes = gr.getTypes();
            DynamicScanWindows scan = DynamicScanWindows.getInstance();
            regionStr.delete(0, regionStr.length());
            for (int regionType : regionTypes) {
                String region;
                if (scan.isUseScan()) {
                    if (regionType < 5) {
                        region = geneBroadRegionCodeMap.get(regionType) != null ? (String)geneBroadRegionCodeMap.get(regionType) : String.valueOf(regionType);
                    } else if (regionType > intronStartID && regionType / scan.getScale() == 0) {
                        region = "Intron" + (regionType - intronStartID);
                    } else if (regionType / scan.getScale() > 0) {
                        int windowIndex = regionType / scan.getScale();
                        int typeIndex = regionType % scan.getScale();
                        region = typeIndex < 5 ? (geneBroadRegionCodeMap.get(typeIndex) != null ? (String)geneBroadRegionCodeMap.get(typeIndex) + ":Window" + windowIndex : regionType + ":Window" + windowIndex) : (typeIndex >= intronStartID ? "Intron" + (typeIndex - intronStartID) + ":Window" + windowIndex : "Unknown");
                    } else {
                        region = "Unknown";
                    }
                } else {
                    region = regionType < 5 ? (geneBroadRegionCodeMap.get(regionType) != null ? (String)geneBroadRegionCodeMap.get(regionType) : String.valueOf(regionType)) : (regionType > intronStartID ? "Intron" + (regionType - intronStartID) : "Unknown");
                }
                regionStr.append(region);
                regionStr.append(';');
            }
            fs.writeChar(regionStr.substring(0, regionStr.length() - 1));
            fs.write(9);
            fs.writeChar(gr.getChromID());
            fs.write(9);
            fs.writeChar(String.valueOf(gr.getStart()));
            fs.write(9);
            fs.writeChar(String.valueOf(gr.getEnd()));
            fs.write(9);
            fs.writeChar(String.valueOf(gr.getInteractedMutNumCase(-1.0, cutoffModel)));
            for (int j = 1; j < addedGeneScoreNum; ++j) {
                fs.writeChar("\t" + addedValues[j]);
            }
            if (controlCountAdjust) {
                fs.writeChar("\t" + gr.getInteractedMutNumControl(-1.0, cutoffModel));
            }
            if (adjustAF) {
                fs.writeChar("\t" + gr.getDiff(-1.0, cutoffModel));
            }
            fs.write(9);
            fs.writeChar(String.format("%.3g", gr.z));
            fs.write(9);
            fs.writeChar(String.format("%.3g", gr.p));
            fs.write(9);
            fs.writeChar(String.format("%.3g", qValues.getQuick(i)));
            fs.write(10);
            if (!(qValues.getQuick(i) <= adjGenePValueCutoff)) continue;
            ++sigNum;
        }
        fs.close();
        SetupApplication.GlobalLogger.info("The p-values of genomic regions are generated at {}.", (Object)outpath);
    }

    public void removeDuplicateRegionSet() {
        SetupApplication.GlobalLogger.info("Sorting regions.");
        RegionComparator comparator = new RegionComparator();
        this.regions.sort(comparator);
        List<String> tmpArray = new List<String>();
        SetupApplication.GlobalLogger.info("Removing duplicate regions.");
        for (int i = 0; i < this.regions.size(); ++i) {
            CalcRegionSet calcRegionSet1 = this.regions.get(i);
            GenomeRegion r1 = calcRegionSet1.regions.get(0);
            for (int j = i + 1; j < this.regions.size(); ++j) {
                CalcRegionSet calcRegionSet2 = this.regions.get(j);
                GenomeRegion r2 = calcRegionSet2.regions.get(0);
                if (comparator.compare(calcRegionSet1, calcRegionSet2) != 0) break;
                if (!Arrays.equals(r1.getRegionScores(), r2.getRegionScores())) continue;
                tmpArray.add(r2.getLabel());
                this.regions.remove(calcRegionSet2);
            }
            if (tmpArray.isEmpty()) continue;
            List<String> symbolArray = new List<String>();
            symbolArray.addAll(tmpArray);
            r1.setSameRegionsLabel(symbolArray);
            tmpArray.clear();
        }
    }

    public void setAdjustAF(boolean adjustAF) {
        this.adjustAF = adjustAF;
    }

    public boolean getAdjust() {
        return this.adjustAF;
    }

    public void setUseControlMutPredictor(boolean containRef) {
        this.containRef = containRef;
    }

    public boolean isContainRef() {
        return this.containRef;
    }

    public void setIsIRunner(boolean iRunner) {
        this.iRunner = iRunner;
    }

    public boolean isIRunner() {
        return this.iRunner;
    }

    public TNBRegression setThreads(int threads) {
        this.nThreads = threads;
        return this;
    }

    private static class RegionComparator
    implements Comparator<CalcRegionSet> {
        boolean useScan = DynamicScanWindows.getInstance().isUseScan();
        int scale = DynamicScanWindows.getInstance().getScale();

        private RegionComparator() {
        }

        @Override
        public int compare(CalcRegionSet g1List, CalcRegionSet g2List) {
            GenomeRegion g1 = g1List.regions.get(0);
            GenomeRegion g2 = g1List.regions.get(0);
            int chromosomeComparison = Integer.compare(g1.getChromID().getIndex(), g2.getChromID().getIndex());
            if (chromosomeComparison != 0) {
                return chromosomeComparison;
            }
            int startComparison = Integer.compare(g1.getStart(), g2.getStart());
            if (startComparison != 0) {
                return startComparison;
            }
            if (this.useScan) {
                return Integer.compare(g1.getType() % this.scale, g2.getType() % this.scale);
            }
            return Integer.compare(g1.getType(), g2.getType());
        }
    }
}

