/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.simulation;

import java.io.BufferedWriter;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;

public class SpatialTranscriptomicsSimulatorMaternKernel {
    private final int W;
    private final int H;
    private final double baseline;
    private final double dropoutProb;
    private final double[] ellCands;
    private final double[] nuCands;
    private final int threads;
    private final long masterSeed;
    private final double kernelCutoff;
    private final ConcurrentMap<String, KernelPatch> kernelCache = new ConcurrentHashMap<String, KernelPatch>();
    private final double[] glX;
    private final double[] glW;

    public SpatialTranscriptomicsSimulatorMaternKernel(int W, int H, double baseline, double dropoutProb, double[] ellCands, double[] nuCands, int threads, long masterSeed, double kernelCutoff) {
        if (W < 1 || H < 1) {
            throw new IllegalArgumentException("W,H must be >=1");
        }
        this.W = W;
        this.H = H;
        this.baseline = baseline;
        this.dropoutProb = dropoutProb;
        this.ellCands = Arrays.copyOf(ellCands, ellCands.length);
        Arrays.sort(this.ellCands);
        this.nuCands = Arrays.copyOf(nuCands, nuCands.length);
        Arrays.sort(this.nuCands);
        this.threads = Math.max(1, threads);
        this.masterSeed = masterSeed;
        this.kernelCutoff = kernelCutoff;
        int nGL = 64;
        double[][] gl = this.computeGaussLegendre(nGL);
        this.glX = gl[0];
        this.glW = gl[1];
        this.precomputeAllPatchesAndReport();
    }

    private String key(double ell, double nu) {
        return String.format(Locale.ROOT, "e%.6f_n%.6f", ell, nu);
    }

    private void precomputeAllPatchesAndReport() {
        for (double ell : this.ellCands) {
            for (double nu : this.nuCands) {
                String k = this.key(ell, nu);
                KernelPatch p = this.buildPatchDynamicRadius(ell, nu, this.kernelCutoff);
                this.kernelCache.put(k, p);
            }
        }
    }

    private KernelPatch buildPatchDynamicRadius(double ell, double nu, double cutoff) {
        int maxPossible = Math.max(this.W, this.H);
        int radius = 1;
        int r = 1;
        while (r <= maxPossible) {
            double kv = this.maternRadial_GL(r, ell, nu);
            if (kv < cutoff) {
                radius = r;
                break;
            }
            radius = r++;
        }
        radius = Math.max(1, Math.min(radius, maxPossible));
        ArrayList<Integer> dxs = new ArrayList<Integer>();
        ArrayList<Integer> dys = new ArrayList<Integer>();
        ArrayList<Float> vals = new ArrayList<Float>();
        for (int dy = -radius; dy <= radius; ++dy) {
            for (int dx = -radius; dx <= radius; ++dx) {
                double dist = Math.hypot(dx, dy);
                double kv = this.maternRadial_GL(dist, ell, nu);
                if (kv <= cutoff) continue;
                dxs.add(dx);
                dys.add(dy);
                vals.add(Float.valueOf((float)kv));
            }
        }
        int n = dxs.size();
        int[] dxA = new int[n];
        int[] dyA = new int[n];
        float[] vA = new float[n];
        for (int i = 0; i < n; ++i) {
            dxA[i] = (Integer)dxs.get(i);
            dyA[i] = (Integer)dys.get(i);
            vA[i] = ((Float)vals.get(i)).floatValue();
        }
        return new KernelPatch(radius, dxA, dyA, vA);
    }

    private double maternRadial_GL(double r, double ell, double nu) {
        if (r <= 0.0) {
            return 1.0;
        }
        double z = Math.sqrt(2.0 * nu) * r / ell;
        if (Math.abs(nu - 0.5) < 1.0E-12) {
            return Math.exp(-z);
        }
        if (Math.abs(nu - 1.5) < 1.0E-12) {
            double a = Math.sqrt(3.0) * r / ell;
            return (1.0 + a) * Math.exp(-a);
        }
        if (Math.abs(nu - 2.5) < 1.0E-12) {
            double a = Math.sqrt(5.0) * r / ell;
            double rr = r / ell;
            return (1.0 + a + 5.0 * rr * rr / 3.0) * Math.exp(-a);
        }
        double K = this.besselK_GL64(nu, z);
        double pref = Math.pow(2.0, 1.0 - nu) / this.gammaLanczos(nu);
        double val = pref * Math.pow(z, nu) * K;
        if (Double.isNaN(val) || val < 0.0) {
            return 0.0;
        }
        return val;
    }

