/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.simulation;

import edu.sysu.pmglab.stat.FastGpSampler;
import java.io.FileNotFoundException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;

public class SpatialTranscriptomicsSimulator {
    public double[][] computeCovarianceMatrix(List<int[]> locations, double sigma2, double l) {
        int n = locations.size();
        double[][] K = new double[n][n];
        double jitter = 1.0E-6;
        int i = 0;
        while (i < n) {
            for (int j = i; j < n; ++j) {
                double cov;
                double dx = locations.get(i)[0] - locations.get(j)[0];
                double dy = locations.get(i)[1] - locations.get(j)[1];
                double dist2 = dx * dx + dy * dy;
                K[i][j] = cov = sigma2 * Math.exp(-dist2 / (2.0 * l * l));
                if (i == j) continue;
                K[j][i] = cov;
            }
            double[] dArray = K[i];
            int n2 = i++;
            dArray[n2] = dArray[n2] + jitter;
        }
        return K;
    }

    public int sampleZINB(double w, double r, double mu, RandomGenerator random) {
        if (random.nextDouble() < w) {
            return 0;
        }
        if (mu <= 0.0) {
            return 0;
        }
        GammaDistribution gamma = new GammaDistribution(random, r, mu / r);
        double lambda = gamma.sample();
        PoissonDistribution poisson = new PoissonDistribution(random, lambda, 1.0E-12, 10000000);
        return poisson.sample();
    }

    public double[] computeMeanVector(List<int[]> locations, List<int[]> hotspots, double l) {
        int n = locations.size();
        double[] meanVec = new double[n];
        for (int i = 0; i < n; ++i) {
            double sum = 0.0;
            for (int[] h : hotspots) {
                double dx = locations.get(i)[0] - h[0];
                double dy = locations.get(i)[1] - h[1];
                double dist2 = dx * dx + dy * dy;
                sum += Math.exp(-dist2 / (2.0 * l * l));
            }
            meanVec[i] = sum;
        }
        return meanVec;
    }

    public double[] sampleMVNFromL(RealMatrix L, double[] mu, RandomGenerator random) {
        int n = L.getRowDimension();
        double[] Z = new double[n];
        for (int i = 0; i < n; ++i) {
            Z[i] = random.nextGaussian();
        }
        ArrayRealVector ZVector = new ArrayRealVector(Z);
        RealVector GPVector = L.operate(ZVector);
        double[] GP = GPVector.toArray();
        for (int i = 0; i < n; ++i) {
            int n2 = i;
            GP[n2] = GP[n2] + mu[i];
        }
        return GP;
    }

    public double[] sampleMVN(double[][] K, RandomGenerator random) {
        int n = K.length;
        Array2DRowRealMatrix KMatrix = new Array2DRowRealMatrix(K);
        CholeskyDecomposition chol = new CholeskyDecomposition(KMatrix);
        RealMatrix L = chol.getL();
        double[] Z = new double[n];
        for (int i = 0; i < n; ++i) {
            Z[i] = random.nextGaussian();
        }
        ArrayRealVector ZVector = new ArrayRealVector(Z);
        return L.operate(ZVector).toArray();
    }

    public double[] sampleMVNFromL(RealMatrix L, RandomGenerator random) {
        int n = L.getRowDimension();
        double[] Z = new double[n];
        for (int i = 0; i < n; ++i) {
            Z[i] = random.nextGaussian();
        }
        ArrayRealVector ZVector = new ArrayRealVector(Z);
        return L.operate(ZVector).toArray();
    }

    public int[][] generateData(List<int[]> locations, int numGenes, double r, double l, double sigma2, double alpha, double baseBeta0, double spatialSparsity, List<List<int[]>> hotspotsPerGene, RandomGenerator masterRandom) {
        int n = locations.size();
        int[][] counts = new int[n][numGenes];
        double[][] K = this.computeCovarianceMatrix(locations, sigma2, l);
        Array2DRowRealMatrix KMatrix = new Array2DRowRealMatrix(K);
        CholeskyDecomposition chol = new CholeskyDecomposition(KMatrix);
        RealMatrix L = chol.getL();
        IntStream.range(0, numGenes).parallel().forEach(g -> {
            MersenneTwister threadRandom = new MersenneTwister(masterRandom.nextLong());
            double beta0 = baseBeta0 + threadRandom.nextGaussian() * 0.5;
            double w = spatialSparsity;
            List hotspotsG = null;
            if (hotspotsPerGene != null && g < hotspotsPerGene.size()) {
                hotspotsG = (List)hotspotsPerGene.get(g);
            }
            double[] muG = new double[n];
            if (hotspotsG != null && !hotspotsG.isEmpty()) {
                muG = this.computeMeanVector(locations, hotspotsG, l);
            }
            double[] GP = this.sampleMVNFromL(L, threadRandom);
            for (int i = 0; i < n; ++i) {
                double eta = beta0 + GP[i] + alpha * muG[i];
                double mean = Math.exp(eta);
                counts[i][g] = this.sampleZINB(w, r, mean, threadRandom);
            }
        });
        return counts;
    }

