/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.stat;

import cern.jet.stat.Probability;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.ejml.data.SingularMatrixException;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.simple.SimpleMatrix;

public class Logistic2ThresholdEstimator1 {
    private final int[] y_int_orig;
    private final double[][] X_orig;
    private final double[][] sortedX;
    private final double[][] Z_orig;
    private final int N;
    private final int m;
    private final int p;
    private static final int COEFF_INDEX_INTERCEPT = 0;
    private static final int COEFF_INDEX_K = 1;
    private static final int COEFF_INDEX_Z_START = 2;
    private final transient ThreadLocal<double[][]> threadLocalDesignMatrix;
    private final transient ThreadLocal<int[]> threadLocalCounts;
    private double tauMinProportion = 0.1;
    private int numQuantilesForC = 50;
    private int MAX_ITER_LOGISTIC = 25;
    private double CONVERGENCE_TOL_LOGISTIC = 1.0E-6;
    private int maxPermutations = 10000;
    private int minPermutations = 200;
    private int stopAfterNHits = 10;

    public Logistic2ThresholdEstimator1(double[] y, double[][] X_param, double[][] Z_param) {
        int i;
        this.y_int_orig = Arrays.stream(y).mapToInt(val -> (int)Math.round(val)).toArray();
        this.N = y.length;
        if (X_param.length != this.N) {
            throw new IllegalArgumentException("X length mismatch with Y");
        }
        int n = this.m = this.N > 0 && X_param[0] != null ? X_param[0].length : 0;
        if (this.m == 0 && this.N > 0) {
            throw new IllegalArgumentException("X cannot be empty.");
        }
        this.X_orig = new double[X_param.length][];
        this.sortedX = new double[X_param.length][];
        for (i = 0; i < X_param.length; ++i) {
            if (X_param[i] == null) continue;
            this.X_orig[i] = (double[])X_param[i].clone();
            this.sortedX[i] = (double[])X_param[i].clone();
            Arrays.sort(this.sortedX[i]);
        }
        if (Z_param != null) {
            if (Z_param.length != this.N) {
                throw new IllegalArgumentException("Z length mismatch with Y");
            }
            this.Z_orig = new double[Z_param.length][];
            this.p = this.N > 0 && Z_param[0] != null ? Z_param[0].length : 0;
            for (i = 0; i < Z_param.length; ++i) {
                if (Z_param[i] != null) {
                    this.Z_orig[i] = (double[])Z_param[i].clone();
                    continue;
                }
                if (this.p <= 0) continue;
                this.Z_orig[i] = new double[this.p];
            }
        } else {
            this.Z_orig = null;
            this.p = 0;
        }
        int numRegressorsExcludingIntercept = 1 + this.p;
        this.threadLocalDesignMatrix = ThreadLocal.withInitial(() -> new double[this.N][numRegressorsExcludingIntercept]);
        this.threadLocalCounts = ThreadLocal.withInitial(() -> new int[this.N]);
    }

    public void setTauMinProportion(double tau) {
        this.tauMinProportion = tau;
    }

    public void setNumQuantilesForC(int numQuantiles) {
        this.numQuantilesForC = Math.max(10, numQuantiles);
    }

    public void setMaxPermutations(int num) {
        this.maxPermutations = num;
    }

    public void setMinPermutations(int num) {
        this.minPermutations = num;
    }

    public void setStopAfterNHits(int num) {
        this.stopAfterNHits = num;
    }

    public void setMaxIterLogistic(int maxIter) {
        this.MAX_ITER_LOGISTIC = maxIter;
    }

    public void setConvergenceTolLogistic(double tol) {
        this.CONVERGENCE_TOL_LOGISTIC = tol;
    }

    private int countGreaterEqualSorted(double[] sorted_x_row, double c_val) {
        if (sorted_x_row == null || sorted_x_row.length == 0) {
            return 0;
        }
        int low = 0;
        int high = sorted_x_row.length - 1;
        int firstGeIndex = sorted_x_row.length;
        while (low <= high) {
            int mid = low + (high - low) / 2;
            if (sorted_x_row[mid] >= c_val) {
                firstGeIndex = mid;
                high = mid - 1;
                continue;
            }
            low = mid + 1;
        }
        return sorted_x_row.length - firstGeIndex;
    }

