/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.container;

import edu.sysu.pmglab.ccf.record.IRecord;
import edu.sysu.pmglab.container.LDPairs;
import edu.sysu.pmglab.container.array.StringArray;
import edu.sysu.pmglab.container.list.IntList;
import edu.sysu.pmglab.container.list.List;
import edu.sysu.pmglab.gtb.GTBReader;
import edu.sysu.pmglab.gtb.genome.Variant;
import edu.sysu.pmglab.gtb.linkagedisequilibrium.GenotypeLD;
import edu.sysu.pmglab.gtb.linkagedisequilibrium.LDProperty;
import edu.sysu.pmglab.kgga.io.GlobalPedIndividuals;
import edu.sysu.pmglab.objectpool.GenericObjectPool;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.hash.THashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.IOException;
import java.util.Comparator;
import java.util.Map;

public enum LDPruner {
    INSTANCE;


    public static List<Variant> ldPruning(List<Variant> initialVariants, Map<String, GTBReader> gtyReaders, StringArray fieldsPrioritization, float r2Cut, Comparator<Variant> varFileIDComparator, Comparator<LDPairs> ldPairsComparator, TIntObjectMap<TIntHashSet> simplifiedLDUpper, TIntObjectMap<TIntHashSet> simplifiedLDLower, GenericObjectPool<LDProperty> ldPropertyObjectPool) throws IOException {
        initialVariants.sort(varFileIDComparator);
        THashMap<String, IntList> gtyIndividualIDs = new THashMap<String, IntList>();
        for (Variant variant : initialVariants) {
            IntList subjectIDs;
            if (variant.getGenotypes() != null && variant.getGenotypes().size() > 0) continue;
            String fileId1 = (String)variant.getProperty("SOURCE@FILE_ID");
            long pointer1 = (Long)variant.getProperty("SOURCE@FILE_POINTER");
            GTBReader reader1 = gtyReaders.get(fileId1);
            if (reader1 == null) {
                reader1 = new GTBReader(fileId1);
                gtyReaders.put(fileId1, reader1);
                subjectIDs = GlobalPedIndividuals.size() > 0 ? GlobalPedIndividuals.getIndividuals().getUIDs().findIndicesIn(reader1.getIndividuals()) : null;
                gtyIndividualIDs.put(fileId1, subjectIDs);
            } else {
                subjectIDs = (IntList)gtyIndividualIDs.get(fileId1);
            }
            reader1.seek(pointer1);
            Variant gtyVariant = reader1.read();
            variant.setGenotypes(gtyVariant.getGenotypes().subGenotypes(subjectIDs).toBiallelic(gtyVariant.indexOfAllele(variant.alleleOfIndex(0)), gtyVariant.indexOfAllele(variant.alleleOfIndex(1)), -1));
        }
        initialVariants.sort(Variant::compareTo);
        List<LDPairs> linkedVarScores = new List<LDPairs>();
        int size = initialVariants.size();
        List<TIntHashSet> linkedIndexes = new List<TIntHashSet>(size);
        for (int i = 0; i < size; ++i) {
            TIntHashSet linkedIndex = new TIntHashSet();
            linkedIndex.add(i);
            linkedIndexes.add(linkedIndex);
        }
        for (int i = 0; i < size; ++i) {
            TIntHashSet lowerSet;
            Variant variant0 = initialVariants.get(i);
            int pos0 = variant0.getPosition();
            TIntHashSet upperSet = simplifiedLDUpper.get(pos0);
            if (upperSet == null) {
                upperSet = new TIntHashSet();
                simplifiedLDUpper.put(variant0.getPosition(), upperSet);
            }
            if ((lowerSet = simplifiedLDLower.get(variant0.getPosition())) == null) {
                lowerSet = new TIntHashSet();
                simplifiedLDLower.put(variant0.getPosition(), lowerSet);
            }
            TIntSet linkedIndex = (TIntSet)linkedIndexes.get(i);
            for (int t = i + 1; t < size; ++t) {
                IRecord record;
                Variant variant1 = initialVariants.get(t);
                int pos1 = variant1.getPosition();
                if (lowerSet.contains(pos1)) continue;
                if (upperSet.contains(pos1)) {
                    linkedIndex.add(t);
                    ((TIntSet)linkedIndexes.get(t)).add(i);
                    continue;
                }
                if (variant0.getProperty(LDProperty.class.getName()) == null) {
                    variant0.setProperty(LDProperty.class.getName(), ldPropertyObjectPool.borrowObject().reload(variant0));
                }
                if (variant1.getProperty(LDProperty.class.getName()) == null) {
                    variant1.setProperty(LDProperty.class.getName(), ldPropertyObjectPool.borrowObject().reload(variant1));
                }
                if (Math.abs(((Float)(record = GenotypeLD.INSTANCE.apply(variant0, variant1)).get("R^2")).floatValue()) >= r2Cut) {
                    linkedIndex.add(t);
                    ((TIntSet)linkedIndexes.get(t)).add(i);
                    upperSet.add(pos1);
                    continue;
                }
                lowerSet.add(pos1);
            }
        }
        List<Variant> retainedVariants = new List<Variant>(size);
        int sizeScoreIndex = 0;
        for (int i = 0; i < size; ++i) {
            double[] rankScores;
            Variant variant0 = initialVariants.get(i);
            if (fieldsPrioritization == null) {
                rankScores = new double[2];
                rankScores[1] = variant0.getGenotypes().counter().getAF();
                if (rankScores[1] > 0.5) {
                    rankScores[1] = 1.0 - rankScores[1];
                }
            } else {
                rankScores = new double[fieldsPrioritization.length() + 2];
                for (int j = 0; j < fieldsPrioritization.length(); ++j) {
                    String valS = variant0.getProperty(fieldsPrioritization.get(j)).toString();
                    rankScores[j] = Double.parseDouble(valS);
                    if (!fieldsPrioritization.get(j).endsWith("@MarkGeneFeature") && !fieldsPrioritization.get(j).endsWith("@P")) continue;
                    rankScores[j] = -rankScores[j];
                }
                int mafScoreIndex = fieldsPrioritization.length() + 1;
                rankScores[mafScoreIndex] = variant0.getGenotypes().counter().getAF();
                if (rankScores[mafScoreIndex] > 0.5) {
                    rankScores[mafScoreIndex] = 1.0 - rankScores[mafScoreIndex];
                }
                sizeScoreIndex = fieldsPrioritization.length();
            }
            linkedVarScores.add(new LDPairs(i, (TIntSet)linkedIndexes.get(i), rankScores, sizeScoreIndex));
        }
        if (!linkedVarScores.isEmpty()) {
            IntList retainedIndexes = LDPruner.multiConditionalPrune(linkedVarScores, ldPairsComparator);
            size = retainedIndexes.size();
            for (int i = 0; i < size; ++i) {
                retainedVariants.add(initialVariants.get(retainedIndexes.fastGet(i)));
            }
        }
        retainedVariants.sort(Variant::compareTo);
        return retainedVariants;
    }