    private double besselK_GL64(double nu, double z) {
        double T;
        if (z <= 0.0) {
            return Double.POSITIVE_INFINITY;
        }
        if (z < 1.0E-12) {
            return 0.5 * this.gammaLanczos(nu) * Math.pow(z / 2.0, -nu);
        }
        if (z > 80.0) {
            return Math.sqrt(Math.PI / (2.0 * z)) * Math.exp(-z);
        }
        double tol = 1.0E-14;
        double target = Math.log(1.0 / tol) / Math.max(z, 1.0E-300);
        if (target <= 1.0) {
            T = 7.0;
        } else {
            T = Math.log(target + Math.sqrt(target * target - 1.0));
            T = Math.max(T, 7.0);
        }
        double halfT = T / 2.0;
        double mid = T / 2.0;
        double sum = 0.0;
        for (int i = 0; i < this.glX.length; ++i) {
            double u = this.glX[i];
            double t = mid + halfT * u;
            double cosh_t = this.coshSafe(t);
            double e = -z * cosh_t;
            if (e < -700.0) continue;
            double f = Math.exp(e) * this.coshSafe(nu * t);
            sum += this.glW[i] * f;
        }
        return sum * halfT;
    }

    private double coshSafe(double t) {
        double a = Math.abs(t);
        if (a < 20.0) {
            return 0.5 * (Math.exp(t) + Math.exp(-t));
        }
        return Math.exp(a) / 2.0;
    }

    private double gammaLanczos(double z) {
        double[] p = new double[]{676.5203681218851, -1259.1392167224028, 771.3234287776531, -176.6150291621406, 12.507343278686905, -0.13857109526572012, 9.984369578019572E-6, 1.5056327351493116E-7};
        if (z < 0.5) {
            return Math.PI / (Math.sin(Math.PI * z) * this.gammaLanczos(1.0 - z));
        }
        z -= 1.0;
        double x = 0.9999999999998099;
        for (int i = 0; i < p.length; ++i) {
            x += p[i] / (z + (double)i + 1.0);
        }
        double t = z + (double)p.length - 0.5;
        return Math.sqrt(Math.PI * 2) * Math.pow(t, z + 0.5) * Math.exp(-t) * x;
    }

    private double[][] computeGaussLegendre(int n) {
        double[] x = new double[n];
        double[] w = new double[n];
        int m = (n + 1) / 2;
        for (int i = 0; i < m; ++i) {
            double wi;
            double xi;
            double theta = Math.PI * ((double)i + 0.75) / ((double)n + 0.5);
            double xi1 = xi = Math.cos(theta);
            for (int iter = 0; iter < 100; ++iter) {
                double[] pw = this.legendrePandPderiv(n, xi1);
                double pn = pw[0];
                double pd = pw[1];
                double delta = pn / pd;
                xi1 -= delta;
                if (Math.abs(delta) < 1.0E-15) break;
            }
            double[] pwFinal = this.legendrePandPderiv(n, xi1);
            x[i] = -xi1;
            x[n - 1 - i] = xi1;
            double pd = pwFinal[1];
            w[i] = wi = 2.0 / ((1.0 - xi1 * xi1) * pd * pd);
            w[n - 1 - i] = wi;
        }
        return new double[][]{x, w};
    }

    private double[] legendrePandPderiv(int n, double x) {
        double p0 = 1.0;
        double p1 = x;
        if (n == 0) {
            return new double[]{p0, 0.0};
        }
        if (n == 1) {
            return new double[]{p1, 1.0};
        }
        double pn = 0.0;
        for (int k = 2; k <= n; ++k) {
            pn = ((2.0 * (double)k - 1.0) * x * p1 - ((double)k - 1.0) * p0) / (double)k;
            p0 = p1;
            p1 = pn;
        }
        double pnm1 = p0;
        double pd = (double)n * (x * pn - pnm1) / (x * x - 1.0);
        return new double[]{pn, pd};
    }