    private List<Double> createCGrid() {
        if (this.N == 0 || this.m == 0) {
            return Collections.emptyList();
        }
        HashSet<Double> uniqueXValuesSet = new HashSet<Double>();
        for (int i = 0; i < this.N; ++i) {
            if (this.X_orig[i] == null) continue;
            for (int j = 0; j < this.m; ++j) {
                uniqueXValuesSet.add(this.X_orig[i][j]);
            }
        }
        if (uniqueXValuesSet.isEmpty()) {
            return Collections.emptyList();
        }
        ArrayList<Double> sortedUniqueXValues = new ArrayList<Double>(uniqueXValuesSet);
        Collections.sort(sortedUniqueXValues);
        if (sortedUniqueXValues.size() <= this.numQuantilesForC) {
            return sortedUniqueXValues;
        }
        ArrayList c_grid = new ArrayList();
        for (int k = 0; k < this.numQuantilesForC; ++k) {
            double fraction = (double)k / ((double)this.numQuantilesForC - 1.0);
            int index = (int)Math.round(fraction * ((double)sortedUniqueXValues.size() - 1.0));
            c_grid.add(sortedUniqueXValues.get(index));
        }
        return c_grid.stream().distinct().sorted().collect(Collectors.toList());
    }

    private void computeXtWX_Optimized(SimpleMatrix X_full, double[] weights, SimpleMatrix XtWX_out) {
        int nObs = X_full.numRows();
        int numCoeffs = X_full.numCols();
        double[] XtWX_data = XtWX_out.getDDRM().getData();
        double[] xData = X_full.getDDRM().getData();
        for (int i = 0; i < numCoeffs; ++i) {
            for (int j = i; j < numCoeffs; ++j) {
                double sum = 0.0;
                for (int k = 0; k < nObs; ++k) {
                    sum += xData[k * numCoeffs + i] * xData[k * numCoeffs + j] * weights[k];
                }
                XtWX_data[i * numCoeffs + j] = sum;
                if (i == j) continue;
                XtWX_data[j * numCoeffs + i] = sum;
            }
        }
    }

