/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.stat;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.ejml.data.SingularMatrixException;
import org.ejml.simple.SimpleMatrix;

public class Logistic1ThresholdEstimator2 {
    private final int[] y_int_orig;
    private final int[] X_scalar_orig;
    private final double[][] Z_orig;
    private final int N;
    private final int p;
    private double tauMinProportion = 0.1;
    private int numBootstrapReplications = 199;
    private final int MAX_ITER_LOGISTIC = 25;
    private final double CONVERGENCE_TOL_LOGISTIC = 1.0E-6;

    public Logistic1ThresholdEstimator2(int[] y, int[] X_scalar_param, double[][] Z_param) {
        this.y_int_orig = (int[])y.clone();
        this.N = y.length;
        if (X_scalar_param.length != this.N) {
            throw new IllegalArgumentException("X_scalar length mismatch with Y");
        }
        this.X_scalar_orig = (int[])X_scalar_param.clone();
        if (Z_param != null) {
            if (Z_param.length != this.N) {
                throw new IllegalArgumentException("Z length mismatch with Y");
            }
            this.Z_orig = new double[Z_param.length][];
            this.p = this.N > 0 && Z_param[0] != null ? Z_param[0].length : 0;
            for (int i = 0; i < Z_param.length; ++i) {
                if (Z_param[i] != null) {
                    this.Z_orig[i] = (double[])Z_param[i].clone();
                    continue;
                }
                if (this.p <= 0) continue;
                this.Z_orig[i] = new double[this.p];
            }
        } else {
            this.Z_orig = null;
            this.p = 0;
        }
    }

    public void setTauMinProportion(double tau) {
        this.tauMinProportion = tau;
    }

    public void setNumBootstrapReplications(int num) {
        this.numBootstrapReplications = num;
    }

    private List<Integer> createDGrid(int[] x_values) {
        if (x_values == null || x_values.length == 0) {
            return Collections.emptyList();
        }
        TreeSet<Integer> uniqueXSet = new TreeSet<Integer>();
        for (int val : x_values) {
            uniqueXSet.add(val);
        }
        if (uniqueXSet.isEmpty()) {
            return Collections.singletonList(0);
        }
        ArrayList<Integer> d_grid = new ArrayList<Integer>();
        int minX = (Integer)uniqueXSet.first();
        int maxX = (Integer)uniqueXSet.last();
        for (int d = minX; d <= maxX + 1; ++d) {
            d_grid.add(d);
        }
        return d_grid;
    }

    private SimpleMatrix computeXtWX_Optimized(SimpleMatrix X_full, double[] weights) {
        int nObs = X_full.numRows();
        int numCoeffs = X_full.numCols();
        double[][] XtWX_data = new double[numCoeffs][numCoeffs];
        double[] xData = X_full.getDDRM().getData();
        for (int i = 0; i < numCoeffs; ++i) {
            for (int j = i; j < numCoeffs; ++j) {
                double sum = 0.0;
                for (int k = 0; k < nObs; ++k) {
                    sum += xData[k * numCoeffs + i] * xData[k * numCoeffs + j] * weights[k];
                }
                XtWX_data[i][j] = sum;
                if (i == j) continue;
                XtWX_data[j][i] = sum;
            }
        }
        return new SimpleMatrix(XtWX_data);
    }