    public int[][] generateSingleGeneExpression(List<int[]> hotspotPositions, double hotspotAmplitude, double ell, double nu, RandomGenerator rng) {
        double[][] mean = new double[this.H][this.W];
        if (this.baseline != 0.0) {
            for (int y = 0; y < this.H; ++y) {
                Arrays.fill(mean[y], this.baseline);
            }
        }
        if (hotspotPositions != null && !hotspotPositions.isEmpty()) {
            KernelPatch patch = (KernelPatch)this.kernelCache.get(this.key(ell, nu));
            if (patch == null) {
                patch = this.buildPatchDynamicRadius(ell, nu, this.kernelCutoff);
                this.kernelCache.put(this.key(ell, nu), patch);
                System.out.printf("Built patch on-the-fly for (ell=%.4f,nu=%.4f): radius=%d entries=%d%n", ell, nu, patch.radius, patch.val.length);
            }
            int pn = patch.val.length;
            int[] dx = patch.dx;
            int[] dy = patch.dy;
            float[] vals = patch.val;
            for (int[] pos : hotspotPositions) {
                int cx = pos[0];
                int cy = pos[1];
                double amp = hotspotAmplitude;
                for (int k = 0; k < pn; ++k) {
                    int xx = cx + dx[k];
                    int yy = cy + dy[k];
                    if (xx < 0 || xx >= this.W || yy < 0 || yy >= this.H) continue;
                    double[] dArray = mean[yy];
                    int n = xx;
                    dArray[n] = dArray[n] + amp * (double)vals[k];
                }
            }
        }
        int[][] counts = new int[this.H][this.W];
        for (int y = 0; y < this.H; ++y) {
            for (int x = 0; x < this.W; ++x) {
                double mu = mean[y][x];
                counts[y][x] = rng.nextDouble() < this.dropoutProb ? 0 : this.samplePoisson(mu, rng);
            }
        }
        return counts;
    }

    private int samplePoisson(double lambda, RandomGenerator rng) {
        if (lambda <= 0.0) {
            return 0;
        }
        if (lambda > 1000000.0) {
            double val = rng.nextGaussian() * Math.sqrt(lambda) + lambda;
            return (int)Math.max(0L, Math.round(val));
        }
        double L = Math.exp(-lambda);
        int k = 0;
        double p = 1.0;
        do {
            ++k;
        } while ((p *= rng.nextDouble()) > L);
        return k - 1;
    }

    public void accumulateGenesToSum(int numGenes, List<List<int[]>> hotspotsPerGene, double hotspotAmplitude, List<Double> ellsPerGene, List<Double> nusPerGene, double[][] sumExpr) throws InterruptedException {
        ExecutorService poolExec = Executors.newFixedThreadPool(this.threads);
        int perThread = Math.max(1, numGenes / this.threads);
        int rem = numGenes % this.threads;
        ArrayList<Callable<Void>> tasks = new ArrayList<Callable<Void>>();
        for (int t = 0; t < this.threads; ++t) {
            int startIndex = t * perThread + Math.min(t, rem);
            int count = perThread + (t < rem ? 1 : 0);
            long seed = this.masterSeed + (long)t * 1315423911L;
            tasks.add(() -> {
                MersenneTwister rng = new MersenneTwister(seed);
                double[][] localSum = new double[this.H][this.W];
                for (int i = 0; i < count; ++i) {
                    int gi = startIndex + i;
                    List hotspots = hotspotsPerGene != null ? (List)hotspotsPerGene.get(gi) : Collections.emptyList();
                    double ell = ellsPerGene != null ? (Double)ellsPerGene.get(gi) : this.ellCands[rng.nextInt(this.ellCands.length)];
                    double nu = nusPerGene != null ? (Double)nusPerGene.get(gi) : this.nuCands[rng.nextInt(this.nuCands.length)];
                    int[][] cnt = this.generateSingleGeneExpression(hotspots, hotspotAmplitude, ell, nu, rng);
                    for (int y = 0; y < this.H; ++y) {
                        int[] crow = cnt[y];
                        double[] lrow = localSum[y];
                        for (int x = 0; x < this.W; ++x) {
                            int n = x;
                            lrow[n] = lrow[n] + (double)crow[x];
                        }
                    }
                }
                double[][] dArray = sumExpr;
                synchronized (sumExpr) {
                    for (int y = 0; y < this.H; ++y) {
                        double[] srow = sumExpr[y];
                        double[] lrow = localSum[y];
                        for (int x = 0; x < this.W; ++x) {
                            int n = x;
                            srow[n] = srow[n] + lrow[x];
                        }
                    }
                    // ** MonitorExit[var13_12] (shouldn't be in output)
                    return null;
                }
            });
        }
        poolExec.invokeAll(tasks);
        poolExec.shutdown();
        poolExec.awaitTermination(1L, TimeUnit.HOURS);
    }