    public int[][] generateDataWithFFT(int nx, int ny, List<int[]> locations, int numGenes, double r, double l, double sigma2, double alpha, double baseBeta0, double spatialSparsity, List<List<int[]>> hotspotsPerGene, RandomGenerator masterRandom) {
        int n_total = nx * ny;
        int[][] counts = new int[n_total][numGenes];
        FastGpSampler sampler = new FastGpSampler(nx, ny, sigma2, l);
        IntStream.range(0, numGenes).parallel().forEach(g -> {
            MersenneTwister threadRandom = new MersenneTwister(masterRandom.nextLong());
            double beta0 = baseBeta0 + threadRandom.nextGaussian() * 0.5;
            double w = spatialSparsity;
            List hotspotsG = null;
            if (hotspotsPerGene != null && g < hotspotsPerGene.size()) {
                hotspotsG = (List)hotspotsPerGene.get(g);
            }
            double[] muG = new double[n_total];
            if (hotspotsG != null && !hotspotsG.isEmpty()) {
                muG = this.computeMeanVector(locations, hotspotsG, l);
            }
            double[] GP = sampler.sample(threadRandom);
            for (int i = 0; i < n_total; ++i) {
                double eta = beta0 + GP[i] + alpha * muG[i];
                double mean = Math.exp(eta);
                counts[i][g] = this.sampleZINB(w, r, mean, threadRandom);
            }
        });
        return counts;
    }

    public static void main(String[] args) {
        int nx;
        int ny = nx = 100;
        int numGenes = 1;
        double baseBeta0 = -0.5;
        double alpha = 2.5;
        double l = 5.0;
        double sigma2 = 0.5;
        double r = 5.0;
        MersenneTwister random = new MersenneTwister(System.currentTimeMillis());
        ArrayList<int[]> hotspotsForGene = new ArrayList<int[]>();
        hotspotsForGene.add(new int[]{10, 10});
        ArrayList<List<int[]>> hotspotsPerGene = new ArrayList<List<int[]>>();
        hotspotsPerGene.add(hotspotsForGene);
        ArrayList<int[]> locations = new ArrayList<int[]>();
        for (int i = 0; i < nx; ++i) {
            int j = 0;
            while (j < ny) {
                locations.add(new int[]{i, j++});
            }
        }
        System.out.println("\u5f00\u59cb\u6a21\u62df\u5355\u4e2a\u57fa\u56e0\u7684\u8868\u8fbe\u6570\u636e...");
        SpatialTranscriptomicsSimulator simulator = new SpatialTranscriptomicsSimulator();
        int[][] counts = null;
        int simuTime = 10000;
        double[][] countSum = new double[nx][ny];
        for (int i = 0; i < countSum.length; ++i) {
            Arrays.fill(countSum[i], 0.0);
        }
        long start = System.nanoTime();
        for (int i = 0; i < simuTime; ++i) {
            counts = simulator.generateDataWithFFT(nx, ny, locations, numGenes, r, l, sigma2, alpha, baseBeta0, 0.8, hotspotsPerGene, random);
            for (int k = 0; k < nx; ++k) {
                for (int t = 0; t < ny; ++t) {
                    double[] dArray = countSum[k];
                    int n = t;
                    dArray[n] = dArray[n] + (double)counts[k * nx + t][0];
                }
            }
        }
        System.out.println((double)(System.nanoTime() - start) / 1.0E9);
        String outputFileName = "hotspot_test_output.csv";
        System.out.println("\u6b63\u5728\u5c06\u7ed3\u679c\u4fdd\u5b58\u5230 " + outputFileName + " ...");
        try (PrintWriter writer = new PrintWriter(outputFileName);){
            int k;
            writer.print("col");
            for (k = 0; k < nx; ++k) {
                writer.print("\t" + k);
            }
            writer.print("\n");
            for (k = 0; k < nx; ++k) {
                writer.print(k);
                for (int t = 0; t < ny; ++t) {
                    writer.print("\t" + countSum[k][t] / (double)simuTime);
                }
                writer.print("\n");
            }
        }
        catch (FileNotFoundException e) {
            System.err.println("\u9519\u8bef\uff1a\u65e0\u6cd5\u5199\u5165\u6587\u4ef6 " + outputFileName);
            e.printStackTrace();
        }
        System.out.println("\u6587\u4ef6\u4fdd\u5b58\u6210\u529f\uff01");
        System.out.println("\u73b0\u5728\u60a8\u53ef\u4ee5\u4f7f\u7528Python, R, \u6216 Excel \u6253\u5f00 " + outputFileName + " \u6765\u7ed8\u5236\u70ed\u529b\u56fe\u3002");
    }
}