    private LogisticRegressionResult performLogisticRegression(double[][] designMatrix, int[] yResponse, double[] initialBetaArray) {
        int nObs = yResponse.length;
        int numPredictors = designMatrix != null && nObs > 0 && designMatrix[0] != null ? designMatrix[0].length : 0;
        int numCoeffs = 1 + numPredictors;
        LogisticRegressionResult result = new LogisticRegressionResult();
        result.stdErrors = new double[numCoeffs];
        SimpleMatrix X_full = new SimpleMatrix(nObs, numCoeffs);
        for (int i = 0; i < nObs; ++i) {
            X_full.set(i, 0, 1.0);
            if (numPredictors <= 0 || designMatrix[i] == null) continue;
            for (int j = 0; j < numPredictors; ++j) {
                X_full.set(i, j + 1, designMatrix[i][j]);
            }
        }
        SimpleMatrix beta = initialBetaArray != null && initialBetaArray.length == numCoeffs ? new SimpleMatrix(numCoeffs, 1, true, initialBetaArray) : new SimpleMatrix(numCoeffs, 1);
        double[] probabilities = new double[nObs];
        double[] diagonalWeights = new double[nObs];
        double[] yResponseDouble = new double[nObs];
        for (int i = 0; i < nObs; ++i) {
            yResponseDouble[i] = yResponse[i];
        }
        SimpleMatrix yVec = new SimpleMatrix(nObs, 1, true, yResponseDouble);
        for (int iter = 0; iter < 25; ++iter) {
            SimpleMatrix hessianInv;
            SimpleMatrix linearPredictor = X_full.mult(beta);
            for (int i = 0; i < nObs; ++i) {
                double lp = linearPredictor.get(i, 0);
                probabilities[i] = lp > 35.0 ? 0.999999999999999 : (lp < -35.0 ? 1.0E-15 : 1.0 / (1.0 + Math.exp(-lp)));
            }
            SimpleMatrix pVec = new SimpleMatrix(nObs, 1, true, probabilities);
            SimpleMatrix residuals = yVec.minus(pVec);
            SimpleMatrix gradient = ((SimpleMatrix)X_full.transpose()).mult(residuals);
            for (int i = 0; i < nObs; ++i) {
                double pi = probabilities[i];
                diagonalWeights[i] = Math.max(pi * (1.0 - pi), 1.0E-10);
            }
            SimpleMatrix hessian = (SimpleMatrix)this.computeXtWX_Optimized(X_full, diagonalWeights).scale(-1.0);
            try {
                hessianInv = (SimpleMatrix)hessian.invert();
            }
            catch (SingularMatrixException e) {
                result.converged = false;
                result.coefficients = beta.getDDRM().getData();
                Arrays.fill(result.stdErrors, Double.NaN);
                result.logLikelihood = Double.NEGATIVE_INFINITY;
                return result;
            }
            SimpleMatrix deltaBeta = (SimpleMatrix)hessianInv.mult(gradient).scale(-1.0);
            beta = beta.plus(deltaBeta);
            if (!(deltaBeta.normF() < 1.0E-6)) continue;
            result.converged = true;
            break;
        }
        result.coefficients = beta.getDDRM().getData();
        if (result.converged) {
            int i;
            double ll = 0.0;
            SimpleMatrix finalLP = X_full.mult(beta);
            for (i = 0; i < nObs; ++i) {
                double lp = finalLP.get(i, 0);
                double prob = lp > 35.0 ? 0.999999999999999 : (lp < -35.0 ? 1.0E-15 : 1.0 / (1.0 + Math.exp(-lp)));
                probabilities[i] = prob;
                if (yResponse[i] == 1) {
                    ll += Math.log(prob);
                    continue;
                }
                ll += Math.log(1.0 - prob);
            }
            result.logLikelihood = ll;
            try {
                for (i = 0; i < nObs; ++i) {
                    double pi = probabilities[i];
                    double weight = pi * (1.0 - pi);
                    diagonalWeights[i] = Math.max(weight, 1.0E-10);
                }
                SimpleMatrix XtWX = this.computeXtWX_Optimized(X_full, diagonalWeights);
                SimpleMatrix covMatrix = (SimpleMatrix)XtWX.invert();
                for (int i2 = 0; i2 < numCoeffs; ++i2) {
                    double var = covMatrix.get(i2, i2);
                    result.stdErrors[i2] = var >= 0.0 ? Math.sqrt(var) : Double.NaN;
                }
            }
            catch (SingularMatrixException e) {
                Arrays.fill(result.stdErrors, Double.NaN);
            }
        } else {
            result.logLikelihood = Double.NEGATIVE_INFINITY;
            Arrays.fill(result.stdErrors, Double.NaN);
        }
        return result;
    }

