/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.kgga.command.python.toolkit;

import edu.sysu.pmglab.analysis.GenotypeLDUtils;
import edu.sysu.pmglab.container.list.IntList;
import edu.sysu.pmglab.executor.ThreadQueue;
import edu.sysu.pmglab.gtb.GTBReader;
import edu.sysu.pmglab.gtb.genome.Variant;
import edu.sysu.pmglab.io.text.TextRecord;
import edu.sysu.pmglab.io.text.reader.IHeaderParser;
import edu.sysu.pmglab.io.text.reader.TextReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.dense.row.factory.DecompositionFactory_DDRM;
import org.ejml.interfaces.decomposition.CholeskyDecomposition_F64;
import org.ejml.interfaces.decomposition.EigenDecomposition_F64;

public enum ValidVariantCounter {
    INSTANCE;


    public static void main(String[] args) throws IOException {
        String source2 = "/Users/yiguoshabi/KGGA/resources/AD/merged_data_2_hg38_2.gtb";
        String index = "/Users/yiguoshabi/KGGA/resources/AD/results/XgboostBasedWeightSampleTask/model_need_variant_index.tsv";
        TextReader textReader = TextReader.setInput(index).setHeaderParser(IHeaderParser.FIRST_LINE).instance();
        for (int i = 0; i < 40; ++i) {
            textReader.read();
        }
        ArrayList<Variant> variants = new ArrayList<Variant>();
        GTBReader reader = new GTBReader(source2);
        TextRecord textRecord = textReader.read();
        for (int j = 0; j < 5000; ++j) {
            reader.seek(textRecord.get("snp" + (j + 1)).toLong());
            variants.add(reader.read());
        }
        int validCount = INSTANCE.calculateValidVariant(variants);
        reader.close();
        textReader.close();
        System.out.println(validCount);
    }

    public IntList calculateValidVariant(IntList blocks, String gtbFile, int threads) throws IOException {
        GTBReader reader = new GTBReader(gtbFile);
        ArrayList<Variant> variants = new ArrayList<Variant>();
        for (long i = 0L; i < reader.numOfVariants(); ++i) {
            variants.add(reader.read());
        }
        reader.close();
        return this.calculateValidVariant(blocks, variants, threads);
    }

    public IntList calculateValidVariant(IntList blocks, List<Variant> variants, int threads) {
        int minN = (int)((double)variants.get(0).getGenotypes().size() * 0.5);
        int[] validCounts = new int[blocks.size() - 1];
        ThreadQueue threadQueue = new ThreadQueue(threads);
        int i = 0;
        while (i < blocks.size() - 1) {
            int finalI = i++;
            threadQueue.addTask((s, c) -> {
                DMatrixRMaj L;
                DMatrixRMaj copy;
                int start = blocks.fastGet(finalI);
                int end = blocks.fastGet(finalI + 1);
                int size = end - start;
                DMatrixRMaj ldMatrix = new DMatrixRMaj(size, size);
                List subVariants = variants.subList(start, end);
                for (int x = 0; x < size; ++x) {
                    ldMatrix.set(x, x, 1.0);
                    for (int y = x + 1; y < size; ++y) {
                        float r2 = ((Float)GenotypeLDUtils.INSTANCE.apply((Variant)subVariants.get(x), (Variant)subVariants.get(y), minN).get("R")).floatValue();
                        if (Float.isNaN(r2)) continue;
                        float absR2 = Math.abs(r2);
                        ldMatrix.set(x, y, absR2);
                        ldMatrix.set(y, x, absR2);
                    }
                }
                CholeskyDecomposition_F64<DMatrixRMaj> chol = DecompositionFactory_DDRM.chol(ldMatrix.numRows, true);
                if (!chol.decompose(copy = ldMatrix.copy())) {
                    DMatrixRMaj adjustMatrix = this.adjustMatrix(ldMatrix);
                    chol.decompose(adjustMatrix);
                    L = chol.getT(adjustMatrix);
                } else {
                    L = chol.getT(ldMatrix);
                }
                double sum = 0.0;
                for (int k = 0; k < size; ++k) {
                    double diag = L.get(k, k);
                    if (Double.isNaN(diag)) continue;
                    sum += diag * diag;
                }
                validCounts[finalI] = (int)Math.min((double)size, Math.floor(sum));
            });
        }
        threadQueue.close();
        return new IntList(validCounts);
    }

    public int calculateValidVariant(List<Variant> variants) {
        DMatrixRMaj L;
        DMatrixRMaj copy;
        int minN = (int)((double)variants.get(0).getGenotypes().size() * 0.5);
        int size = variants.size();
        DMatrixRMaj ldMatrix = new DMatrixRMaj(size, size);
        for (int x = 0; x < size; ++x) {
            ldMatrix.set(x, x, 1.0);
            for (int y = x + 1; y < size; ++y) {
                float r2 = ((Float)GenotypeLDUtils.INSTANCE.apply(variants.get(x), variants.get(y), minN).get("R")).floatValue();
                if (Float.isNaN(r2)) continue;
                float absR2 = Math.abs(r2);
                ldMatrix.set(x, y, absR2);
                ldMatrix.set(y, x, absR2);
            }
        }
        CholeskyDecomposition_F64<DMatrixRMaj> chol = DecompositionFactory_DDRM.chol(ldMatrix.numRows, true);
        if (!chol.decompose(copy = ldMatrix.copy())) {
            DMatrixRMaj adjustMatrix = this.adjustMatrix(ldMatrix);
            chol.decompose(adjustMatrix);
            L = chol.getT(adjustMatrix);
        } else {
            L = chol.getT(ldMatrix);
        }
        double sum = 0.0;
        for (int k = 0; k < size; ++k) {
            double diag = L.get(k, k);
            if (Double.isNaN(diag)) continue;
            sum += diag * diag;
        }
        return (int)Math.min((double)size, Math.floor(sum));
    }

    private DMatrixRMaj adjustMatrix(DMatrixRMaj matrix) {
        int i;
        EigenDecomposition_F64<DMatrixRMaj> eig = DecompositionFactory_DDRM.eig(matrix.numRows, true);
        if (!eig.decompose(matrix)) {
            throw new RuntimeException("\u7279\u5f81\u5206\u89e3\u5931\u8d25");
        }
        double[] evals = new double[matrix.numRows];
        DMatrixRMaj eigenvectors = new DMatrixRMaj(matrix.numRows, matrix.numRows);
        for (i = 0; i < matrix.numRows; ++i) {
            evals[i] = eig.getEigenvalue((int)i).real;
            DMatrixRMaj v = (DMatrixRMaj)eig.getEigenVector(i);
            CommonOps_DDRM.insert(v, eigenvectors, 0, i);
        }
        for (i = 0; i < evals.length; ++i) {
            if (!(evals[i] < 0.0)) continue;
            evals[i] = 1.0E-10f;
        }
        DMatrixRMaj diag = new DMatrixRMaj(evals.length, evals.length);
        for (int i2 = 0; i2 < evals.length; ++i2) {
            diag.set(i2, i2, evals[i2]);
        }
        DMatrixRMaj temp = new DMatrixRMaj(matrix.numRows, matrix.numRows);
        DMatrixRMaj result = new DMatrixRMaj(matrix.numRows, matrix.numRows);
        CommonOps_DDRM.mult(eigenvectors, diag, temp);
        CommonOps_DDRM.multTransB(temp, eigenvectors, result);
        return result;
    }
}

