/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.kgga.command.python.toolkit;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Random;

public class Fold
implements Serializable {
    private static final long serialVersionUID = 1L;
    public List<Integer> trainIndices;
    public List<Integer> testIndices;

    public Fold(List<Integer> train, List<Integer> test2) {
        this.trainIndices = train;
        this.testIndices = test2;
    }

    public static List<Fold> stratifiedKFold(double[] labels, int nSplits) {
        HashMap<Double, List> classIndices = new HashMap<Double, List>();
        for (int i = 0; i < labels.length; ++i) {
            classIndices.computeIfAbsent(labels[i], k -> new ArrayList()).add(i);
        }
        Random random = new Random();
        ArrayList<Fold> result = new ArrayList<Fold>();
        ArrayList globalFolds = new ArrayList();
        for (int i = 0; i < nSplits; ++i) {
            globalFolds.add(new ArrayList());
        }
        for (List indices : classIndices.values()) {
            Collections.shuffle(indices, random);
            int foldSize = indices.size() / nSplits;
            int extra = indices.size() % nSplits;
            int start = 0;
            for (int foldIndex = 0; foldIndex < nSplits; ++foldIndex) {
                int end = start + foldSize + (foldIndex < extra ? 1 : 0);
                ((List)globalFolds.get(foldIndex)).addAll(indices.subList(start, end));
                start = end;
            }
        }
        for (int i = 0; i < nSplits; ++i) {
            ArrayList<Integer> testIndices = new ArrayList<Integer>((Collection)globalFolds.get(i));
            ArrayList<Integer> trainIndices = new ArrayList<Integer>();
            for (int j = 0; j < nSplits; ++j) {
                if (j == i) continue;
                trainIndices.addAll((Collection)globalFolds.get(j));
            }
            Collections.shuffle(trainIndices, random);
            Collections.shuffle(testIndices, random);
            result.add(new Fold(trainIndices, testIndices));
        }
        return result;
    }

    public static void saveFolds(List<Fold> folds, String filePath) throws IOException {
        try (ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(filePath)));){
            oos.writeInt(folds.size());
            for (Fold fold : folds) {
                oos.writeObject(fold);
            }
        }
    }

    public static List<Fold> loadFolds(String filePath) throws IOException, ClassNotFoundException {
        ArrayList<Fold> folds = new ArrayList<Fold>();
        try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(filePath)));){
            int foldCount = ois.readInt();
            for (int i = 0; i < foldCount; ++i) {
                folds.add((Fold)ois.readObject());
            }
        }
        return folds;
    }
}