    private ModelFitDetails estimateParametersInternalLogistic(int[] currentY, boolean isBootstrapRun) {
        int innerThreads;
        List<Integer> d_grid = this.createDGrid(this.X_scalar_orig);
        ModelFitDetails overallBestFit = new ModelFitDetails();
        int numRegressorsExcludingIntercept = 1 + this.p;
        if (d_grid.isEmpty()) {
            return overallBestFit;
        }
        double[][] designMatrixForLogit = new double[this.N][numRegressorsExcludingIntercept];
        if (this.p > 0 && this.Z_orig != null) {
            for (int i = 0; i < this.N; ++i) {
                if (this.Z_orig[i] == null) continue;
                System.arraycopy(this.Z_orig[i], 0, designMatrixForLogit[i], 1, this.p);
            }
        }
        int availableProcessors = Runtime.getRuntime().availableProcessors();
        if (isBootstrapRun) {
            innerThreads = 1;
        } else {
            innerThreads = Math.max(1, availableProcessors - 1);
            if (d_grid.size() < innerThreads * 2) {
                innerThreads = Math.max(1, d_grid.size() / 2);
            }
        }
        if (d_grid.size() < 5) {
            innerThreads = 1;
        }
        if (innerThreads > 1) {
            ExecutorService dExecutor = Executors.newFixedThreadPool(innerThreads);
            ArrayList<Future<ModelFitDetails>> dFutures = new ArrayList<Future<ModelFitDetails>>();
            for (int n : d_grid) {
                Callable<ModelFitDetails> dTask = () -> {
                    double[][] localDesignMatrix = new double[this.N][];
                    for (int i = 0; i < this.N; ++i) {
                        localDesignMatrix[i] = (double[])designMatrixForLogit[i].clone();
                    }
                    ModelFitDetails fitForThisD = new ModelFitDetails();
                    fitForThisD.d_hat = d_candidate;
                    int countRegime1 = 0;
                    for (int i = 0; i < this.N; ++i) {
                        boolean D_val = this.X_scalar_orig[i] >= d_candidate;
                        localDesignMatrix[i][0] = (double)D_val;
                        if (!D_val) continue;
                        ++countRegime1;
                    }
                    int countRegime0 = this.N - countRegime1;
                    if (isBootstrapRun ? countRegime0 < 1 || countRegime1 < 1 : (double)countRegime0 < (double)this.N * this.tauMinProportion || (double)countRegime1 < (double)this.N * this.tauMinProportion) {
                        return fitForThisD;
                    }
                    LogisticRegressionResult fit = this.performLogisticRegression(localDesignMatrix, currentY, null);
                    if (fit != null && fit.converged) {
                        fitForThisD.logLikelihood = fit.logLikelihood;
                        fitForThisD.b_hat = fit.coefficients[0];
                        fitForThisD.k_hat = fit.coefficients[1];
                        if (this.p > 0) {
                            fitForThisD.z_hat = new double[this.p];
                            System.arraycopy(fit.coefficients, 2, fitForThisD.z_hat, 0, this.p);
                        } else {
                            fitForThisD.z_hat = new double[0];
                        }
                        if (fit.stdErrors != null && fit.stdErrors.length > 1) {
                            fitForThisD.k_stdErr = fit.stdErrors[1];
                            if (!Double.isNaN(fitForThisD.k_stdErr) && fitForThisD.k_stdErr > 1.0E-9) {
                                fitForThisD.k_tStat = fitForThisD.k_hat / fitForThisD.k_stdErr;
                            }
                        }
                    }
                    return fitForThisD;
                };
                dFutures.add(dExecutor.submit(dTask));
            }
            dExecutor.shutdown();
            try {
                for (Future future : dFutures) {
                    ModelFitDetails fitFromTask = (ModelFitDetails)future.get();
                    if (fitFromTask == null || !(fitFromTask.logLikelihood > overallBestFit.logLikelihood)) continue;
                    overallBestFit = fitFromTask;
                }
            }
            catch (Exception e) {
                System.err.println("Error during parallel d_grid search: " + e.getMessage());
                e.printStackTrace();
            }
        } else {
            double[] lastBeta = null;
            for (int d_candidate : d_grid) {
                int n = 0;
                for (int i = 0; i < this.N; ++i) {
                    boolean D_val = this.X_scalar_orig[i] >= d_candidate;
                    designMatrixForLogit[i][0] = (double)D_val;
                    if (!D_val) continue;
                    ++n;
                }
                int countRegime0 = this.N - n;
                if (!isBootstrapRun ? (double)countRegime0 < (double)this.N * this.tauMinProportion || (double)n < (double)this.N * this.tauMinProportion : countRegime0 < 1 || n < 1) continue;
                LogisticRegressionResult fit = this.performLogisticRegression(designMatrixForLogit, currentY, lastBeta);
                if (fit != null && fit.converged) {
                    lastBeta = fit.coefficients;
                    if (!(fit.logLikelihood > overallBestFit.logLikelihood)) continue;
                    overallBestFit.logLikelihood = fit.logLikelihood;
                    overallBestFit.d_hat = d_candidate;
                    overallBestFit.b_hat = fit.coefficients[0];
                    overallBestFit.k_hat = fit.coefficients[1];
                    if (this.p > 0) {
                        overallBestFit.z_hat = new double[this.p];
                        System.arraycopy(fit.coefficients, 2, overallBestFit.z_hat, 0, this.p);
                    } else {
                        overallBestFit.z_hat = new double[0];
                    }
                    if (fit.stdErrors == null || fit.stdErrors.length <= 1) continue;
                    overallBestFit.k_stdErr = fit.stdErrors[1];
                    if (Double.isNaN(overallBestFit.k_stdErr) || !(overallBestFit.k_stdErr > 1.0E-9)) continue;
                    overallBestFit.k_tStat = overallBestFit.k_hat / overallBestFit.k_stdErr;
                    continue;
                }
                lastBeta = null;
            }
        }
        return overallBestFit;
    }