    public static void main(String[] args) throws Exception {
        int maxHotspots = 2;
        int threads = Math.max(1, Runtime.getRuntime().availableProcessors() - 1);
        int W = 100;
        int H = 100;
        double baseline = 1.0;
        double dropout = 0.7;
        double[] ellCands = new double[]{1.5};
        double[] nuCands = new double[]{2.5};
        long seed = 1234567L;
        int numGenes = 50000;
        double hotspotAmp = 2.0;
        double kernelCutoff = 0.001;
        System.out.printf("W=%d H=%d genes=%d threads=%d baseline=%.3f dropout=%.3f cutoff=%g%n", W, H, numGenes, threads, baseline, dropout, kernelCutoff);
        SpatialTranscriptomicsSimulatorMaternKernel sim = new SpatialTranscriptomicsSimulatorMaternKernel(W, H, baseline, dropout, ellCands, nuCands, threads, seed, kernelCutoff);
        ArrayList<List<int[]>> hotspotsPerGene = new ArrayList<List<int[]>>(numGenes);
        Random rg = new Random(2025L);
        for (int g = 0; g < numGenes; ++g) {
            int nh = 1;
            ArrayList<int[]> hs = new ArrayList<int[]>(nh);
            for (int k = 0; k < nh; ++k) {
                hs.add(new int[]{10, 10});
            }
            hotspotsPerGene.add(hs);
        }
        List<Double> ells = null;
        List<Double> nus = null;
        double[][] sumExpr = new double[H][W];
        long t0 = System.currentTimeMillis();
        sim.accumulateGenesToSum(numGenes, hotspotsPerGene, hotspotAmp, ells, nus, sumExpr);
        long t1 = System.currentTimeMillis();
        System.out.printf("Accumulation done in %.1f s%n", (double)(t1 - t0) / 1000.0);
        String out = "spot_mean_expression.tsv";
        try (BufferedWriter writer = Files.newBufferedWriter(Paths.get(out, new String[0]), new OpenOption[0]);){
            writer.write("row/col");
            for (int x = 0; x < W; ++x) {
                writer.write("\t" + x);
            }
            writer.newLine();
            for (int y = 0; y < H; ++y) {
                writer.write(String.valueOf(y));
                for (int x = 0; x < W; ++x) {
                    double mean = sumExpr[y][x] / (double)numGenes;
                    writer.write("\t" + String.format(Locale.ROOT, "%.6f", mean));
                }
                writer.newLine();
            }
        }
        long t2 = System.currentTimeMillis();
        System.out.printf("Wrote %s. Total time %.1f s%n", out, (double)(t2 - t0) / 1000.0);
    }

    static class KernelPatch {
        final int radius;
        final int[] dx;
        final int[] dy;
        final float[] val;

        KernelPatch(int radius, int[] dx, int[] dy, float[] val) {
            this.radius = radius;
            this.dx = dx;
            this.dy = dy;
            this.val = val;
        }
    }
}

