/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.stat;

import org.apache.commons.math3.complex.Complex;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.transform.DftNormalization;
import org.apache.commons.math3.transform.FastFourierTransformer;
import org.apache.commons.math3.transform.TransformType;

public class FastGpSampler {
    private final int nx;
    private final int ny;
    private final int N;
    private final int M;
    private final Complex[] sqrtEigenvalues;
    private final FastFourierTransformer fft;

    public FastGpSampler(int nx, int ny, double sigma2, double l) {
        if (nx <= 0 || ny <= 0) {
            throw new IllegalArgumentException("Grid dimensions (nx, ny) must be positive. Got: nx=" + nx + ", ny=" + ny);
        }
        this.nx = nx;
        this.ny = ny;
        this.N = nx == 1 ? 1 : Integer.highestOneBit(2 * nx - 2) << 1;
        this.M = ny == 1 ? 1 : Integer.highestOneBit(2 * ny - 2) << 1;
        this.fft = new FastFourierTransformer(DftNormalization.STANDARD);
        Complex[][] C = this.buildCirculantCovariance(nx, ny, this.N, this.M, sigma2, l);
        Complex[][] eigenvalues2D = this.fft2d(C, this.N, this.M, TransformType.FORWARD);
        this.sqrtEigenvalues = new Complex[this.N * this.M];
        for (int i = 0; i < this.N; ++i) {
            for (int j = 0; j < this.M; ++j) {
                double realEigen = Math.max(0.0, eigenvalues2D[i][j].getReal());
                this.sqrtEigenvalues[i * this.M + j] = new Complex(Math.sqrt(realEigen));
            }
        }
    }

    private Complex[][] buildCirculantCovariance(int nx, int ny, int N, int M, double sigma2, double l) {
        Complex[][] C = new Complex[N][M];
        for (int i = 0; i < N; ++i) {
            int x_dist = Math.min(i, N - i);
            for (int j = 0; j < M; ++j) {
                int y_dist = Math.min(j, M - j);
                double dist2 = (double)x_dist * (double)x_dist + (double)y_dist * (double)y_dist;
                double cov = sigma2 * Math.exp(-dist2 / (2.0 * l * l));
                C[i][j] = new Complex(cov, 0.0);
            }
        }
        return C;
    }

    public double[] sample(RandomGenerator random) {
        Complex[] noiseFreq = new Complex[this.N * this.M];
        double invSqrt2 = 1.0 / Math.sqrt(2.0);
        for (int i = 0; i < noiseFreq.length; ++i) {
            noiseFreq[i] = new Complex(random.nextGaussian() * invSqrt2, random.nextGaussian() * invSqrt2);
        }
        Complex[] sampleFreq = new Complex[this.N * this.M];
        for (int i = 0; i < sampleFreq.length; ++i) {
            sampleFreq[i] = this.sqrtEigenvalues[i].multiply(noiseFreq[i]);
        }
        Complex[][] sampleFreq2D = new Complex[this.N][this.M];
        for (int i = 0; i < this.N; ++i) {
            System.arraycopy(sampleFreq, i * this.M, sampleFreq2D[i], 0, this.M);
        }
        Complex[][] sampleSpatial2D = this.fft2d(sampleFreq2D, this.N, this.M, TransformType.INVERSE);
        double[] result = new double[this.nx * this.ny];
        for (int i = 0; i < this.nx; ++i) {
            for (int j = 0; j < this.ny; ++j) {
                result[i * this.ny + j] = sampleSpatial2D[i][j].getReal() / (double)(this.N * this.M);
            }
        }
        double jitter = 1.0E-6;
        int i = 0;
        while (i < result.length) {
            int n = i++;
            result[n] = result[n] + random.nextGaussian() * jitter;
        }
        return result;
    }

    private Complex[][] fft2d(Complex[][] data, int rows, int cols, TransformType type) {
        Complex[][] rowTransformed = new Complex[rows][cols];
        for (int i = 0; i < rows; ++i) {
            rowTransformed[i] = this.fft.transform(data[i], type);
        }
        Complex[][] colTransformed = new Complex[rows][cols];
        Complex[] column = new Complex[rows];
        for (int j = 0; j < cols; ++j) {
            for (int i = 0; i < rows; ++i) {
                column[i] = rowTransformed[i][j];
            }
            Complex[] transformedColumn = this.fft.transform(column, type);
            for (int i = 0; i < rows; ++i) {
                colTransformed[i][j] = transformedColumn[i];
            }
        }
        return colTransformed;
    }
}