    static IntList multiConditionalPrune(List<LDPairs> linkedVarScores, Comparator<LDPairs> comparator) {
        List<LDPairs> linkedVarScoresTmp0 = new List<LDPairs>();
        for (LDPairs ldPairs : linkedVarScores) {
            if (ldPairs.indices.size() <= 1) continue;
            linkedVarScoresTmp0.add(ldPairs);
        }
        List<LDPairs> linkedVarScoresTmp1 = new List<LDPairs>();
        while (!linkedVarScoresTmp0.isEmpty()) {
            int id;
            linkedVarScoresTmp0.sort(comparator);
            int size = linkedVarScoresTmp0.size();
            for (id = 0; id < size && ((LDPairs)linkedVarScoresTmp0.get((int)id)).indices.size() <= 1; ++id) {
            }
            if (id == size) break;
            LDPairs first = (LDPairs)linkedVarScoresTmp0.get(id);
            TIntSet indices = first.indices;
            int finalIndex = first.initIndex;
            indices.forEach(i -> {
                ((LDPairs)linkedVarScores.fastGet((int)i)).indices.remove(finalIndex);
                ((LDPairs)linkedVarScores.fastGet(i)).updateSizeScore();
                return true;
            });
            indices.clear();
            for (LDPairs ldPairs : linkedVarScoresTmp0) {
                if (ldPairs.indices.size() <= 1) continue;
                linkedVarScoresTmp1.add(ldPairs);
            }
            linkedVarScoresTmp0.clear();
            linkedVarScoresTmp0.addAll(linkedVarScoresTmp1);
            linkedVarScoresTmp1.clear();
        }
        IntList retainedIndices = new IntList();
        for (LDPairs linkedVar : linkedVarScores) {
            if (linkedVar.indices.isEmpty()) continue;
            retainedIndices.add(linkedVar.initIndex);
        }
        return retainedIndices;
    }
}

