/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.simulation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.math3.distribution.PascalDistribution;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;

public class SpatialTranscriptomicsSimulator1 {
    public static double[][] computeCovarianceMatrix(List<double[]> locations, double sigma2, double l_x, double l_y) {
        int n = locations.size();
        double[][] K = new double[n][n];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                double dx = locations.get(i)[0] - locations.get(j)[0];
                double dy = locations.get(i)[1] - locations.get(j)[1];
                double dist2_x = dx * dx / (2.0 * l_x * l_x);
                double dist2_y = dy * dy / (2.0 * l_y * l_y);
                K[i][j] = sigma2 * Math.exp(-(dist2_x + dist2_y));
            }
        }
        return K;
    }

    public static double[] sampleMVN(double[][] K, double[] mu, RandomGenerator random) {
        int n = K.length;
        Array2DRowRealMatrix KMatrix = new Array2DRowRealMatrix(K);
        CholeskyDecomposition chol = new CholeskyDecomposition(KMatrix);
        RealMatrix L = chol.getL();
        double[] Z = new double[n];
        for (int i = 0; i < n; ++i) {
            Z[i] = random.nextGaussian();
        }
        ArrayRealVector ZVector = new ArrayRealVector(Z);
        RealVector GPVector = L.operate(ZVector);
        double[] GP = GPVector.toArray();
        for (int i = 0; i < n; ++i) {
            int n2 = i;
            GP[n2] = GP[n2] + mu[i];
        }
        return GP;
    }

    public static int sampleZINB(double w, double r, double mu, RandomGenerator random) {
        double u = random.nextDouble();
        if (u < w) {
            return 0;
        }
        double p = r / (r + mu);
        PascalDistribution nb = new PascalDistribution((int)r, p);
        return nb.sample();
    }

    public static double[] computeMeanVector(List<double[]> locations, List<double[]> hotspots, double l_x, double l_y) {
        int n = locations.size();
        double[] meanVec = new double[n];
        for (int i = 0; i < n; ++i) {
            double sum = 0.0;
            for (double[] h : hotspots) {
                double dx = locations.get(i)[0] - h[0];
                double dy = locations.get(i)[1] - h[1];
                double dist2_x = dx * dx / (2.0 * l_x * l_x);
                double dist2_y = dy * dy / (2.0 * l_y * l_y);
                sum += Math.exp(-(dist2_x + dist2_y));
            }
            meanVec[i] = sum;
        }
        return meanVec;
    }

    public static int[][] generateData(List<double[]> locations, int numGenes, double r, double spreadX, double spreadY, double sigma2, List<List<double[]>> hotspotsPerGene, RandomGenerator random) {
        int n = locations.size();
        int[][] counts = new int[n][numGenes];
        double[][] K = SpatialTranscriptomicsSimulator1.computeCovarianceMatrix(locations, sigma2, spreadX, spreadY);
        for (int g = 0; g < numGenes; ++g) {
            double beta0 = random.nextGaussian();
            double w = random.nextDouble() * 0.5;
            List<double[]> hotspotsG = hotspotsPerGene.get(g);
            double[] muG = new double[n];
            if (!hotspotsG.isEmpty()) {
                muG = SpatialTranscriptomicsSimulator1.computeMeanVector(locations, hotspotsG, spreadX, spreadY);
            }
            double[] GP = SpatialTranscriptomicsSimulator1.sampleMVN(K, muG, random);
            for (int i = 0; i < n; ++i) {
                double eta = beta0 + GP[i];
                double mean = Math.exp(eta);
                counts[i][g] = SpatialTranscriptomicsSimulator1.sampleZINB(w, r, mean, random);
            }
        }
        return counts;
    }

    public static double[] getRanks(List<Double> values2) {
        int n = values2.size();
        ArrayList<Pair> pairs = new ArrayList<Pair>();
        for (int i = 0; i < n; ++i) {
            pairs.add(new Pair(values2.get(i), i));
        }
        pairs.sort((a, b) -> Double.compare(a.value, b.value));
        double[] ranks = new double[n];
        int pos = 0;
        while (pos < n) {
            int start = pos;
            double v = ((Pair)pairs.get((int)pos)).value;
            while (pos < n && ((Pair)pairs.get((int)pos)).value == v) {
                ++pos;
            }
            int end = pos;
            double avgRank = ((double)start + 1.0 + (double)end) / 2.0;
            for (int k = start; k < end; ++k) {
                int idx = ((Pair)pairs.get((int)k)).index;
                ranks[idx] = avgRank;
            }
        }
        return ranks;
    }

    public static double pearsonCorrelation(double[] X2, double[] Y) {
        int n = X2.length;
        if (n != Y.length || n < 2) {
            return Double.NaN;
        }
        double sumX = 0.0;
        double sumY = 0.0;
        double sumXY = 0.0;
        double sumX2 = 0.0;
        double sumY2 = 0.0;
        for (int i = 0; i < n; ++i) {
            double x = X2[i];
            double y = Y[i];
            sumX += x;
            sumY += y;
            sumXY += x * y;
            sumX2 += x * x;
            sumY2 += y * y;
        }
        double meanX = sumX / (double)n;
        double meanY = sumY / (double)n;
        double SXX = sumX2 - sumX * sumX / (double)n;
        double SYY = sumY2 - sumY * sumY / (double)n;
        double SXY = sumXY - sumX * sumY / (double)n;
        if (SXX == 0.0 || SYY == 0.0) {
            return 0.0;
        }
        return SXY / Math.sqrt(SXX * SYY);
    }

    public static double spearmanCorrelation(List<Double> X2, List<Double> Y) {
        if (X2.size() != Y.size()) {
            return Double.NaN;
        }
        double[] ranksX = SpatialTranscriptomicsSimulator1.getRanks(new ArrayList<Double>(X2));
        double[] ranksY = SpatialTranscriptomicsSimulator1.getRanks(new ArrayList<Double>(Y));
        return SpatialTranscriptomicsSimulator1.pearsonCorrelation(ranksX, ranksY);
    }

    public static double[] listToDoubleArray(List<Double> list) {
        return list.stream().mapToDouble(Double::doubleValue).toArray();
    }

    public static double computeCx(int dx, int[][] counts, int gridSize) {
        int M = counts[0].length;
        double sumCorr = 0.0;
        for (int g = 0; g < M; ++g) {
            ArrayList<Double> A = new ArrayList<Double>();
            ArrayList<Double> B = new ArrayList<Double>();
            for (int i = 0; i <= gridSize - 1 - dx; ++i) {
                for (int j = 0; j < gridSize; ++j) {
                    int idx1 = i * gridSize + j;
                    int idx2 = (i + dx) * gridSize + j;
                    A.add(Double.valueOf(counts[idx1][g]));
                    B.add(Double.valueOf(counts[idx2][g]));
                }
            }
            double[] aArr = SpatialTranscriptomicsSimulator1.listToDoubleArray(A);
            double[] bArr = SpatialTranscriptomicsSimulator1.listToDoubleArray(B);
            double corr = SpatialTranscriptomicsSimulator1.pearsonCorrelation(aArr, bArr);
            sumCorr += corr;
        }
        return sumCorr / (double)M;
    }

    public static double computeCy(int dy, int[][] counts, int gridSize) {
        int M = counts[0].length;
        double sumCorr = 0.0;
        for (int g = 0; g < M; ++g) {
            ArrayList<Double> A = new ArrayList<Double>();
            ArrayList<Double> B = new ArrayList<Double>();
            for (int i = 0; i < gridSize; ++i) {
                for (int j = 0; j <= gridSize - 1 - dy; ++j) {
                    int idx1 = i * gridSize + j;
                    int idx2 = i * gridSize + (j + dy);
                    A.add(Double.valueOf(counts[idx1][g]));
                    B.add(Double.valueOf(counts[idx2][g]));
                }
            }
            double[] aArr = SpatialTranscriptomicsSimulator1.listToDoubleArray(A);
            double[] bArr = SpatialTranscriptomicsSimulator1.listToDoubleArray(B);
            double corr = SpatialTranscriptomicsSimulator1.pearsonCorrelation(aArr, bArr);
            sumCorr += corr;
        }
        return sumCorr / (double)M;
    }

    public static double[] computeMg(int g, List<List<double[]>> hotspotsPerGene, List<double[]> locations, double spreadX, double spreadY) {
        List<double[]> hotspotsG = hotspotsPerGene.get(g);
        if (hotspotsG.isEmpty()) {
            return null;
        }
        int N = locations.size();
        double[] Mg = new double[N];
        for (int k = 0; k < N; ++k) {
            double[] loc = locations.get(k);
            double x_k = loc[0];
            double y_k = loc[1];
            double sum = 0.0;
            for (double[] h : hotspotsG) {
                double h_x = h[0];
                double h_y = h[1];
                double dx = x_k - h_x;
                double dy = y_k - h_y;
                double dist2_x = dx * dx / (2.0 * spreadX * spreadX);
                double dist2_y = dy * dy / (2.0 * spreadY * spreadY);
                sum += Math.exp(-(dist2_x + dist2_y));
            }
            Mg[k] = sum;
        }
        return Mg;
    }

    public static void performStatisticalChecks(int[][] counts, List<double[]> locations, List<List<double[]>> hotspotsPerGene, double spreadX, double spreadY, int gridSize) {
        int maxLag = 5;
        System.out.println("Anisotropic Spatial Correlation Check:");
        for (int dx = 1; dx <= maxLag; ++dx) {
            double Cx = SpatialTranscriptomicsSimulator1.computeCx(dx, counts, gridSize);
            double expectedCx = Math.exp((double)(-(dx * dx)) / (2.0 * spreadX * spreadX));
            System.out.printf("For dx=%d, C_x=%.4f, expected=%.4f%n", dx, Cx, expectedCx);
        }
        for (int dy = 1; dy <= maxLag; ++dy) {
            double Cy = SpatialTranscriptomicsSimulator1.computeCy(dy, counts, gridSize);
            double expectedCy = Math.exp((double)(-(dy * dy)) / (2.0 * spreadY * spreadY));
            System.out.printf("For dy=%d, C_y=%.4f, expected=%.4f%n", dy, Cy, expectedCy);
        }
        System.out.println("\nHotspot Effect Check:");
        int M = counts[0].length;
        double sumCorrHotspot = 0.0;
        int countHotspot = 0;
        for (int g = 0; g < M; ++g) {
            if (hotspotsPerGene.get(g).isEmpty()) continue;
            double[] Mg = SpatialTranscriptomicsSimulator1.computeMg(g, hotspotsPerGene, locations, spreadX, spreadY);
            int N = locations.size();
            double[] expression_g = new double[N];
            for (int k = 0; k < N; ++k) {
                expression_g[k] = counts[k][g];
            }
            List<Double> exprList = Arrays.stream(expression_g).boxed().collect(Collectors.toList());
            List<Double> MgList = Arrays.stream(Mg).boxed().collect(Collectors.toList());
            double corr = SpatialTranscriptomicsSimulator1.spearmanCorrelation(exprList, MgList);
            System.out.printf("For gene %d, Spearman correlation with M_g: %.4f%n", g, corr);
            sumCorrHotspot += corr;
            ++countHotspot;
        }
        if (countHotspot > 0) {
            double avgCorrHotspot = sumCorrHotspot / (double)countHotspot;
            System.out.printf("Average Spearman correlation with M_g for genes with hotspots: %.4f%n", avgCorrHotspot);
        }
    }

    public static void main(String[] args) {
        ArrayList<double[]> locations = new ArrayList<double[]>();
        for (int i = 0; i < 10; ++i) {
            for (int j = 0; j < 10; ++j) {
                locations.add(new double[]{i, j});
            }
        }
        int numGenes = 1000;
        double r = 10.0;
        double spreadX = 2.0;
        double spreadY = 1.0;
        double sigma2 = 1.0;
        MersenneTwister random = new MersenneTwister(42);
        ArrayList<List<double[]>> hotspotsPerGene = new ArrayList<List<double[]>>();
        for (int g = 0; g < numGenes; ++g) {
            hotspotsPerGene.add(new ArrayList());
        }
        ((List)hotspotsPerGene.get(0)).add(new double[]{2.0, 2.0});
        ((List)hotspotsPerGene.get(0)).add(new double[]{7.0, 7.0});
        ((List)hotspotsPerGene.get(2)).add(new double[]{5.0, 5.0});
        int[][] counts = SpatialTranscriptomicsSimulator1.generateData(locations, numGenes, r, spreadX, spreadY, sigma2, hotspotsPerGene, random);
        SpatialTranscriptomicsSimulator1.performStatisticalChecks(counts, locations, hotspotsPerGene, spreadX, spreadY, 10);
    }

    static class Pair {
        double value;
        int index;

        Pair(double v, int idx) {
            this.value = v;
            this.index = idx;
        }
    }
}