    public ThresholdModelResult estimateParametersAndTestTwoStage(double screeningAlpha, int randSeed) {
        ModelFitDetails initialFit = this.estimateParametersInternalLogistic(this.y_int_orig, false);
        ThresholdModelResult finalResult = new ThresholdModelResult();
        if (initialFit.logLikelihood == Double.NEGATIVE_INFINITY || Double.isNaN(initialFit.k_hat)) {
            System.err.println("Stage 1: Initial fit failed. Skipping Bootstrap.");
            finalResult.logLikelihood = initialFit.logLikelihood;
            finalResult.pValueK = Double.NaN;
            finalResult.b_hat = initialFit.b_hat;
            finalResult.k_hat = initialFit.k_hat;
            finalResult.z_hat = initialFit.z_hat != null ? (double[])initialFit.z_hat.clone() : null;
            finalResult.d_hat = initialFit.d_hat;
            finalResult.k_stdErr = initialFit.k_stdErr;
            finalResult.k_tStat = initialFit.k_tStat;
            return finalResult;
        }
        finalResult.b_hat = initialFit.b_hat;
        finalResult.k_hat = initialFit.k_hat;
        finalResult.z_hat = initialFit.z_hat != null ? (double[])initialFit.z_hat.clone() : null;
        finalResult.d_hat = initialFit.d_hat;
        finalResult.logLikelihood = initialFit.logLikelihood;
        finalResult.k_stdErr = initialFit.k_stdErr;
        finalResult.k_tStat = initialFit.k_tStat;
        double nominalPValue = 1.0;
        if (!Double.isNaN(finalResult.k_tStat) && !Double.isNaN(finalResult.k_stdErr) && finalResult.k_stdErr > 1.0E-12) {
            double waldStat = Math.pow(finalResult.k_tStat, 2.0);
            ChiSquaredDistribution chi2Dist = new ChiSquaredDistribution(1.0);
            nominalPValue = 1.0 - chi2Dist.cumulativeProbability(waldStat);
            System.out.println("  Stage 1: Nominal (screening) Wald p-value: " + String.format("%.4e", nominalPValue));
        } else {
            System.out.println("  Stage 1: Could not calculate nominal Wald p-value (k_tStat or k_stdErr is NaN/problematic).");
        }
        finalResult.pValueK = nominalPValue;
        if (nominalPValue < screeningAlpha && this.numBootstrapReplications > 0) {
            int b;
            System.out.println("  Stage 1 p-value < " + screeningAlpha + ". Proceeding to Bootstrap confirmation (" + this.numBootstrapReplications + " reps).");
            ModelFitDetails h0Fit = this.estimateH0Logistic(this.y_int_orig);
            if (Double.isNaN(h0Fit.b_hat)) {
                System.err.println("Stage 2: H0 fit failed for Bootstrap. Bootstrap p-value will be NaN.");
                return finalResult;
            }
            double final_b0_hat_H0 = h0Fit.b_hat;
            double[] final_z0_hat_H0 = h0Fit.z_hat != null ? (double[])h0Fit.z_hat.clone() : new double[]{};
            List<Double> bootstrap_t_stats = Collections.synchronizedList(new ArrayList());
            int numThreads = Math.max(1, Runtime.getRuntime().availableProcessors() - 1);
            ExecutorService bootstrapExecutor = Executors.newFixedThreadPool(numThreads);
            ArrayList<Future<Double>> futures = new ArrayList<Future<Double>>();
            RandomDataGenerator outerRandomGen = new RandomDataGenerator();
            outerRandomGen.reSeed(randSeed);
            long[] seeds = new long[this.numBootstrapReplications];
            for (b = 0; b < this.numBootstrapReplications; ++b) {
                seeds[b] = outerRandomGen.nextLong(Long.MIN_VALUE, Long.MAX_VALUE);
            }
            b = 0;
            while (b < this.numBootstrapReplications) {
                int finalB = b++;
                Callable<Double> callable = () -> {
                    RandomDataGenerator localRandomGen = new RandomDataGenerator();
                    localRandomGen.reSeed(seeds[finalB]);
                    int[] y_boot = new int[this.N];
                    for (int i = 0; i < this.N; ++i) {
                        double linearPredictorH0 = final_b0_hat_H0;
                        if (this.p > 0 && final_z0_hat_H0.length > 0 && this.Z_orig[i] != null) {
                            for (int j = 0; j < this.p; ++j) {
                                linearPredictorH0 += this.Z_orig[i][j] * final_z0_hat_H0[j];
                            }
                        }
                        double prob_y1 = 1.0 / (1.0 + Math.exp(-linearPredictorH0));
                        y_boot[i] = localRandomGen.nextUniform(0.0, 1.0) < prob_y1 ? 1 : 0;
                    }
                    ModelFitDetails bootFit = this.estimateParametersInternalLogistic(y_boot, true);
                    return bootFit.k_tStat;
                };
                futures.add(bootstrapExecutor.submit(callable));
            }
            bootstrapExecutor.shutdown();
            int successfulBootstraps = 0;
            try {
                for (Future future : futures) {
                    Double t_boot = (Double)future.get();
                    if (t_boot == null || Double.isNaN(t_boot) || Double.isInfinite(t_boot)) continue;
                    bootstrap_t_stats.add(t_boot);
                    ++successfulBootstraps;
                }
            }
            catch (Exception e) {
                System.err.println("Stage 2: Error during Bootstrap execution: " + e.getMessage());
                e.printStackTrace();
                finalResult.pValueK = Double.NaN;
                return finalResult;
            }
            if (this.numBootstrapReplications > 0 && (double)successfulBootstraps < (double)this.numBootstrapReplications * 0.8) {
                System.err.println("Stage 2: Warning - High Bootstrap failure rate (" + successfulBootstraps + "/" + this.numBootstrapReplications + ").");
            }
            if (Double.isNaN(finalResult.k_tStat) || bootstrap_t_stats.isEmpty()) {
                finalResult.pValueK = Double.NaN;
            } else {
                long countExceeding = bootstrap_t_stats.stream().filter(t_b -> Math.abs(t_b) >= Math.abs(finalResult.k_tStat)).count();
                finalResult.pValueK = ((double)countExceeding + 1.0) / ((double)bootstrap_t_stats.size() + 1.0);
                System.out.println("  Stage 2: Bootstrap p-value: " + String.format("%.4f", finalResult.pValueK));
                DescriptiveStatistics stats = new DescriptiveStatistics();
                Iterator iterator2 = bootstrap_t_stats.iterator();
                while (iterator2.hasNext()) {
                    double t_b2 = (Double)iterator2.next();
                    stats.addValue(t_b2);
                }
                double mean_t_boot = stats.getMean();
                double std_t_boot = stats.getStandardDeviation();
                if (std_t_boot > 1.0E-9) {
                    double pValNormal;
                    NormalDistribution normalDist = new NormalDistribution(mean_t_boot, std_t_boot);
                    finalResult.pValueK_NormalFit = pValNormal = 2.0 * (1.0 - normalDist.cumulativeProbability(Math.abs(finalResult.k_tStat - mean_t_boot) + mean_t_boot));
                    System.out.println("  Stage 2: Bootstrap t_stats details: N_boot_samples=" + stats.getN() + ", Mean=" + String.format("%.4f", mean_t_boot) + ", StdDev=" + String.format("%.4f", std_t_boot));
                    System.out.println("  Stage 2: Bootstrap p-value (Normal Fit): " + String.format("%.6e", finalResult.pValueK_NormalFit));
                } else {
                    System.out.println("  Stage 2: StdDev of bootstrap t_stats too small for Normal fit.");
                }
            }
        } else if (this.numBootstrapReplications > 0) {
            System.out.println("  Stage 1 p-value >= " + screeningAlpha + ". Skipping Bootstrap confirmation.");
        } else {
            System.out.println("  Bootstrap skipped as numBootstrapReplications is 0. pValue is nominal.");
        }
        return finalResult;
    }