    private LogisticRegressionResult performLogisticRegression(double[][] designMatrix, int[] yResponse, double[] initialBetaArray) {
        int nObs = yResponse.length;
        int numPredictors = designMatrix != null && nObs > 0 && designMatrix[0] != null ? designMatrix[0].length : 0;
        int numCoeffs = 1 + numPredictors;
        LogisticRegressionResult result = new LogisticRegressionResult();
        result.stdErrors = new double[numCoeffs];
        SimpleMatrix X_full = new SimpleMatrix(nObs, numCoeffs);
        for (int i = 0; i < nObs; ++i) {
            X_full.set(i, 0, 1.0);
            if (numPredictors <= 0 || designMatrix[i] == null) continue;
            for (int j = 0; j < numPredictors; ++j) {
                X_full.set(i, j + 1, designMatrix[i][j]);
            }
        }
        SimpleMatrix beta = initialBetaArray != null && initialBetaArray.length == numCoeffs ? new SimpleMatrix(numCoeffs, 1, true, initialBetaArray) : new SimpleMatrix(numCoeffs, 1);
        double[] yResponseDouble = new double[nObs];
        for (int i = 0; i < nObs; ++i) {
            yResponseDouble[i] = yResponse[i];
        }
        SimpleMatrix yVec = new SimpleMatrix(nObs, 1, true, yResponseDouble);
        SimpleMatrix linearPredictor = new SimpleMatrix(nObs, 1);
        SimpleMatrix pVec = new SimpleMatrix(nObs, 1);
        SimpleMatrix residuals = new SimpleMatrix(nObs, 1);
        SimpleMatrix gradient = new SimpleMatrix(numCoeffs, 1);
        SimpleMatrix hessian = new SimpleMatrix(numCoeffs, numCoeffs);
        SimpleMatrix deltaBeta = new SimpleMatrix(numCoeffs, 1);
        double[] probabilities = new double[nObs];
        double[] diagonalWeights = new double[nObs];
        for (int iter = 0; iter < this.MAX_ITER_LOGISTIC; ++iter) {
            SimpleMatrix hessianInv;
            int i;
            CommonOps_DDRM.mult(X_full.getDDRM(), beta.getDDRM(), linearPredictor.getDDRM());
            for (i = 0; i < nObs; ++i) {
                double lp = linearPredictor.get(i, 0);
                probabilities[i] = lp > 35.0 ? 0.999999999999999 : (lp < -35.0 ? 1.0E-15 : 1.0 / (1.0 + Math.exp(-lp)));
            }
            pVec.getDDRM().setData(probabilities);
            CommonOps_DDRM.subtract(yVec.getDDRM(), pVec.getDDRM(), residuals.getDDRM());
            CommonOps_DDRM.multTransA(X_full.getDDRM(), residuals.getDDRM(), gradient.getDDRM());
            for (i = 0; i < nObs; ++i) {
                double pi = probabilities[i];
                diagonalWeights[i] = Math.max(pi * (1.0 - pi), 1.0E-10);
            }
            this.computeXtWX_Optimized(X_full, diagonalWeights, hessian);
            CommonOps_DDRM.scale(-1.0, hessian.getDDRM());
            try {
                hessianInv = (SimpleMatrix)hessian.invert();
            }
            catch (SingularMatrixException e) {
                result.converged = false;
                result.coefficients = beta.getDDRM().getData();
                Arrays.fill(result.stdErrors, Double.NaN);
                return result;
            }
            CommonOps_DDRM.mult(hessianInv.getDDRM(), gradient.getDDRM(), deltaBeta.getDDRM());
            CommonOps_DDRM.scale(-1.0, deltaBeta.getDDRM());
            CommonOps_DDRM.add(beta.getDDRM(), deltaBeta.getDDRM(), beta.getDDRM());
            if (!(deltaBeta.normF() < this.CONVERGENCE_TOL_LOGISTIC)) continue;
            result.converged = true;
            break;
        }
        result.coefficients = beta.getDDRM().getData();
        if (result.converged) {
            int i;
            double ll = 0.0;
            CommonOps_DDRM.mult(X_full.getDDRM(), beta.getDDRM(), linearPredictor.getDDRM());
            for (i = 0; i < nObs; ++i) {
                double prob = 1.0 / (1.0 + Math.exp(-linearPredictor.get(i, 0)));
                probabilities[i] = Math.max(1.0E-15, Math.min(0.999999999999999, prob));
                if (yResponse[i] == 1) {
                    ll += Math.log(probabilities[i]);
                    continue;
                }
                ll += Math.log(1.0 - probabilities[i]);
            }
            result.logLikelihood = ll;
            try {
                for (i = 0; i < nObs; ++i) {
                    double pi = probabilities[i];
                    diagonalWeights[i] = Math.max(pi * (1.0 - pi), 1.0E-10);
                }
                this.computeXtWX_Optimized(X_full, diagonalWeights, hessian);
                SimpleMatrix covMatrix = (SimpleMatrix)hessian.invert();
                for (int i2 = 0; i2 < numCoeffs; ++i2) {
                    double var = covMatrix.get(i2, i2);
                    result.stdErrors[i2] = var >= 0.0 ? Math.sqrt(var) : Double.NaN;
                }
            }
            catch (SingularMatrixException e) {
                Arrays.fill(result.stdErrors, Double.NaN);
            }
        }
        return result;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * WARNING - void declaration
     */
    private GridSearchResult runGridSearchForMaxWaldStat(int[] currentY, int innerThreads) {
        List<Double> c_grid = this.createCGrid();
        if (c_grid.isEmpty()) {
            return new GridSearchResult();
        }
        GridSearchResult overallBestResult = new GridSearchResult();
        int numRegressorsExcludingIntercept = 1 + this.p;
        double[][] designMatrixTemplate = new double[this.N][numRegressorsExcludingIntercept];
        if (this.p > 0 && this.Z_orig != null) {
            for (int i = 0; i < this.N; ++i) {
                if (this.Z_orig[i] == null) continue;
                System.arraycopy(this.Z_orig[i], 0, designMatrixTemplate[i], 1, this.p);
            }
        }
        if (c_grid.size() < innerThreads * 2) {
            innerThreads = 1;
        }
        if (innerThreads > 1) {
            ExecutorService cExecutor = Executors.newFixedThreadPool(innerThreads);
            try {
                ArrayList<Future<GridSearchResult>> cFutures = new ArrayList<Future<GridSearchResult>>();
                int chunkSize = (int)Math.ceil((double)c_grid.size() / (double)innerThreads);
                for (int i = 0; i < c_grid.size(); i += chunkSize) {
                    List<Double> list = c_grid.subList(i, Math.min(i + chunkSize, c_grid.size()));
                    Callable<GridSearchResult> cTask = () -> {
                        GridSearchResult localBestResult = new GridSearchResult();
                        int[] local_counts_for_c = this.threadLocalCounts.get();
                        double[][] localDesignMatrix = this.threadLocalDesignMatrix.get();
                        Iterator iterator2 = c_chunk.iterator();
                        while (iterator2.hasNext()) {
                            double c_candidate = (Double)iterator2.next();
                            double[] lastBeta = null;
                            for (int k = 0; k < this.N; ++k) {
                                System.arraycopy(designMatrixTemplate[k], 0, localDesignMatrix[k], 0, numRegressorsExcludingIntercept);
                                local_counts_for_c[k] = this.countGreaterEqualSorted(this.sortedX[k], c_candidate);
                            }
                            HashSet<Integer> d_critical_points_set = new HashSet<Integer>();
                            for (int count_val : local_counts_for_c) {
                                if (count_val >= 1 && count_val <= this.m) {
                                    d_critical_points_set.add(count_val);
                                }
                                if (count_val + 1 < 1 || count_val + 1 > this.m) continue;
                                d_critical_points_set.add(count_val + 1);
                            }
                            if (this.m > 0) {
                                d_critical_points_set.add(1);
                                d_critical_points_set.add(this.m);
                            }
                            ArrayList reduced_d_grid = new ArrayList(d_critical_points_set);
                            Collections.sort(reduced_d_grid);
                            Iterator iterator3 = reduced_d_grid.iterator();
                            while (iterator3.hasNext()) {
                                int d_candidate = (Integer)iterator3.next();
                                int countRegime1 = 0;
                                for (int k = 0; k < this.N; ++k) {
                                    boolean D_val = local_counts_for_c[k] >= d_candidate;
                                    localDesignMatrix[k][0] = (double)D_val;
                                    if (!D_val) continue;
                                    ++countRegime1;
                                }
                                int countRegime0 = this.N - countRegime1;
                                if ((double)countRegime0 < (double)this.N * this.tauMinProportion || (double)countRegime1 < (double)this.N * this.tauMinProportion) continue;
                                LogisticRegressionResult fit = this.performLogisticRegression(localDesignMatrix, currentY, lastBeta);
                                if (fit.converged) {
                                    lastBeta = fit.coefficients;
                                    double currentWaldStat = Double.NaN;
                                    if (fit.stdErrors.length > 1 && !Double.isNaN(fit.stdErrors[1]) && fit.stdErrors[1] > 1.0E-9) {
                                        double tStat = fit.coefficients[1] / fit.stdErrors[1];
                                        currentWaldStat = tStat * tStat;
                                    }
                                    if (Double.isNaN(currentWaldStat) || !Double.isNaN(localBestResult.maxWaldStat) && !(currentWaldStat > localBestResult.maxWaldStat)) continue;
                                    localBestResult.maxWaldStat = currentWaldStat;
                                    localBestResult.bestModelDetails.logLikelihood = fit.logLikelihood;
                                    localBestResult.bestModelDetails.c_hat = c_candidate;
                                    localBestResult.bestModelDetails.d_hat = d_candidate;
                                    localBestResult.bestModelDetails.b_hat = fit.coefficients[0];
                                    localBestResult.bestModelDetails.k_hat = fit.coefficients[1];
                                    localBestResult.bestModelDetails.k_stdErr = fit.stdErrors[1];
                                    localBestResult.bestModelDetails.k_waldStat = currentWaldStat;
                                    if (this.p > 0) {
                                        localBestResult.bestModelDetails.z_hat = new double[this.p];
                                        System.arraycopy(fit.coefficients, 2, localBestResult.bestModelDetails.z_hat, 0, this.p);
                                        continue;
                                    }
                                    localBestResult.bestModelDetails.z_hat = new double[0];
                                    continue;
                                }
                                lastBeta = null;
                            }
                        }
                        return localBestResult;
                    };
                    cFutures.add(cExecutor.submit(cTask));
                }
                for (Future future : cFutures) {
                    GridSearchResult fitFromTask = (GridSearchResult)future.get();
                    if (fitFromTask == null || Double.isNaN(fitFromTask.maxWaldStat) || !Double.isNaN(overallBestResult.maxWaldStat) && !(fitFromTask.maxWaldStat > overallBestResult.maxWaldStat)) continue;
                    overallBestResult = fitFromTask;
                }
            }
            catch (InterruptedException | ExecutionException e) {
                if (e instanceof InterruptedException) {
                    Thread.currentThread().interrupt();
                } else {
                    System.err.println("Error in sub-task during grid search: " + e.getCause());
                }
                GridSearchResult chunkSize = new GridSearchResult();
                return chunkSize;
            }
            finally {
                cExecutor.shutdownNow();
            }
        } else {
            int[] counts_for_c = new int[this.N];
            for (double c_candidate : c_grid) {
                Object var11_22 = null;
                for (int i = 0; i < this.N; ++i) {
                    counts_for_c[i] = this.countGreaterEqualSorted(this.sortedX[i], c_candidate);
                }
                HashSet<Integer> d_critical_points_set = new HashSet<Integer>();
                for (int count_val : counts_for_c) {
                    if (count_val >= 1 && count_val <= this.m) {
                        d_critical_points_set.add(count_val);
                    }
                    if (count_val + 1 < 1 || count_val + 1 > this.m) continue;
                    d_critical_points_set.add(count_val + 1);
                }
                if (this.m > 0) {
                    d_critical_points_set.add(1);
                    d_critical_points_set.add(this.m);
                }
                ArrayList reduced_d_grid = new ArrayList(d_critical_points_set);
                Collections.sort(reduced_d_grid);
                Iterator iterator2 = reduced_d_grid.iterator();
                while (iterator2.hasNext()) {
                    void var11_21;
                    int d_candidate = (Integer)iterator2.next();
                    int countRegime1 = 0;
                    for (int i = 0; i < this.N; ++i) {
                        boolean D_val = counts_for_c[i] >= d_candidate;
                        designMatrixTemplate[i][0] = (double)D_val;
                        if (!D_val) continue;
                        ++countRegime1;
                    }
                    int countRegime0 = this.N - countRegime1;
                    if ((double)countRegime0 < (double)this.N * this.tauMinProportion || (double)countRegime1 < (double)this.N * this.tauMinProportion) continue;
                    LogisticRegressionResult fit = this.performLogisticRegression(designMatrixTemplate, currentY, (double[])var11_21);
                    if (fit != null && fit.converged) {
                        double[] dArray = fit.coefficients;
                        double currentWaldStat = Double.NaN;
                        if (fit.stdErrors != null && fit.stdErrors.length > 1 && !Double.isNaN(fit.stdErrors[1]) && fit.stdErrors[1] > 1.0E-9) {
                            double tStat = fit.coefficients[1] / fit.stdErrors[1];
                            currentWaldStat = tStat * tStat;
                        }
                        if (Double.isNaN(currentWaldStat) || !Double.isNaN(overallBestResult.maxWaldStat) && !(currentWaldStat > overallBestResult.maxWaldStat)) continue;
                        overallBestResult.maxWaldStat = currentWaldStat;
                        overallBestResult.bestModelDetails.logLikelihood = fit.logLikelihood;
                        overallBestResult.bestModelDetails.c_hat = c_candidate;
                        overallBestResult.bestModelDetails.d_hat = d_candidate;
                        overallBestResult.bestModelDetails.b_hat = fit.coefficients[0];
                        overallBestResult.bestModelDetails.k_hat = fit.coefficients[1];
                        overallBestResult.bestModelDetails.k_stdErr = fit.stdErrors[1];
                        overallBestResult.bestModelDetails.k_waldStat = currentWaldStat;
                        if (this.p > 0) {
                            overallBestResult.bestModelDetails.z_hat = new double[this.p];
                            System.arraycopy(fit.coefficients, 2, overallBestResult.bestModelDetails.z_hat, 0, this.p);
                            continue;
                        }
                        overallBestResult.bestModelDetails.z_hat = new double[0];
                        continue;
                    }
                    Object var11_24 = null;
                }
            }
        }
        return overallBestResult;
    }

    public ThresholdModelResult fit(long randSeed, int threadNum) {
        GridSearchResult observedResult = this.runGridSearchForMaxWaldStat(this.y_int_orig, threadNum);
        ThresholdModelResult finalResult = new ThresholdModelResult();
        if (Double.isNaN(observedResult.maxWaldStat)) {
            finalResult.pValuePermutation = 1.0;
            finalResult.k_pValue = 1.0;
            return finalResult;
        }
        ModelFitDetails bestObservedModel = observedResult.bestModelDetails;
        finalResult.b_hat = bestObservedModel.b_hat;
        finalResult.k_hat = bestObservedModel.k_hat;
        finalResult.z_hat = bestObservedModel.z_hat != null ? (double[])bestObservedModel.z_hat.clone() : new double[]{};
        finalResult.c_hat = bestObservedModel.c_hat;
        finalResult.d_hat = bestObservedModel.d_hat;
        finalResult.logLikelihood = bestObservedModel.logLikelihood;
        finalResult.k_stdErr = bestObservedModel.k_stdErr;
        finalResult.k_waldStat = bestObservedModel.k_waldStat;
        if (!Double.isNaN(finalResult.k_hat) && !Double.isNaN(finalResult.k_stdErr) && finalResult.k_stdErr > 1.0E-9) {
            finalResult.k_zValue = finalResult.k_hat / finalResult.k_stdErr;
            finalResult.k_pValue = 2.0 * Probability.normal(-Math.abs(finalResult.k_zValue));
        } else {
            finalResult.k_pValue = Double.NaN;
        }
        double observedMaxWaldStat = observedResult.maxWaldStat;
        if (this.maxPermutations <= 0) {
            return finalResult;
        }
        AtomicInteger hits = new AtomicInteger(0);
        int permutationsRun = 0;
        Random masterRng = new Random(randSeed);
        List baseIndices = IntStream.range(0, this.N).boxed().collect(Collectors.toList());
        if (threadNum > 1) {
            ExecutorService permExecutor = Executors.newFixedThreadPool(threadNum);
            ArrayList<Future<Double>> futures = new ArrayList<Future<Double>>(this.maxPermutations);
            for (int b = 0; b < this.maxPermutations; ++b) {
                long pSeed = masterRng.nextLong();
                Callable<Double> task = () -> {
                    ArrayList permutedIndices = new ArrayList(baseIndices);
                    Collections.shuffle(permutedIndices, new Random(pSeed));
                    int[] y_perm = new int[this.N];
                    for (int i = 0; i < this.N; ++i) {
                        y_perm[i] = this.y_int_orig[(Integer)permutedIndices.get(i)];
                    }
                    GridSearchResult permutedGridResult = this.runGridSearchForMaxWaldStat(y_perm, 1);
                    return permutedGridResult.maxWaldStat;
                };
                futures.add(permExecutor.submit(task));
            }
            try {
                for (int i = 0; i < futures.size(); ++i) {
                    Future future = (Future)futures.get(i);
                    Double permutedMaxWaldStat = (Double)future.get();
                    ++permutationsRun;
                    if (permutedMaxWaldStat != null && !Double.isNaN(permutedMaxWaldStat) && permutedMaxWaldStat >= observedMaxWaldStat) {
                        hits.incrementAndGet();
                    }
                    if (permutationsRun < this.minPermutations || hits.get() < this.stopAfterNHits) continue;
                    for (int j = i + 1; j < futures.size(); ++j) {
                        ((Future)futures.get(j)).cancel(true);
                    }
                    permExecutor.shutdownNow();
                }
            }
            catch (InterruptedException | CancellationException e) {
                Thread.currentThread().interrupt();
            }
            catch (ExecutionException e) {
                System.err.println("Error during permutation: " + e.getCause());
                throw new RuntimeException("A permutation task failed.", e.getCause());
            }
            finally {
                if (!permExecutor.isTerminated()) {
                    permExecutor.shutdownNow();
                }
            }
        } else {
            Random permRng = new Random();
            for (int b = 0; b < this.maxPermutations; ++b) {
                permRng.setSeed(masterRng.nextLong());
                ArrayList permutedIndices = new ArrayList(baseIndices);
                Collections.shuffle(permutedIndices, permRng);
                int[] y_perm = new int[this.N];
                for (int i = 0; i < this.N; ++i) {
                    y_perm[i] = this.y_int_orig[(Integer)permutedIndices.get(i)];
                }
                GridSearchResult permutedGridResult = this.runGridSearchForMaxWaldStat(y_perm, 1);
                Double permutedMaxWaldStat = permutedGridResult.maxWaldStat;
                ++permutationsRun;
                if (permutedMaxWaldStat != null && !Double.isNaN(permutedMaxWaldStat) && permutedMaxWaldStat >= observedMaxWaldStat) {
                    hits.incrementAndGet();
                }
                if (permutationsRun >= this.minPermutations && hits.get() >= this.stopAfterNHits) break;
            }
        }
        finalResult.permutationsPerformed = permutationsRun;
        finalResult.pValuePermutation = ((double)hits.get() + 1.0) / ((double)permutationsRun + 1.0);
        return finalResult;
    }

    public static void main(String[] args) {
        double[] dArray;
        int N_samples = 4000;
        int m_dim = 50;
        int p_dim = 2;
        RandomDataGenerator rng = new RandomDataGenerator();
        rng.reSeed(12345L);
        double b_true = -0.5;
        double k_true = 0.0;
        double c_true = 0.8;
        int d_true = 10;
        if (p_dim > 0) {
            double[] dArray2 = new double[2];
            dArray2[0] = 0.8;
            dArray = dArray2;
            dArray2[1] = -0.6;
        } else {
            dArray = null;
        }
        double[] z_true = dArray;
        System.out.println("--- Generating Test Data (H0: k=0) ---");
        Logistic2ThresholdEstimator1.runTest("H0 Test (k=0)", N_samples, m_dim, p_dim, b_true, k_true, c_true, d_true, z_true);
        k_true = 0.5;
        System.out.println("\n--- Generating Test Data (H1: k=0.5) ---");
        Logistic2ThresholdEstimator1.runTest("H1 Test (k=0.5)", N_samples, m_dim, p_dim, b_true, k_true, c_true, d_true, z_true);
    }

    private static void runTest(String testName, int N_samples, int m_dim, int p_dim, double b_true, double k_true, double c_true, int d_true, double[] z_true) {
        RandomDataGenerator rng = new RandomDataGenerator();
        rng.reSeed(12345L);
        double[] y_data_double = new double[N_samples];
        double[][] X_data = new double[N_samples][m_dim];
        double[][] Z_data = p_dim > 0 ? new double[N_samples][p_dim] : (double[][])null;
        System.out.println("Test: " + testName);
        System.out.println("N=" + N_samples + ", m=" + m_dim + ", p=" + p_dim);
        System.out.println("True params: b=" + b_true + ", k=" + k_true + ", c=" + c_true + ", d=" + d_true + (p_dim > 0 && z_true != null ? ", z=" + Arrays.toString(z_true) : ""));
        long caseCount = 0L;
        for (int i = 0; i < N_samples; ++i) {
            int j;
            for (j = 0; j < m_dim; ++j) {
                X_data[i][j] = rng.nextGaussian(0.0, 1.0);
            }
            if (p_dim > 0 && Z_data != null) {
                for (j = 0; j < p_dim; ++j) {
                    Z_data[i][j] = rng.nextUniform(-1.0, 1.0);
                }
            }
            int count_X_ge_c = 0;
            for (double x_val : X_data[i]) {
                if (!(x_val >= c_true)) continue;
                ++count_X_ge_c;
            }
            boolean I_val = count_X_ge_c >= d_true;
            double W_i = b_true + k_true * (double)I_val;
            if (p_dim > 0 && Z_data != null && Z_data[i] != null && z_true != null) {
                for (int j2 = 0; j2 < p_dim; ++j2) {
                    W_i += Z_data[i][j2] * z_true[j2];
                }
            }
            double prob_y1 = 1.0 / (1.0 + Math.exp(-W_i));
            double d = y_data_double[i] = rng.nextUniform(0.0, 1.0) < prob_y1 ? 1.0 : 0.0;
            if (y_data_double[i] != 1.0) continue;
            ++caseCount;
        }
        System.out.println("Cases: " + caseCount + ", Controls: " + ((long)N_samples - caseCount));
        Logistic2ThresholdEstimator1 estimator = new Logistic2ThresholdEstimator1(y_data_double, X_data, Z_data);
        estimator.setTauMinProportion(0.05);
        estimator.setNumQuantilesForC(30);
        estimator.setMinPermutations(200);
        estimator.setMaxPermutations(10000);
        estimator.setStopAfterNHits(20);
        System.out.println("\n--- Starting Analysis ---");
        int nThreads = Runtime.getRuntime().availableProcessors();
        System.out.println("Settings: tau=" + estimator.tauMinProportion + ", numCQuantiles=" + estimator.numQuantilesForC + ", minPerms=" + estimator.minPermutations + ", maxPerms=" + estimator.maxPermutations + ", stopAfterHits=" + estimator.stopAfterNHits + ", threads=" + nThreads);
        long startTime = System.currentTimeMillis();
        ThresholdModelResult result = estimator.fit(1000L, nThreads);
        long endTime = System.currentTimeMillis();
        System.out.println("Analysis finished in " + (endTime - startTime) + " ms.");
        System.out.println("\n--- Final Results for " + testName + " ---");
        System.out.println(result);
    }

    private static class GridSearchResult {
        double maxWaldStat = Double.NaN;
        ModelFitDetails bestModelDetails = new ModelFitDetails();

        private GridSearchResult() {
        }
    }

    private static class LogisticRegressionResult {
        double[] coefficients;
        double[] stdErrors;
        double logLikelihood = Double.NEGATIVE_INFINITY;
        boolean converged = false;

        private LogisticRegressionResult() {
        }
    }

    private static class ModelFitDetails {
        public double b_hat = Double.NaN;
        public double k_hat = Double.NaN;
        public double[] z_hat;
        public double c_hat = Double.NaN;
        public int d_hat = -1;
        public double logLikelihood = Double.NEGATIVE_INFINITY;
        public double k_stdErr = Double.NaN;
        public double k_waldStat = Double.NaN;

        private ModelFitDetails() {
        }
    }

    public static class ThresholdModelResult {
        public double b_hat = Double.NaN;
        public double k_hat = Double.NaN;
        public double[] z_hat;
        public double c_hat = Double.NaN;
        public int d_hat = -1;
        public double logLikelihood = Double.NEGATIVE_INFINITY;
        public double k_stdErr = Double.NaN;
        public double k_waldStat = Double.NaN;
        public double k_zValue = Double.NaN;
        public double k_pValue = Double.NaN;
        public double pValuePermutation = Double.NaN;
        public int permutationsPerformed = 0;

        public String toString() {
            String z_hat_str = "N/A";
            if (this.z_hat != null && this.z_hat.length > 0) {
                z_hat_str = Arrays.stream(this.z_hat).mapToObj(val -> String.format("%.4f", val)).collect(Collectors.joining(", "));
            }
            return "ThresholdModelResult{\nb_hat=" + String.format("%.4f", this.b_hat) + ", \nk_hat=" + String.format("%.4f", this.k_hat) + ", \nz_hat=[" + z_hat_str + "], \nc_hat=" + String.format("%.4f", this.c_hat) + ", \nd_hat=" + this.d_hat + ", \nlogLikelihood=" + String.format("%.4f", this.logLikelihood) + ", \nk_stdErr=" + String.format("%.4f", this.k_stdErr) + ", \nk_waldStat=" + String.format("%.4f", this.k_waldStat) + ", \nk_zValue=" + String.format("%.4f", this.k_zValue) + ", \nk_pValue=" + (Double.isNaN(this.k_pValue) ? "NaN" : String.format("%.5g", this.k_pValue)) + ", \npValuePermutation=" + (Double.isNaN(this.pValuePermutation) ? "NaN" : String.format("%.5g", this.pValuePermutation)) + ", \npermutationsPerformed=" + this.permutationsPerformed + "\n}";
        }
    }
}

