/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.kgga.command.task;

import cern.jet.stat.Gamma;
import edu.sysu.pmglab.ccf.field.FieldGroupMeta;
import edu.sysu.pmglab.ccf.field.IFieldCollection;
import edu.sysu.pmglab.ccf.record.BoxRecord;
import edu.sysu.pmglab.ccf.toolkit.annotator.PointerDatabase;
import edu.sysu.pmglab.ccf.type.FieldType;
import edu.sysu.pmglab.commandParser.exception.ParameterException;
import edu.sysu.pmglab.container.indexable.IndexableSet;
import edu.sysu.pmglab.container.indexable.LinkedSet;
import edu.sysu.pmglab.container.list.DoubleList;
import edu.sysu.pmglab.container.list.FloatList;
import edu.sysu.pmglab.container.list.IntList;
import edu.sysu.pmglab.executor.Context;
import edu.sysu.pmglab.executor.ITask;
import edu.sysu.pmglab.executor.Status;
import edu.sysu.pmglab.executor.ThreadQueue;
import edu.sysu.pmglab.executor.track.ITrack;
import edu.sysu.pmglab.gtb.GTBManager;
import edu.sysu.pmglab.gtb.GTBReader;
import edu.sysu.pmglab.gtb.GTBReaderOption;
import edu.sysu.pmglab.gtb.GTBWriter;
import edu.sysu.pmglab.gtb.filter.GTBFilter;
import edu.sysu.pmglab.gtb.genome.Variant;
import edu.sysu.pmglab.gtb.genome.coordinate.Coordinate;
import edu.sysu.pmglab.gtb.genome.genotype.Genotype;
import edu.sysu.pmglab.gtb.genome.genotype.IGenotypes;
import edu.sysu.pmglab.gtb.genome.genotype.counter.ICounter;
import edu.sysu.pmglab.gtb.toolkit.GTBAnnotator;
import edu.sysu.pmglab.gtb.toolkit.GTBIndexer;
import edu.sysu.pmglab.io.FileUtils;
import edu.sysu.pmglab.io.text.TextRecord;
import edu.sysu.pmglab.io.text.writer.IHeaderFormatter;
import edu.sysu.pmglab.io.text.writer.TextWriter;
import edu.sysu.pmglab.kgga.command.SetupApplication;
import edu.sysu.pmglab.kgga.command.TaskTracker;
import edu.sysu.pmglab.kgga.command.Utility;
import edu.sysu.pmglab.kgga.command.pipeline.GeneralIOOptions;
import edu.sysu.pmglab.kgga.command.pipeline.PhenoPredictionOptions;
import edu.sysu.pmglab.kgga.command.python.toolkit.DirectNumpyBinaryArray;
import edu.sysu.pmglab.kgga.command.python.toolkit.DirectNumpyGenotypeArray;
import edu.sysu.pmglab.kgga.command.setting.LeadOptionsSet;
import edu.sysu.pmglab.kgga.command.setting.PhenoFilesSet;
import edu.sysu.pmglab.kgga.io.GlobalPedIndividuals;
import edu.sysu.pmglab.kgga.io.InputOutputFileSet;
import edu.sysu.pmglab.pyserve.GlobalPythonInterpreter;
import edu.sysu.pmglab.stat.ContingencyTable;
import edu.sysu.pmglab.utils.Assert;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.AbstractCollection;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import jep.DirectNDArray;
import jep.MainInterpreter;