    private ModelFitDetails estimateH0Logistic(int[] yData) {
        ModelFitDetails h0FitDetails = new ModelFitDetails();
        double[][] designMatrixForH0 = this.p > 0 && this.Z_orig != null ? this.Z_orig : new double[this.N][0];
        LogisticRegressionResult h0LogitRes = this.performLogisticRegression(designMatrixForH0, yData, null);
        if (h0LogitRes != null && h0LogitRes.converged) {
            h0FitDetails.b_hat = h0LogitRes.coefficients[0];
            if (this.p > 0 && h0LogitRes.coefficients.length > 1) {
                h0FitDetails.z_hat = new double[this.p];
                System.arraycopy(h0LogitRes.coefficients, 1, h0FitDetails.z_hat, 0, this.p);
            } else {
                h0FitDetails.z_hat = new double[0];
            }
            h0FitDetails.logLikelihood = h0LogitRes.logLikelihood;
        } else {
            h0FitDetails.b_hat = Double.NaN;
        }
        return h0FitDetails;
    }

    public static void main(String[] args) {
        int N_samples = 2000;
        int p_dim = 4;
        int maxXValue = 100;
        RandomDataGenerator rng = new RandomDataGenerator();
        rng.reSeed(20025L);
        int[] y_data_int = new int[N_samples];
        int[] X_data_scalar = new int[N_samples];
        double[][] Z_data = p_dim > 0 ? new double[N_samples][p_dim] : (double[][])null;
        double b_true = -0.5;
        double k_true = 1.0;
        int d_true = 10;
        double[] z_true = null;
        if (p_dim > 0) {
            z_true = new double[]{1.0, -0.5, 1.0, 5.0};
        }
        System.out.println("--- Generating Simplified Logistic Test Data ---");
        System.out.println("N=" + N_samples + ", p=" + p_dim + ", X range up to " + maxXValue);
        System.out.println("True params: b=" + b_true + ", k=" + k_true + ", d_true=" + d_true + (p_dim > 0 && z_true != null ? ", z=" + Arrays.toString(z_true) : ""));
        for (int i = 0; i < N_samples; ++i) {
            X_data_scalar[i] = rng.nextInt(0, maxXValue);
            if (p_dim > 0 && Z_data != null) {
                for (int j = 0; j < p_dim; ++j) {
                    Z_data[i][j] = rng.nextUniform(-1.0, 1.0);
                }
            }
            boolean I_val = X_data_scalar[i] >= d_true;
            double W_i = b_true + k_true * (double)I_val;
            if (p_dim > 0 && Z_data[i] != null && z_true != null) {
                for (int j = 0; j < p_dim; ++j) {
                    W_i += Z_data[i][j] * z_true[j];
                }
            }
            double prob_y1 = 1.0 / (1.0 + Math.exp(-W_i));
            y_data_int[i] = rng.nextUniform(0.0, 1.0) < prob_y1 ? 1 : 0;
        }
        Logistic1ThresholdEstimator2 estimator = new Logistic1ThresholdEstimator2(y_data_int, X_data_scalar, Z_data);
        estimator.setTauMinProportion(0.05);
        estimator.setNumBootstrapReplications(9999);
        double screeningAlpha = 0.05;
        System.out.println("\n--- Starting Simplified Logistic Threshold Estimation (Two-Stage, EJML version) ---");
        System.out.println("Settings: tau=" + estimator.tauMinProportion + ", numBootstrap (for Stage 2)=" + estimator.numBootstrapReplications + ", MAX_ITER_LOGISTIC=" + estimator.MAX_ITER_LOGISTIC + ", Screening Alpha=" + screeningAlpha);
        long startTime = System.currentTimeMillis();
        ThresholdModelResult result = estimator.estimateParametersAndTestTwoStage(screeningAlpha, 1000);
        long endTime = System.currentTimeMillis();
        System.out.println("Simplified Logistic Estimation (Two-Stage, EJML version) finished in " + (endTime - startTime) + " ms.");
        System.out.println("\n--- Simplified Logistic Estimation Results (Two-Stage, EJML version) ---");
        System.out.println(result);
        System.out.println("True params for comparison: b=" + String.format("%.4f", b_true) + ", k=" + String.format("%.4f", k_true) + ", d_true=" + d_true + (p_dim > 0 && z_true != null ? ", z=" + Arrays.stream(z_true).mapToObj(val -> String.format("%.4f", val)).collect(Collectors.joining(", ", "[", "]")) : ""));
    }