public class TrainLEAPModelTask
implements ITask {
    File outputDir;
    File outputFile;
    File gtyFile;
    GTBManager manager;
    int threads;
    File validateSavePath;
    File testSavePath;
    File modelSavePath;
    File variantSavePath;
    File predictResultPath;
    GeneralIOOptions generalIOOptions;
    LeadOptionsSet leadOptionsSet;
    PhenoFilesSet phenoFilesSet;
    String[] pFieldLabels;

    public static void main(String[] args) throws Exception {
        SetupApplication.execute("predict", "-i", "/public2/ukb/wgs/comm/variants.maf05.hg38.gtb", "--ped-file", "./AD/merged_all.fam", "--sum-file", "file=./AD/merged_filter_standard_p_raw.tsv", "cp12Cols=CHROM,POS", "pbsCols=P", "refG=hg19", "--assign-sample", "trainingSample=./AD/merged_data_with_age_sex.fam", "testingSample=./AD/merged_test_data_with_age_sex.fam", "pheno=disease", "covar=age,sex", "--elag", "maxEpoch=20", "crossFold=5", "baggingNumber=10", "variantSampleSize=300:1000:200", "impute=y", "downSample=y", "permutNum=100", "--threads", "32", "-o", "./AD/results");
    }

    public TrainLEAPModelTask(GeneralIOOptions generalIOOptions, String[] pFieldLabels, PhenoPredictionOptions phenoPredictionOptions, File outputDir, boolean makeDir) {
        if (makeDir && !(outputDir = FileUtils.getSubFile(outputDir, this.getClass().getSimpleName())).exists()) {
            outputDir.mkdirs();
        }
        this.outputDir = outputDir;
        this.outputFile = FileUtils.getSubFile(outputDir, InputOutputFileSet.getAnnotationFileName());
        this.gtyFile = FileUtils.getSubFile(outputDir, InputOutputFileSet.getAnnotationGtyFileName());
        this.validateSavePath = FileUtils.getSubFile(outputDir, "validate_metrics.tsv");
        this.testSavePath = FileUtils.getSubFile(outputDir, "test_metrics.tsv");
        this.modelSavePath = FileUtils.getSubFile(outputDir, "model_save_list.pth");
        this.variantSavePath = FileUtils.getSubFile(outputDir, "model_need_variant.tsv");
        this.predictResultPath = FileUtils.getSubFile(outputDir, "predict_result.tsv");
        this.generalIOOptions = generalIOOptions;
        this.threads = generalIOOptions.threads;
        this.leadOptionsSet = phenoPredictionOptions.leadOptionsSet;
        this.phenoFilesSet = phenoPredictionOptions.phenoFilesSet;
        this.pFieldLabels = pFieldLabels;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public void execute(Status status, Context context) throws Exception, Error {
        File inputFile = (File)context.cast("AnnotationBaseVariantSet");
        boolean ignoreGty = this.leadOptionsSet.isIgnoreGty();
        if (ignoreGty && this.phenoFilesSet.covariableNames == null) {
            throw new Exception("ignoreGty=yes but covariate is null!");
        }
        MainInterpreter.setJepLibraryPath(this.getJepPath());
        GlobalPythonInterpreter interpreter = new GlobalPythonInterpreter();
        int K = this.leadOptionsSet.getCrossFold();
        int n = this.leadOptionsSet.getVariantSampleTimes();
        int epoch = this.leadOptionsSet.getEpoch();
        IntList selectSize = this.leadOptionsSet.getVariantSampleSize();
        interpreter.setValue("java_K", K);
        interpreter.setValue("java_n", n);
        interpreter.setValue("ignoreGty", ignoreGty);
        long seed = this.leadOptionsSet.getSeed();
        interpreter.setValue("seed", seed);
        interpreter.setValue("threads", Math.min(Runtime.getRuntime().availableProcessors(), Math.max(this.threads, 1)));
        Map<String, String> scripts = this.getLEAPScriptFromPython();
        interpreter.exec(scripts.get("import"));
        interpreter.exec(scripts.get("function"));
        double[] labels = this.phenoFilesSet.trainPhenotypes.toArray();
        interpreter.setValue("labels", labels);
        if (this.leadOptionsSet.isDownSample()) {
            interpreter.setValue("down_sample", true);
        } else {
            interpreter.setValue("down_sample", false);
        }
        interpreter.exec(scripts.get("sample"));
        int[][][] folds = interpreter.getValue("folds", int[][][].class);
        IndexableSet<String> testDownSampleUIDs = this.phenoFilesSet.testUIDs;
        DoubleList testDownSamplePhenotypes = this.phenoFilesSet.testPhenotypes;
        double[][] testDownSampleCovariables = this.phenoFilesSet.testCovariables;
        if (this.leadOptionsSet.isDownSample()) {
            int[] index;
            testDownSampleUIDs = new LinkedSet<String>();
            testDownSamplePhenotypes = new DoubleList();
            interpreter.setValue("testLabels", this.phenoFilesSet.testPhenotypes.toArray());
            interpreter.eval("index = balanced_downsample_indices(testLabels)");
            for (int i : index = interpreter.getValue("index", int[].class)) {
                testDownSampleUIDs.add(this.phenoFilesSet.testUIDs.valueOf(i));
                testDownSamplePhenotypes.add((int)this.phenoFilesSet.testPhenotypes.fastGet(i));
            }
            if (this.phenoFilesSet.covariableNames != null) {
                testDownSampleCovariables = new double[this.phenoFilesSet.covariableNames.length][index.length];
                for (int covIndex = 0; covIndex < this.phenoFilesSet.covariableNames.length; ++covIndex) {
                    testDownSampleCovariables[covIndex] = new double[index.length];
                    for (int idx = 0; idx < index.length; ++idx) {
                        testDownSampleCovariables[covIndex][idx] = this.phenoFilesSet.testCovariables[covIndex][index[idx]];
                    }
                }
            }
        }
        final int permutNum = this.leadOptionsSet.getPermutNum();
        interpreter.setValue("permutation", permutNum);
        if (!ignoreGty) {
            TaskTracker.TaskResult completeTaskResult = new TaskTracker.TaskResult(this.getClass().getName(), Utility.MD5File(inputFile), this.digest(inputFile));
            Optional<File> outputPathOpt = SetupApplication.GlobalTaskTracker.checkTask(completeTaskResult);
            if (outputPathOpt.isPresent()) {
                this.gtyFile = outputPathOpt.get();
                SetupApplication.GlobalLogger.info("Use existing gty file: {}.", (Object)this.gtyFile.getAbsoluteFile());
            } else {
                this.generateGtyFile((File)context.cast("OutputGTYFile"), (File)context.cast("AnnotationBaseVariantSet"), this.gtyFile, this.threads);
                completeTaskResult.setOutputPath(this.gtyFile);
                SetupApplication.GlobalTaskTracker.recordTaskCompletion(completeTaskResult);
            }
            this.manager = new GTBManager(this.gtyFile);
            SetupApplication.GlobalLogger.info("{} variants will be used for downstream model training.", (Object)this.manager.numOfVariants());
            double[] summaryPValues = new double[(int)this.manager.numOfVariants()];
            Arrays.fill(summaryPValues, -999.0);
            int validCount = 0;
            int pointer = 0;
            if (this.pFieldLabels != null) {
                Variant spVar;
                GTBReader sP = new GTBReader(this.manager);
                while ((spVar = sP.read()) != null) {
                    DoubleList pValues = new DoubleList();
                    for (String pFieldName : this.pFieldLabels) {
                        double p = (Double)spVar.getProperty(pFieldName);
                        if (Double.isNaN(p)) continue;
                        if (p < 1.0E-200) {
                            pValues.add(200.0);
                            continue;
                        }
                        pValues.add(-Math.log10(p));
                    }
                    if (pValues.size() == 1) {
                        summaryPValues[pointer] = pValues.fastGet(0);
                        ++validCount;
                    } else if (pValues.size() > 1) {
                        double sumLogP = 0.0;
                        for (int i = 0; i < pValues.size(); ++i) {
                            sumLogP += Math.log(pValues.fastGet(i));
                        }
                        double chiSquared = -2.0 * sumLogP;
                        int degreesOfFreedom = 2 * pValues.size();
                        summaryPValues[pointer] = Gamma.incompleteGammaComplement((double)degreesOfFreedom / 2.0, chiSquared / 2.0);
                        ++validCount;
                    }
                    ++pointer;
                }
                sP.close();
                SetupApplication.GlobalLogger.info("{}% was used as the external p-value.", (Object)String.format("%.2f", (double)validCount / (double)this.manager.numOfVariants() * 100.0));
            }
            interpreter.setValue("java_summary_pvalues", summaryPValues);
            if (!this.leadOptionsSet.getFunctionScore().isEmpty()) {
                Variant variant;
                if (!this.manager.containsField(this.leadOptionsSet.getFunctionScore())) throw new ParameterException("Field not found: " + this.leadOptionsSet.getFunctionScore());
                GTBReader reader = new GTBReader(this.manager);
                FloatList caddScore = new FloatList();
                while ((variant = reader.read()) != null) {
                    caddScore.add(((Float)variant.getProperty(this.leadOptionsSet.getFunctionScore())).floatValue());
                }
                Assert.that((long)caddScore.size() == this.manager.numOfVariants(), "caddScore.size() != variants.size() !");
                reader.close();
                interpreter.setValue("java_cadd_score", caddScore.toArray());
                interpreter.setValue("java_function_score_type", "cadd");
                interpreter.exec("cadd_score = np.array(java_cadd_score)\n# \u8ba1\u7b97\u5747\u503c\uff08NaN\u4f1a\u88ab\u81ea\u52a8\u5ffd\u7565\uff09\nmean_value = np.nanmean(cadd_score)\n# \u7528\u5747\u503c\u586b\u5145NaN\ncadd_score_filled = np.where(np.isnan(cadd_score), mean_value, cadd_score)\n# cadd_score_filled = MinMaxScaler().fit_transform(cadd_score_filled.reshape(-1, 1))\ncadd_score = torch.tensor(cadd_score_filled.reshape(-1))\ncadd_score = (torch.tanh(cadd_score) + 1) / 2");
            } else {
                interpreter.setValue("java_function_score_type", "normal");
            }
            double[][] allPValues = this.calAllPValues(K, folds, this.manager, this.phenoFilesSet);
            DirectNDArray[][] testX = new DirectNDArray[K][n];
            DirectNumpyBinaryArray array = new DirectNumpyBinaryArray(testDownSampleUIDs.size());
            if (!GlobalPedIndividuals.isBinaryPhenotypes(0)) {
                throw new Exception("Error! Phenotype currently only supports binary values.");
            }
            int[] phenotypes = Arrays.stream(testDownSamplePhenotypes.toArray()).mapToInt(value -> (int)value).toArray();
            array.setRange(0, phenotypes);
            DirectNDArray testY = array.getAsDirectNDArray();
            Coordinate[][][] advancedCoords = new Coordinate[K][n][];
            int bestIndex = 0;
            double bestRoc = 0.0;
            for (int s2 = 0; s2 < selectSize.size(); ++s2) {
                int sampleSize = selectSize.fastGet(s2);
                Coordinate[][][] advancedCoordsTemp = new Coordinate[K][n][sampleSize];
                SetupApplication.GlobalLogger.info("Start model training and testing. (SampleSize: {}, Epochs: {}, Folds: {}, Samples: {})", sampleSize, epoch, K, n);
                double roc = this.train(epoch, K, n, allPValues, sampleSize, folds, testDownSampleUIDs, testDownSampleCovariables, advancedCoordsTemp, testX, testY, interpreter, scripts);
                if (roc > bestRoc) {
                    bestIndex = s2;
                    bestRoc = roc;
                    advancedCoords = advancedCoordsTemp;
                    interpreter.eval("global_models_save_list = models_save_list");
                    interpreter.eval("global_x_train_all = X_train_all");
                    interpreter.eval("global_y_train_all = y_train_all");
                    interpreter.eval("global_train_metrics_df = train_df.copy()");
                    interpreter.eval("global_test_metrics_df = test_df.copy()");
                }
                interpreter.eval("train_df = train_df.iloc[0:0]");
                interpreter.eval("test_df = test_df.iloc[0:0]");
                interpreter.eval("probabilities_list = []");
                interpreter.eval("params_group = [\n    {\n        'beta': torch.nn.Parameter(torch.nn.init.constant_(torch.empty(1), 1)),\n        'gamma': torch.nn.Parameter(torch.nn.init.uniform_(torch.empty(1), 0, 1)),\n    }\n    for _ in range(K)\n]");
                interpreter.eval("best_models_auc_roc = [0 for _ in range(K)]");
                interpreter.eval("models_save_list = [0 for _ in range(K)]");
            }
            int bestSampleSize = selectSize.fastGet(bestIndex);
            SetupApplication.GlobalLogger.info("{} is the optimal number of sampling variants for the model, with the corresponding ROC score: {}", (Object)bestSampleSize, (Object)bestRoc);
            if (this.phenoFilesSet.needPredict) {
                DirectNDArray[][] X_predict = new DirectNDArray[K][n];
                GTBManager predictManager = new GTBManager(this.phenoFilesSet.predictingGtyFile);
                if (predictManager.getIndexer() == null) {
                    GTBIndexer.setInput(this.phenoFilesSet.predictingGtyFile, new String[0]).save(this.threads);
                    predictManager = new GTBManager(this.phenoFilesSet.predictingGtyFile);
                }
                IndexableSet<String> predictUIDs = this.phenoFilesSet.predictUIDs;
                IndexableSet<String> predictIndividuals = predictManager.getIndividuals();
                if (predictIndividuals.isEmpty()) {
                    throw new Exception("The predicted genotype file contains no sample information!");
                }
                IntList predictUIDsIndices = predictUIDs.findIndicesIn(predictIndividuals);
                LinkedSet<String> validUIDs = new LinkedSet<String>();
                LinkedSet invalidUIDs = new LinkedSet();
                for (int i = 0; i < predictUIDsIndices.size(); ++i) {
                    int index = predictUIDsIndices.fastGet(i);
                    if (index == -1) {
                        ((AbstractCollection)invalidUIDs).add(predictUIDs.valueOf(i));
                        continue;
                    }
                    ((AbstractCollection)validUIDs).add(predictUIDs.valueOf(i));
                }
                if (!invalidUIDs.isEmpty()) {
                    SetupApplication.GlobalLogger.warn("The following samples to be predicted were not found in the predicted genotype file: {}. Subsequent predictions will ignore these samples.", (Object)invalidUIDs);
                }
                AtomicInteger notFound = new AtomicInteger(0);
                IntList validUIDsIndices = validUIDs.findIndicesIn(predictIndividuals);
                IntList trainUIDsIndices = this.phenoFilesSet.trainUIDs.findIndicesIn(GlobalPedIndividuals.getIndividuals().getUIDs());
                HashSet<Coordinate> allCoordsSet = new HashSet<Coordinate>();
                for (int currentK = 0; currentK < K; ++currentK) {
                    for (int currentN = 0; currentN < n; ++currentN) {
                        allCoordsSet.addAll(Arrays.asList(advancedCoords[currentK][currentN]));
                    }
                }
                ArrayList allCoords = new ArrayList(allCoordsSet);
                allCoords.sort(Coordinate::compareTo);
                HashMap<Coordinate, IGenotypes> predictGenotypesCache = new HashMap<Coordinate, IGenotypes>();
                HashMap<Coordinate, IGenotypes> trainGenotypesCache = new HashMap<Coordinate, IGenotypes>();
                GTBReader predictReader = new GTBReader(predictManager);
                GTBFilter predictFilter = new GTBFilter(predictManager);
                GTBReader trainReader = new GTBReader(this.manager);
                GTBFilter trainFilter = new GTBFilter(this.manager);
                for (Coordinate coord : allCoords) {
                    if (predictFilter.find(coord)) {
                        predictReader.seek(predictFilter.tell());
                        predictGenotypesCache.put(coord, predictReader.read().getGenotypes().subGenotypes(validUIDsIndices));
                        continue;
                    }
                    if (!this.leadOptionsSet.isInfer()) continue;
                    if (!trainFilter.find(coord)) throw new Exception("The feature variants used for prediction were not found in the training genotype file!");
                    trainReader.seek(trainFilter.tell());
                    trainGenotypesCache.put(coord, trainReader.read().getGenotypes().subGenotypes(trainUIDsIndices));
                }
                predictReader.close();
                predictFilter.close();
                trainReader.close();
                trainFilter.close();
                ThreadQueue threadQueue = new ThreadQueue(this.threads);
                int currentK = 0;
                while (currentK < K) {
                    int finalCurrentK = currentK++;
                    Coordinate[][][] finalAdvancedCoords = advancedCoords;
                    threadQueue.addTask((s, c) -> {
                        for (int currentN = 0; currentN < n; ++currentN) {
                            Coordinate[] coords = finalAdvancedCoords[finalCurrentK][currentN];
                            DirectNumpyGenotypeArray genotypeArray = new DirectNumpyGenotypeArray(validUIDs.size(), coords.length);
                            for (int p = 0; p < coords.length; ++p) {
                                int j;
                                Coordinate coord = coords[p];
                                IGenotypes predictGenotypes = (IGenotypes)predictGenotypesCache.get(coord);
                                if (predictGenotypes != null) {
                                    for (j = 0; j < validUIDs.size(); ++j) {
                                        genotypeArray.set(j, p, predictGenotypes.get(j).getAC());
                                    }
                                    continue;
                                }
                                notFound.incrementAndGet();
                                if (this.leadOptionsSet.isInfer()) {
                                    IGenotypes trainGenotypes = (IGenotypes)trainGenotypesCache.get(coord);
                                    ICounter counter = trainGenotypes.counter();
                                    int AA = counter.count(Genotype.of(0, 0));
                                    int Aa = counter.count(Genotype.of(0, 1)) + counter.count(Genotype.of(1, 0));
                                    int aa = counter.count(Genotype.of(1, 1));
                                    for (int j2 = 0; j2 < validUIDs.size(); ++j2) {
                                        genotypeArray.set(j2, p, ThreadLocalRandom.current().nextInt(AA + Aa + aa) < AA ? 0 : (ThreadLocalRandom.current().nextInt(Aa + aa) < Aa ? 1 : 2));
                                    }
                                    continue;
                                }
                                for (j = 0; j < validUIDs.size(); ++j) {
                                    genotypeArray.set(j, p, 3);
                                }
                            }
                            X_predict[finalCurrentK][currentN] = genotypeArray.getAsDirectNDArray();
                        }
                    });
                }
                threadQueue.close();
                if (notFound.get() > 0) {
                    SetupApplication.GlobalLogger.warn("{}% variants in the predict dataset were not found in the training dataset! Impute the genotypes of these variants based on the genotype data from the training set. The prediction accuracy may be compromised!", (Object)String.format("%.2f", (double)notFound.get() / (double)bestSampleSize / (double)K / (double)n * 100.0));
                }
                interpreter.setValue("java_predict_covars", this.phenoFilesSet.predictCovariables);
                interpreter.setValue("java_X_predict", X_predict);
                interpreter.exec(scripts.get("predict"));
                double[][] yPredictAll = interpreter.getValue("y_predict_all", double[][].class);
                TextWriter.Builder builder1 = TextWriter.setOutput(this.predictResultPath).setHeaderFormatter(IHeaderFormatter.DIRECTLY).addFields("IID");
                for (int i = 0; i < K * n; ++i) {
                    builder1.addField(String.valueOf(i + 1));
                }
                builder1.addField("mean");
                TextWriter textWriter = builder1.instance();
                for (int i = 0; i < ((AbstractCollection)validUIDs).size(); ++i) {
                    TextRecord textRecord1 = textWriter.getRecord();
                    textRecord1.set("IID", (String)((IndexableSet)validUIDs).valueOf(i));
                    double thisMean = 0.0;
                    for (int j = 0; j < K * n; ++j) {
                        double prob = yPredictAll[j][i];
                        thisMean += prob;
                        textRecord1.set(String.valueOf(j + 1), String.valueOf(prob));
                    }
                    textRecord1.set("mean", thisMean / (double)K / (double)n);
                    textWriter.write(textRecord1);
                }
                textWriter.close();
            }
            interpreter.setValue("java_train_save_path", this.validateSavePath.getAbsolutePath());
            interpreter.setValue("java_test_save_path", this.testSavePath.getAbsolutePath());
            interpreter.setValue("java_model_save_path", this.modelSavePath.getAbsolutePath());
            interpreter.exec(scripts.get("save"));
            SetupApplication.GlobalLogger.info("The model has been saved in {}\nTrain metrics has been saved in {}\nTest metrics has been saved in {}\nBest variants has been saved in {}", this.modelSavePath, this.validateSavePath, this.testSavePath, this.variantSavePath);
            String[][][] snpNames = new String[K][n][];
            for (int currentK = 0; currentK < K; ++currentK) {
                for (int currentN = 0; currentN < n; ++currentN) {
                    Coordinate[] coordinates = advancedCoords[currentK][currentN];
                    if (this.phenoFilesSet.covariableNames != null) {
                        snpNames[currentK][currentN] = new String[coordinates.length + this.phenoFilesSet.covariableNames.length];
                        for (int c2 = 0; c2 < coordinates.length; ++c2) {
                            snpNames[currentK][currentN][c2] = coordinates[c2].toString();
                        }
                        System.arraycopy(this.phenoFilesSet.covariableNames, 0, snpNames[currentK][currentN], coordinates.length, this.phenoFilesSet.covariableNames.length);
                        continue;
                    }
                    snpNames[currentK][currentN] = (String[])Arrays.stream(coordinates).map(Coordinate::toString).toArray(String[]::new);
                }
            }
            interpreter.setValue("java_all_snp_list", snpNames);
            interpreter.exec(scripts.get("get_importance"));
            final Map snpGain = interpreter.getValue("export_gains", Map.class);
            final Map snpP = interpreter.getValue("export_p", Map.class);
            final Map countUsage = interpreter.getValue("export_usages", Map.class);
            if (this.phenoFilesSet.covariableNames != null) {
                String key;
                Map.Entry entry2;
                HashSet<String> targetKeys = new HashSet<String>(Arrays.asList(this.phenoFilesSet.covariableNames));
                List sortedEntries = snpGain.entrySet().stream().filter(entry -> targetKeys.contains(entry.getKey())).sorted(Map.Entry.comparingByValue().reversed()).collect(Collectors.toList());
                int maxKeyLength = snpGain.keySet().stream().mapToInt(String::length).max().orElse(20);
                String keyLine = new String(new char[maxKeyLength]).replace('\u0000', '-');
                String valueLine = new String(new char[10]).replace('\u0000', '-');
                StringBuilder logBuilder = new StringBuilder("\nCovariable Feature Gain and Rank Comparison:\n");
                if (permutNum != 0) {
                    logBuilder.append(String.format("%-" + maxKeyLength + "s | %10s | %10s | %10s\n", "Covar", "Gain", "P", "Rank"));
                    logBuilder.append(String.format("%-" + maxKeyLength + "s-+-%10s-+-%10s-+-%10s\n", keyLine, valueLine, valueLine, valueLine));
                    for (int i = 0; i < sortedEntries.size(); ++i) {
                        entry2 = (Map.Entry)sortedEntries.get(i);
                        key = (String)entry2.getKey();
                        Double pValue = snpP.getOrDefault(key, -1.0);
                        logBuilder.append(String.format("%-" + maxKeyLength + "s | %10.4f | %10.4f | %10d\n", key, entry2.getValue(), pValue, i + 1));
                    }
                } else {
                    logBuilder.append(String.format("%-" + maxKeyLength + "s | %10s | %10s\n", "Covar", "Gain", "Rank"));
                    logBuilder.append(String.format("%-" + maxKeyLength + "s-+-%10s-+-%10s\n", keyLine, valueLine, valueLine));
                    for (int i = 0; i < sortedEntries.size(); ++i) {
                        entry2 = (Map.Entry)sortedEntries.get(i);
                        key = (String)entry2.getKey();
                        logBuilder.append(String.format("%-" + maxKeyLength + "s | %10.4f | %10d\n", key, entry2.getValue(), i + 1));
                    }
                }
                SetupApplication.GlobalLogger.info(logBuilder.toString());
            }
            final FieldGroupMeta fields = new FieldGroupMeta("LEAP");
            fields.addField("Gain", FieldType.float64);
            if (permutNum != 0) {
                fields.addField("P", FieldType.float64);
            }
            fields.addField("Count", FieldType.varInt32);
            GTBAnnotator.setInput(new GTBReaderOption(this.manager, false, true)).setOutput(this.outputFile).addMeta(this.manager.getMeta()).addDatabase(new PointerDatabase<Variant>(){

                @Override
                public Object getSource(Variant variant) {
                    return null;
                }

                @Override
                public long getPointer(Variant variant) {
                    return 0L;
                }

                @Override
                public IFieldCollection getAllFields() {
                    return fields;
                }

                @Override
                public boolean annotate(edu.sysu.pmglab.container.list.List<BoxRecord> databaseRecords, long pointer, Variant variant) {
                    String coords = variant.getCoordinate().toString();
                    if (snpGain.containsKey(coords)) {
                        variant.setProperty("LEAP@Gain", snpGain.get(coords));
                        if (permutNum != 0) {
                            variant.setProperty("LEAP@P", snpP.get(coords));
                        }
                        variant.setProperty("LEAP@Count", countUsage.get(coords));
                        return true;
                    }
                    return false;
                }
            }).submit(this.threads);
            context.put("AnnotationBaseVariantSet", this.outputFile);
        } else {
            SetupApplication.GlobalLogger.info("Covariate: {} will be used for downstream model training.", (Object)Arrays.toString(this.phenoFilesSet.covariableNames));
            DirectNumpyBinaryArray array = new DirectNumpyBinaryArray(testDownSampleUIDs.size());
            if (!GlobalPedIndividuals.isBinaryPhenotypes(0)) {
                throw new Exception("Error! Phenotype currently only supports binary values.");
            }
            int[] phenotypes = Arrays.stream(testDownSamplePhenotypes.toArray()).mapToInt(value -> (int)value).toArray();
            array.setRange(0, phenotypes);
            DirectNDArray testY = array.getAsDirectNDArray();
            this.trainWithoutGty(K, folds, testDownSampleCovariables, testY, interpreter, scripts);
            interpreter.eval("global_models_save_list = models_save_list");
            interpreter.eval("global_x_train_all = X_train_all");
            interpreter.eval("global_y_train_all = y_train_all");
            interpreter.eval("global_train_metrics_df = train_df.copy()");
            interpreter.eval("global_test_metrics_df = test_df.copy()");
            if (this.phenoFilesSet.needPredict) {
                interpreter.setValue("java_predict_covars", this.phenoFilesSet.predictCovariables);
                interpreter.exec(scripts.get("predict"));
                double[][] yPredictAll = interpreter.getValue("y_predict_all", double[][].class);
                TextWriter.Builder builder1 = TextWriter.setOutput(this.predictResultPath).setHeaderFormatter(IHeaderFormatter.DIRECTLY).addFields("IID");
                builder1.addField("mean");
                for (String covariableName : this.phenoFilesSet.covariableNames) {
                    builder1.addField(covariableName);
                }
                TextWriter textWriter = builder1.instance();
                for (int i = 0; i < this.phenoFilesSet.predictUIDs.size(); ++i) {
                    TextRecord textRecord1 = textWriter.getRecord();
                    textRecord1.set("IID", this.phenoFilesSet.predictUIDs.valueOf(i));
                    double thisMean = 0.0;
                    for (int j = 0; j < this.phenoFilesSet.covariableNames.length; ++j) {
                        double prob = yPredictAll[j][i];
                        thisMean += prob;
                        textRecord1.set(this.phenoFilesSet.covariableNames[j], String.valueOf(prob));
                    }
                    textRecord1.set("mean", thisMean / (double)K / (double)n);
                    textWriter.write(textRecord1);
                }
                textWriter.close();
            }
            interpreter.setValue("java_train_save_path", this.validateSavePath.getAbsolutePath());
            interpreter.setValue("java_test_save_path", this.testSavePath.getAbsolutePath());
            interpreter.setValue("java_model_save_path", this.modelSavePath.getAbsolutePath());
            interpreter.exec(scripts.get("save"));
            SetupApplication.GlobalLogger.info("The model has been saved in {}\nTrain metrics has been saved in {}\nTest metrics has been saved in {}\nBest variants has been saved in {}", this.modelSavePath, this.validateSavePath, this.testSavePath, this.variantSavePath);
            String[][][] snpNames = new String[K][n][this.phenoFilesSet.covariableNames.length];
            for (int currentK = 0; currentK < K; ++currentK) {
                for (int currentN = 0; currentN < n; ++currentN) {
                    System.arraycopy(this.phenoFilesSet.covariableNames, 0, snpNames[currentK][currentN], 0, this.phenoFilesSet.covariableNames.length);
                }
            }
            interpreter.setValue("java_all_snp_list", snpNames);
            interpreter.exec(scripts.get("get_importance"));
            Map snpGain = interpreter.getValue("export_gains", Map.class);
            Map snpP = interpreter.getValue("export_p", Map.class);
            List sortedEntries = snpGain.entrySet().stream().sorted(Map.Entry.comparingByValue().reversed()).collect(Collectors.toList());
            int maxKeyLength = snpGain.keySet().stream().mapToInt(String::length).max().orElse(20);
            String keyLine = new String(new char[maxKeyLength]).replace('\u0000', '-');
            String valueLine = new String(new char[10]).replace('\u0000', '-');
            StringBuilder logBuilder = new StringBuilder("\nCovariable Feature Gain and Rank Comparison:\n");
            if (permutNum != 0) {
                logBuilder.append(String.format("%-" + maxKeyLength + "s | %10s | %10s | %10s\n", "Covar", "Gain", "P", "Rank"));
                logBuilder.append(String.format("%-" + maxKeyLength + "s-+-%10s-+-%10s-+-%10s\n", keyLine, valueLine, valueLine, valueLine));
                for (int i = 0; i < sortedEntries.size(); ++i) {
                    Map.Entry entry3 = (Map.Entry)sortedEntries.get(i);
                    String key = (String)entry3.getKey();
                    Double pValue = snpP.getOrDefault(key, -1.0);
                    logBuilder.append(String.format("%-" + maxKeyLength + "s | %10.4f | %10.4f | %10d\n", key, entry3.getValue(), pValue, i + 1));
                }
            } else {
                logBuilder.append(String.format("%-" + maxKeyLength + "s | %10s | %10s\n", "Covar", "Gain", "Rank"));
                logBuilder.append(String.format("%-" + maxKeyLength + "s-+-%10s-+-%10s\n", keyLine, valueLine, valueLine));
                for (int i = 0; i < sortedEntries.size(); ++i) {
                    Map.Entry entry4 = (Map.Entry)sortedEntries.get(i);
                    String key = (String)entry4.getKey();
                    logBuilder.append(String.format("%-" + maxKeyLength + "s | %10.4f | %10d\n", key, entry4.getValue(), i + 1));
                }
            }
            SetupApplication.GlobalLogger.info(logBuilder.toString());
        }
        interpreter.close();
    }

    private double train(int epoch, int K, int n, double[][] allPValues, int sampleSize, int[][][] folds, IndexableSet<String> testDownSampleUIDs, double[][] testDownSampleCovariables, Coordinate[][][] advancedCoordsTemp, DirectNDArray[][] testX, DirectNDArray testY, GlobalPythonInterpreter interpreter, Map<String, String> scripts) throws IOException {
        int isEnd = 0;
        double bestRoc = 0.0;
        double bestPr = 0.0;
        for (int i = 0; i < epoch; ++i) {
            interpreter.setValue("java_current_i", i);
            for (int currentK = 0; currentK < K; ++currentK) {
                interpreter.setValue("java_p_values", allPValues[currentK]);
                interpreter.setValue("java_k", currentK);
                interpreter.setValue("java_select_size", sampleSize);
                interpreter.exec(scripts.get("select"));
            }
            WeightSampleResult sampleResults = this.sample(sampleSize, folds, this.manager, K, n, testDownSampleUIDs, this.threads, interpreter);
            interpreter.setValue("java_X_train", sampleResults.getTrainX());
            interpreter.setValue("java_X_validate", sampleResults.getValidateX());
            interpreter.setValue("java_y_train", sampleResults.getTrainY());
            interpreter.setValue("java_y_validate", sampleResults.getValidateY());
            if (this.phenoFilesSet.covariableNames != null) {
                int covarNum = this.phenoFilesSet.covariableNames.length;
                double[][][] trainCovariables = new double[K][covarNum][];
                double[][][] validateCovariables = new double[K][covarNum][];
                for (int currentK = 0; currentK < K; ++currentK) {
                    int[] trainFold = folds[currentK][0];
                    int[] validFold = folds[currentK][1];
                    for (int currentC = 0; currentC < covarNum; ++currentC) {
                        int idx;
                        trainCovariables[currentK][currentC] = new double[trainFold.length];
                        for (idx = 0; idx < trainFold.length; ++idx) {
                            trainCovariables[currentK][currentC][idx] = this.phenoFilesSet.trainCovariables[currentC][trainFold[idx]];
                        }
                        validateCovariables[currentK][currentC] = new double[validFold.length];
                        for (idx = 0; idx < validFold.length; ++idx) {
                            validateCovariables[currentK][currentC][idx] = this.phenoFilesSet.trainCovariables[currentC][validFold[idx]];
                        }
                    }
                }
                interpreter.setValue("needCovar", true);
                interpreter.setValue("java_train_covars", trainCovariables);
                interpreter.setValue("java_validate_covars", validateCovariables);
            } else {
                interpreter.setValue("needCovar", false);
            }
            interpreter.exec(scripts.get("train"));
            int[] changeIndex = interpreter.getValue("change_index", int[].class);
            if (changeIndex.length != 0) {
                isEnd = 0;
                for (int currentK = 0; currentK < K; ++currentK) {
                    boolean isModelUpdated = Arrays.stream(changeIndex).anyMatch(currentK::equals);
                    if (!isModelUpdated) continue;
                    System.arraycopy(sampleResults.getSnpCoords()[currentK], 0, advancedCoordsTemp[currentK], 0, n);
                    System.arraycopy(sampleResults.getTestX()[currentK], 0, testX[currentK], 0, n);
                }
            } else if (++isEnd == 2) {
                SetupApplication.GlobalLogger.info("The model has shown no performance improvement for 3 consecutive times. The early stopping strategy has been triggered, and training is exited at the {}th epoch.", (Object)i);
                break;
            }
            interpreter.setValue("java_X_test", testX);
            interpreter.setValue("java_y_test", testY);
            if (this.phenoFilesSet.covariableNames != null) {
                interpreter.setValue("java_test_covars", testDownSampleCovariables);
            }
            interpreter.exec(scripts.get("test"));
            bestRoc = interpreter.getValue("test_df.iloc[-1]['auc_roc']", Double.class);
            bestPr = interpreter.getValue("test_df.iloc[-1]['auc_pr']", Double.class);
            SetupApplication.GlobalLogger.info("Test mean metric for epoch {} is: AUC-ROC: {}, AUC-PR: {}", i, bestRoc, bestPr);
            interpreter.eval("probabilities_list = []");
            sampleResults.clear();
        }
        return bestRoc;
    }

    private void trainWithoutGty(int K, int[][][] folds, double[][] testDownSampleCovariables, DirectNDArray testY, GlobalPythonInterpreter interpreter, Map<String, String> scripts) {
        DirectNDArray[] trainY = new DirectNDArray[K];
        DirectNDArray[] validateY = new DirectNDArray[K];
        for (int currentK = 0; currentK < K; ++currentK) {
            double phenotype;
            int i;
            int[] trainIndex = folds[currentK][0];
            int[] validateIndex = folds[currentK][1];
            DirectNumpyBinaryArray trainY0 = new DirectNumpyBinaryArray(trainIndex.length);
            DirectNumpyBinaryArray validateY0 = new DirectNumpyBinaryArray(validateIndex.length);
            for (i = 0; i < trainIndex.length; ++i) {
                phenotype = this.phenoFilesSet.trainPhenotypes.fastGet(trainIndex[i]);
                trainY0.set(i, (int)phenotype);
            }
            for (i = 0; i < validateIndex.length; ++i) {
                phenotype = this.phenoFilesSet.trainPhenotypes.fastGet(validateIndex[i]);
                validateY0.set(i, (int)phenotype);
            }
            trainY[currentK] = trainY0.getAsDirectNDArray();
            validateY[currentK] = validateY0.getAsDirectNDArray();
        }
        int covarNum = this.phenoFilesSet.covariableNames.length;
        double[][][] trainCovariables = new double[K][covarNum][];
        double[][][] validateCovariables = new double[K][covarNum][];
        for (int currentK = 0; currentK < K; ++currentK) {
            int[] trainFold = folds[currentK][0];
            int[] validFold = folds[currentK][1];
            for (int currentC = 0; currentC < covarNum; ++currentC) {
                int idx;
                trainCovariables[currentK][currentC] = new double[trainFold.length];
                for (idx = 0; idx < trainFold.length; ++idx) {
                    trainCovariables[currentK][currentC][idx] = this.phenoFilesSet.trainCovariables[currentC][trainFold[idx]];
                }
                validateCovariables[currentK][currentC] = new double[validFold.length];
                for (idx = 0; idx < validFold.length; ++idx) {
                    validateCovariables[currentK][currentC][idx] = this.phenoFilesSet.trainCovariables[currentC][validFold[idx]];
                }
            }
        }
        interpreter.setValue("needCovar", true);
        interpreter.setValue("java_train_covars", trainCovariables);
        interpreter.setValue("java_validate_covars", validateCovariables);
        interpreter.setValue("java_y_train", trainY);
        interpreter.setValue("java_y_validate", validateY);
        interpreter.setValue("java_current_i", 0);
        interpreter.exec(scripts.get("train"));
        interpreter.setValue("java_y_test", testY);
        interpreter.setValue("java_test_covars", testDownSampleCovariables);
        interpreter.exec(scripts.get("test"));
        double bestRoc = interpreter.getValue("test_df.iloc[-1]['auc_roc']", Double.class);
        double bestPr = interpreter.getValue("test_df.iloc[-1]['auc_pr']", Double.class);
        SetupApplication.GlobalLogger.info("Test mean metric is: AUC-ROC: {}, AUC-PR: {}", (Object)bestRoc, (Object)bestPr);
    }

    private double[][] calAllPValues(int K, int[][][] folds, GTBManager manager, PhenoFilesSet phenoFilesSet) {
        int numOfVariants = (int)manager.numOfVariants();
        double[][] allPValues = new double[K][numOfVariants];
        try (ThreadQueue threadQueue = new ThreadQueue(this.threads);){
            int i = 0;
            while (i < K) {
                int finalI = i++;
                threadQueue.addTask((s, c) -> {
                    int[] trainIndices = folds[finalI][0];
                    LinkedSet<String> caseIDsInPed = new LinkedSet<String>();
                    LinkedSet<String> controlIDsInPed = new LinkedSet<String>();
                    for (int index : trainIndices) {
                        if (GlobalPedIndividuals.isBinaryPhenotypes(0)) {
                            if (phenoFilesSet.trainPhenotypes.fastGet(index) == 0.0) {
                                ((AbstractCollection)controlIDsInPed).add(phenoFilesSet.trainUIDs.valueOf(index));
                                continue;
                            }
                            if (phenoFilesSet.trainPhenotypes.fastGet(index) != 1.0) continue;
                            ((AbstractCollection)caseIDsInPed).add(phenoFilesSet.trainUIDs.valueOf(index));
                            continue;
                        }
                        throw new Exception("Error! Phenotype currently only supports binary values.");
                    }
                    IntList controlIndices = controlIDsInPed.findIndicesIn(GlobalPedIndividuals.getIndividuals().getUIDs());
                    IntList caseIndices = caseIDsInPed.findIndicesIn(GlobalPedIndividuals.getIndividuals().getUIDs());
                    GTBReader reader = new GTBReader(manager);
                    double[] pValues = new double[numOfVariants];
                    int p = 0;
                    while (reader.hasNext()) {
                        Variant variant = reader.read();
                        ICounter controlGenotypeCounter = variant.getGenotypes().subGenotypes(controlIndices).counter();
                        ICounter caseGenotypeCounter = variant.getGenotypes().subGenotypes(caseIndices).counter();
                        int AACase = caseGenotypeCounter.count(Genotype.of(0, 0));
                        int AaCase = caseGenotypeCounter.count(Genotype.of(0, 1)) + caseGenotypeCounter.count(Genotype.of(1, 0));
                        int aaCase = caseGenotypeCounter.count(Genotype.of(1, 1));
                        int AAControl = controlGenotypeCounter.count(Genotype.of(0, 0));
                        int AaControl = controlGenotypeCounter.count(Genotype.of(0, 1)) + controlGenotypeCounter.count(Genotype.of(1, 0));
                        int aaControl = controlGenotypeCounter.count(Genotype.of(1, 1));
                        long[][] counts = new long[2][2];
                        counts[0][0] = 2L * (long)AACase + (long)AaCase;
                        counts[0][1] = 2L * (long)aaCase + (long)AaCase;
                        counts[1][0] = 2L * (long)AAControl + (long)AaControl;
                        counts[1][1] = 2L * (long)aaControl + (long)AaControl;
                        double pValue = ContingencyTable.chiSquareTest(counts);
                        pValues[p++] = Double.isNaN(pValue) ? 0.9999999999999999 : pValue;
                    }
                    reader.close();
                    allPValues[finalI] = pValues;
                });
            }
        }
        return allPValues;
    }

    private WeightSampleResult sample(int sampleSize, int[][][] folds, GTBManager manager, int K, int n, IndexableSet<String> testUIDs, int threads, GlobalPythonInterpreter interpreter) throws IOException {
        IntList trainUIDsIndices = this.phenoFilesSet.trainUIDs.findIndicesIn(GlobalPedIndividuals.getIndividuals().getUIDs());
        IntList testUIDsIndices = testUIDs.findIndicesIn(GlobalPedIndividuals.getIndividuals().getUIDs());
        interpreter.setValue("select_size", sampleSize);
        interpreter.exec("indices_selected_list = sample_indices(probabilities_list, K, n, select_size)");
        int[][][] sampledIndices = interpreter.getValue("indices_selected_list", int[][][].class);
        HashSet<Integer> globalIndexSet = new HashSet<Integer>();
        for (int k = 0; k < K; ++k) {
            for (int nIdx = 0; nIdx < n; ++nIdx) {
                for (int index : sampledIndices[k][nIdx]) {
                    globalIndexSet.add(index);
                }
            }
        }
        int[] globalSortedIndices = globalIndexSet.stream().mapToInt(Integer::intValue).sorted().toArray();
        HashMap<Integer, IGenotypes> globalTrainMap = new HashMap<Integer, IGenotypes>();
        HashMap<Integer, IGenotypes> globalTestMap = new HashMap<Integer, IGenotypes>();
        HashMap<Integer, Coordinate> globalCoordsMap = new HashMap<Integer, Coordinate>();
        GTBReader reader = new GTBReader(new GTBReaderOption(manager, true, false));
        for (int index : globalSortedIndices) {
            reader.seek(index);
            Variant variant = reader.read();
            globalTrainMap.put(index, variant.getGenotypes().subGenotypes(trainUIDsIndices));
            globalTestMap.put(index, variant.getGenotypes().subGenotypes(testUIDsIndices));
            globalCoordsMap.put(index, variant.getCoordinate());
        }
        reader.close();
        DirectNDArray[][] trainX = new DirectNDArray[K][n];
        DirectNDArray[][] validateX = new DirectNDArray[K][n];
        DirectNDArray[][] testX = new DirectNDArray[K][n];
        DirectNDArray[] trainY = new DirectNDArray[K];
        DirectNDArray[] validateY = new DirectNDArray[K];
        Coordinate[][][] snpCoords = new Coordinate[K][n][sampleSize];
        try (ThreadQueue threadQueue = new ThreadQueue(threads);){
            for (int currentK = 0; currentK < K; ++currentK) {
                double phenotype;
                int i;
                int finalCurrentK = currentK;
                int[] trainIndex = folds[finalCurrentK][0];
                int[] validateIndex = folds[finalCurrentK][1];
                DirectNumpyBinaryArray trainY0 = new DirectNumpyBinaryArray(trainIndex.length);
                DirectNumpyBinaryArray validateY0 = new DirectNumpyBinaryArray(validateIndex.length);
                for (i = 0; i < trainIndex.length; ++i) {
                    phenotype = this.phenoFilesSet.trainPhenotypes.fastGet(trainIndex[i]);
                    trainY0.set(i, (int)phenotype);
                }
                for (i = 0; i < validateIndex.length; ++i) {
                    phenotype = this.phenoFilesSet.trainPhenotypes.fastGet(validateIndex[i]);
                    validateY0.set(i, (int)phenotype);
                }
                int currentN = 0;
                while (currentN < n) {
                    int finalCurrentN = currentN++;
                    threadQueue.addTask((s, c) -> {
                        DirectNumpyGenotypeArray trainX0 = new DirectNumpyGenotypeArray(trainIndex.length, sampleSize);
                        DirectNumpyGenotypeArray validateX0 = new DirectNumpyGenotypeArray(validateIndex.length, sampleSize);
                        DirectNumpyGenotypeArray testX0 = new DirectNumpyGenotypeArray(testUIDs.size(), sampleSize);
                        int[] sortedSampled = Arrays.stream(sampledIndices[finalCurrentK][finalCurrentN]).sorted().toArray();
                        Coordinate[] coords = new Coordinate[sortedSampled.length];
                        int p = 0;
                        for (int index : sortedSampled) {
                            int i;
                            IGenotypes genotypes = (IGenotypes)globalTrainMap.get(index);
                            IGenotypes genotypesTest = (IGenotypes)globalTestMap.get(index);
                            coords[p] = (Coordinate)globalCoordsMap.get(index);
                            for (i = 0; i < trainIndex.length; ++i) {
                                trainX0.set(i, p, genotypes.get(trainIndex[i]).getAC());
                            }
                            for (i = 0; i < validateIndex.length; ++i) {
                                validateX0.set(i, p, genotypes.get(validateIndex[i]).getAC());
                            }
                            for (i = 0; i < testUIDs.size(); ++i) {
                                testX0.set(i, p, genotypesTest.get(i).getAC());
                            }
                            ++p;
                        }
                        trainX[finalCurrentK][finalCurrentN] = trainX0.getAsDirectNDArray();
                        validateX[finalCurrentK][finalCurrentN] = validateX0.getAsDirectNDArray();
                        testX[finalCurrentK][finalCurrentN] = testX0.getAsDirectNDArray();
                        snpCoords[finalCurrentK][finalCurrentN] = coords;
                    });
                    trainY[finalCurrentK] = trainY0.getAsDirectNDArray();
                    validateY[finalCurrentK] = validateY0.getAsDirectNDArray();
                }
            }
        }
        return new WeightSampleResult(trainX, validateX, testX, trainY, validateY, null, snpCoords);
    }

    private void generateGtyFile(File source2, File coorFile, File output, int threads) throws IOException {
        GTBManager manager = new GTBManager(source2);
        GTBManager coordsManager = new GTBManager(coorFile);
        final edu.sysu.pmglab.container.list.List<GTBReader> coorReaders = new GTBReader(new GTBReaderOption(coordsManager, true, true)).part(threads);
        final GTBWriter writers = GTBWriter.setOutput(output).addIndividuals(manager.getIndividuals()).addFields(coordsManager.getAllFields()).instance(threads);
        ThreadQueue queue = new ThreadQueue(threads);
        int i = 0;
        while (i < coorReaders.size()) {
            final int threadId = i++;
            queue.addTask(new ITask(){
                final Map<String, GTBReader> gtyReaders = new HashMap<String, GTBReader>();

                @Override
                public void execute(Status status, Context context) throws Exception, Error {
                    Variant variant;
                    GTBReader coorReader = (GTBReader)coorReaders.fastGet(threadId);
                    while ((variant = coorReader.read()) != null) {
                        GTBReader gtyReader;
                        String fileId = (String)variant.getProperty("SOURCE@FILE_ID");
                        long pointer = (Long)variant.getProperty("SOURCE@FILE_POINTER");
                        if (this.gtyReaders.containsKey(fileId)) {
                            gtyReader = this.gtyReaders.get(fileId);
                        } else {
                            gtyReader = new GTBReader(new GTBReaderOption(fileId, true, false));
                            this.gtyReaders.put(fileId, gtyReader);
                        }
                        gtyReader.seek(pointer);
                        Variant gtyVariant = gtyReader.read();
                        variant.setGenotypes(gtyVariant.getGenotypes());
                        writers.write(threadId, variant);
                    }
                    writers.finish(threadId);
                    for (String fileId : this.gtyReaders.keySet()) {
                        this.gtyReaders.get(fileId).close();
                    }
                    this.gtyReaders.clear();
                    coorReader.close();
                }
            });
        }
        queue.close();
        writers.close();
        GTBIndexer.setInput(output, new String[0]).save(threads);
    }

    private Map<String, String> getLEAPScriptFromPython() throws IOException {
        String[] strings;
        String line;
        HashMap<String, String> scripts = new HashMap<String, String>();
        InputStream resource = TrainLEAPModelTask.class.getResourceAsStream("/edu/sysu/pmglab/kgga/command/python/script/LEAP.py");
        assert (resource != null);
        BufferedReader reader = new BufferedReader(new InputStreamReader(resource, StandardCharsets.UTF_8));
        StringBuilder contentBuilder = new StringBuilder();
        while ((line = reader.readLine()) != null) {
            contentBuilder.append(line).append("\n");
        }
        String content = contentBuilder.toString();
        for (String title : strings = content.split("#>>>")) {
            String s = title.split("\n")[0].trim();
            if (s.isEmpty()) continue;
            scripts.put(s, "#>>>" + title.trim());
        }
        return scripts;
    }

    private String getJepPath() throws IOException, InterruptedException {
        int exitCode;
        String line;
        String libName = System.getProperty("os.name").toLowerCase().contains("win") ? "jep.dll" : (System.getProperty("os.name").toLowerCase().contains("mac") ? "libjep.jnilib" : "libjep.so");
        ProcessBuilder processBuilder = new ProcessBuilder("pip", "show", "jep");
        Process process = processBuilder.start();
        BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
        String location = null;
        while ((line = reader.readLine()) != null) {
            if (!line.startsWith("Location:")) continue;
            location = line.substring("Location:".length()).trim();
            break;
        }
        if ((exitCode = process.waitFor()) != 0 || location == null) {
            location = System.getenv("PYTHONJEP");
            if (location != null) {
                return Paths.get(location, new String[0]).toAbsolutePath().toString();
            }
            throw new RuntimeException("Failed to find Jep installation location");
        }
        return Paths.get(location, "jep", libName).toAbsolutePath().toString();
    }

    private String digest(File inputFile) throws IOException {
        return ITrack.digest(inputFile.getCanonicalPath() + "|" + inputFile.length() + "|" + inputFile.lastModified());
    }

    private static class WeightSampleResult {
        private DirectNDArray[][] trainX;
        private DirectNDArray[][] validateX;
        private DirectNDArray[][] testX;
        private DirectNDArray[] trainY;
        private DirectNDArray[] validateY;
        private DirectNDArray testY;
        private Coordinate[][][] snpCoords;

        public WeightSampleResult(DirectNDArray[][] trainX, DirectNDArray[][] validateX, DirectNDArray[][] testX, DirectNDArray[] trainY, DirectNDArray[] validateY, DirectNDArray testY, Coordinate[][][] snpCoords) {
            this.trainX = trainX;
            this.validateX = validateX;
            this.testX = testX;
            this.trainY = trainY;
            this.validateY = validateY;
            this.testY = testY;
            this.snpCoords = snpCoords;
        }

        public DirectNDArray[][] getTrainX() {
            return this.trainX;
        }

        public DirectNDArray[][] getValidateX() {
            return this.validateX;
        }

        public DirectNDArray[][] getTestX() {
            return this.testX;
        }

        public DirectNDArray[] getTrainY() {
            return this.trainY;
        }

        public DirectNDArray[] getValidateY() {
            return this.validateY;
        }

        public DirectNDArray getTestY() {
            return this.testY;
        }

        public Coordinate[][][] getSnpCoords() {
            return this.snpCoords;
        }

        public void clear() {
            this.trainX = null;
            this.validateX = null;
            this.testX = null;
            this.trainY = null;
            this.validateY = null;
            this.testY = null;
            this.snpCoords = null;
            System.gc();
        }
    }
}