    private static class LogisticRegressionResult {
        double[] coefficients;
        double[] stdErrors;
        double logLikelihood = Double.NEGATIVE_INFINITY;
        boolean converged = false;

        private LogisticRegressionResult() {
        }
    }

    private static class ModelFitDetails {
        public double b_hat = Double.NaN;
        public double k_hat = Double.NaN;
        public double[] z_hat;
        public int d_hat = -1;
        public double logLikelihood = Double.NEGATIVE_INFINITY;
        public double k_stdErr = Double.NaN;
        public double k_tStat = Double.NaN;

        private ModelFitDetails() {
        }
    }

    public static class ThresholdModelResult {
        public double b_hat = Double.NaN;
        public double k_hat = Double.NaN;
        public double[] z_hat;
        public int d_hat = -1;
        public double logLikelihood = Double.NEGATIVE_INFINITY;
        public double k_stdErr = Double.NaN;
        public double k_tStat = Double.NaN;
        public double pValueK = Double.NaN;
        public double pValueK_NormalFit = Double.NaN;

        public String toString() {
            String z_hat_str = "N/A";
            if (this.z_hat != null && this.z_hat.length > 0) {
                z_hat_str = Arrays.stream(this.z_hat).mapToObj(val -> String.format("%.4f", val)).collect(Collectors.joining(", "));
            }
            return "ThresholdModelResult{\nb_hat=" + String.format("%.4f", this.b_hat) + ", \nk_hat=" + String.format("%.4f", this.k_hat) + ", \nz_hat=[" + z_hat_str + "], \nd_hat=" + this.d_hat + ", \nlogLikelihood=" + String.format("%.4f", this.logLikelihood) + ", \nk_stdErr=" + String.format("%.4f", this.k_stdErr) + ", \nk_tStat=" + String.format("%.4f", this.k_tStat) + ", \npValueK (Bootstrap)=" + (Double.isNaN(this.pValueK) ? "NaN" : String.format("%.4f", this.pValueK)) + "\n, \npValueK (Normal Fit)=" + (Double.isNaN(this.pValueK_NormalFit) ? "NaN" : String.format("%.6e", this.pValueK_NormalFit)) + "\n}";
        }
    }
}

