/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.stat;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;

public class Logistic1ThresholdEstimator {
    private final int[] y_int_orig;
    private final int[] X_scalar_orig;
    private final double[][] Z_orig;
    private final int N;
    private final int p;
    private double tauMinProportion = 0.1;
    private int numBootstrapReplications = 199;
    private final int MAX_ITER_LOGISTIC = 25;
    private final double CONVERGENCE_TOL_LOGISTIC = 1.0E-6;

    public Logistic1ThresholdEstimator(int[] y, int[] X_scalar_param, double[][] Z_param) {
        this.y_int_orig = (int[])y.clone();
        this.N = y.length;
        if (X_scalar_param.length != this.N) {
            throw new IllegalArgumentException("X_scalar length mismatch with Y");
        }
        this.X_scalar_orig = (int[])X_scalar_param.clone();
        if (Z_param != null) {
            if (Z_param.length != this.N) {
                throw new IllegalArgumentException("Z length mismatch with Y");
            }
            this.Z_orig = new double[Z_param.length][];
            this.p = this.N > 0 && Z_param[0] != null ? Z_param[0].length : 0;
            for (int i = 0; i < Z_param.length; ++i) {
                if (Z_param[i] != null) {
                    this.Z_orig[i] = (double[])Z_param[i].clone();
                    continue;
                }
                if (this.p <= 0) continue;
                this.Z_orig[i] = new double[this.p];
            }
        } else {
            this.Z_orig = null;
            this.p = 0;
        }
    }

    public void setTauMinProportion(double tau) {
        this.tauMinProportion = tau;
    }

    public void setNumBootstrapReplications(int num) {
        this.numBootstrapReplications = num;
    }

    private List<Integer> createDGrid(int[] x_values) {
        if (x_values == null || x_values.length == 0) {
            return Collections.emptyList();
        }
        HashSet<Integer> uniqueXSet = new HashSet<Integer>();
        for (int val : x_values) {
            uniqueXSet.add(val);
        }
        ArrayList<Integer> d_grid = new ArrayList<Integer>(uniqueXSet);
        Collections.sort(d_grid);
        if (d_grid.isEmpty()) {
            return d_grid;
        }
        int minX = (Integer)d_grid.get(0);
        int maxX = (Integer)d_grid.get(d_grid.size() - 1);
        HashSet<Integer> exhaustiveDSet = new HashSet<Integer>();
        for (int d_val = minX; d_val <= maxX + 1; ++d_val) {
            exhaustiveDSet.add(d_val);
        }
        if (minX > Arrays.stream(this.X_scalar_orig).min().orElse(0) && minX > 0) {
            exhaustiveDSet.add(Arrays.stream(this.X_scalar_orig).min().orElse(0));
        }
        ArrayList<Integer> final_d_grid = new ArrayList<Integer>(exhaustiveDSet);
        Collections.sort(final_d_grid);
        return final_d_grid.isEmpty() ? Collections.singletonList(0) : final_d_grid;
    }

    public RealMatrix computeHessian(RealMatrix X2, int row, int col, double[] diag) {
        RealMatrix res = MatrixUtils.createRealMatrix(col, row);
        for (int i = 0; i < col; ++i) {
            for (int j = 0; j < row; ++j) {
                res.setEntry(i, j, 1.0 * X2.getEntry(j, i) * diag[j]);
            }
        }
        return res.multiply(X2).scalarMultiply(-1.0);
    }

    public RealMatrix computeHTWX(RealMatrix X2, int row, int col, double[] diag) {
        RealMatrix res = MatrixUtils.createRealMatrix(col, row);
        for (int i = 0; i < col; ++i) {
            for (int j = 0; j < row; ++j) {
                res.setEntry(i, j, 1.0 * X2.getEntry(j, i) * diag[j]);
            }
        }
        return res.multiply(X2);
    }

    private LogisticRegressionResult performLogisticRegression(double[][] designMatrix, int[] yResponse, double[] initialBetaArray) {
        int nObs = yResponse.length;
        int numPredictors = designMatrix != null && nObs > 0 && designMatrix[0] != null ? designMatrix[0].length : 0;
        int numCoeffs = 1 + numPredictors;
        LogisticRegressionResult result = new LogisticRegressionResult();
        result.stdErrors = new double[numCoeffs];
        RealMatrix X_full = MatrixUtils.createRealMatrix(nObs, numCoeffs);
        for (int i = 0; i < nObs; ++i) {
            X_full.setEntry(i, 0, 1.0);
            if (numPredictors <= 0 || designMatrix == null || designMatrix[i] == null) continue;
            for (int j = 0; j < numPredictors; ++j) {
                X_full.setEntry(i, j + 1, designMatrix[i][j]);
            }
        }
        RealVector beta = initialBetaArray != null && initialBetaArray.length == numCoeffs ? new ArrayRealVector(initialBetaArray, false) : new ArrayRealVector(numCoeffs, 0.0);
        double[] probabilities = new double[nObs];
        for (int iter = 0; iter < 25; ++iter) {
            RealMatrix hessianInv;
            RealVector linearPredictor = X_full.operate(beta);
            for (int i = 0; i < nObs; ++i) {
                probabilities[i] = 1.0 / (1.0 + Math.exp(-linearPredictor.getEntry(i)));
                if (probabilities[i] < 1.0E-15) {
                    probabilities[i] = 1.0E-15;
                }
                if (!(probabilities[i] > 0.999999999999999)) continue;
                probabilities[i] = 0.999999999999999;
            }
            ArrayRealVector residuals = new ArrayRealVector(nObs);
            for (int i = 0; i < nObs; ++i) {
                ((RealVector)residuals).setEntry(i, (double)yResponse[i] - probabilities[i]);
            }
            RealVector gradient = X_full.transpose().operate(residuals);
            double[] diagonalWeights = new double[nObs];
            for (int i = 0; i < nObs; ++i) {
                double pi = probabilities[i];
                double weight = pi * (1.0 - pi);
                if (weight < 1.0E-10) {
                    weight = 1.0E-10;
                }
                diagonalWeights[i] = weight;
            }
            RealMatrix hessian = this.computeHessian(X_full, nObs, numCoeffs, diagonalWeights);
            try {
                hessianInv = new LUDecomposition(hessian).getSolver().getInverse();
            }
            catch (SingularMatrixException e) {
                result.converged = false;
                result.coefficients = ((RealVector)beta).toArray();
                Arrays.fill(result.stdErrors, Double.NaN);
                result.logLikelihood = Double.NEGATIVE_INFINITY;
                return result;
            }
            RealVector deltaBeta = hessianInv.operate(gradient).mapMultiply(-1.0);
            beta = ((RealVector)beta).add(deltaBeta);
            if (!(deltaBeta.getNorm() < 1.0E-6)) continue;
            result.converged = true;
            break;
        }
        result.coefficients = ((RealVector)beta).toArray();
        if (result.converged) {
            double ll = 0.0;
            RealVector finalLP = X_full.operate(beta);
            for (int i = 0; i < nObs; ++i) {
                probabilities[i] = 1.0 / (1.0 + Math.exp(-finalLP.getEntry(i)));
                probabilities[i] = Math.max(1.0E-15, Math.min(0.999999999999999, probabilities[i]));
                if (yResponse[i] == 1) {
                    ll += Math.log(probabilities[i]);
                    continue;
                }
                ll += Math.log(1.0 - probabilities[i]);
            }
            result.logLikelihood = ll;
            try {
                double[] finalDiagonalWeights = new double[nObs];
                for (int i = 0; i < nObs; ++i) {
                    double pi = probabilities[i];
                    double weight = pi * (1.0 - pi);
                    if (weight < 1.0E-10) {
                        weight = 1.0E-10;
                    }
                    finalDiagonalWeights[i] = weight;
                }
                RealMatrix XtWX = this.computeHTWX(X_full, nObs, numCoeffs, finalDiagonalWeights);
                RealMatrix covMatrix = new LUDecomposition(XtWX).getSolver().getInverse();
                for (int i = 0; i < numCoeffs; ++i) {
                    result.stdErrors[i] = covMatrix.getEntry(i, i) >= 0.0 ? Math.sqrt(covMatrix.getEntry(i, i)) : Double.NaN;
                }
            }
            catch (SingularMatrixException e) {
                Arrays.fill(result.stdErrors, Double.NaN);
            }
        } else {
            result.logLikelihood = Double.NEGATIVE_INFINITY;
            Arrays.fill(result.stdErrors, Double.NaN);
        }
        return result;
    }

    private ModelFitDetails estimateParametersInternalLogistic(int[] currentY, boolean isBootstrapRun) {
        int innerThreads;
        List<Integer> d_grid = this.createDGrid(this.X_scalar_orig);
        ModelFitDetails overallBestFit = new ModelFitDetails();
        int numRegressorsExcludingIntercept = 1 + this.p;
        if (d_grid.isEmpty()) {
            return overallBestFit;
        }
        int availableProcessors = Runtime.getRuntime().availableProcessors();
        if (isBootstrapRun) {
            innerThreads = 1;
        } else {
            innerThreads = Math.max(1, availableProcessors / 2);
            if (d_grid.size() < innerThreads * 2 && (innerThreads = Math.max(1, d_grid.size() / 2)) == 0 && !d_grid.isEmpty()) {
                innerThreads = 1;
            }
        }
        if (d_grid.size() < 5) {
            innerThreads = 1;
        }
        if (innerThreads > 1) {
            ExecutorService dExecutor = Executors.newFixedThreadPool(innerThreads);
            ArrayList<Future<ModelFitDetails>> dFutures = new ArrayList<Future<ModelFitDetails>>();
            for (int n : d_grid) {
                Callable<ModelFitDetails> dTask = () -> {
                    LogisticRegressionResult fit;
                    ModelFitDetails fitForThisD = new ModelFitDetails();
                    fitForThisD.d_hat = d_candidate;
                    double[][] designMatrixForLogit = new double[this.N][numRegressorsExcludingIntercept];
                    int countRegime1 = 0;
                    for (int i = 0; i < this.N; ++i) {
                        boolean D_val = this.X_scalar_orig[i] >= d_candidate;
                        designMatrixForLogit[i][0] = (double)D_val;
                        if (!D_val) continue;
                        ++countRegime1;
                    }
                    int countRegime0 = this.N - countRegime1;
                    if (!isBootstrapRun ? (double)countRegime0 < (double)this.N * this.tauMinProportion || (double)countRegime1 < (double)this.N * this.tauMinProportion : countRegime0 == 0 || countRegime1 == 0) {
                        return fitForThisD;
                    }
                    boolean D_is_constant = true;
                    if (this.N > 0 && designMatrixForLogit[0] != null) {
                        double firstD = designMatrixForLogit[0][0];
                        for (int i = 1; i < this.N; ++i) {
                            if (designMatrixForLogit[i][0] == firstD) continue;
                            D_is_constant = false;
                            break;
                        }
                    }
                    if (D_is_constant && this.N > 0) {
                        return fitForThisD;
                    }
                    if (this.p > 0 && this.Z_orig != null) {
                        for (int i = 0; i < this.N; ++i) {
                            if (this.Z_orig[i] == null) continue;
                            System.arraycopy(this.Z_orig[i], 0, designMatrixForLogit[i], 1, this.p);
                        }
                    }
                    if ((fit = this.performLogisticRegression(designMatrixForLogit, currentY, null)) != null && fit.converged) {
                        fitForThisD.logLikelihood = fit.logLikelihood;
                        fitForThisD.b_hat = fit.coefficients[0];
                        fitForThisD.k_hat = fit.coefficients[1];
                        if (this.p > 0) {
                            fitForThisD.z_hat = new double[this.p];
                            System.arraycopy(fit.coefficients, 2, fitForThisD.z_hat, 0, this.p);
                        } else {
                            fitForThisD.z_hat = new double[0];
                        }
                        if (fit.stdErrors != null && fit.stdErrors.length > 1) {
                            fitForThisD.k_stdErr = fit.stdErrors[1];
                            fitForThisD.k_tStat = fitForThisD.k_stdErr > 1.0E-9 && !Double.isNaN(fitForThisD.k_stdErr) ? fitForThisD.k_hat / fitForThisD.k_stdErr : Double.NaN;
                        } else {
                            fitForThisD.k_stdErr = Double.NaN;
                            fitForThisD.k_tStat = Double.NaN;
                        }
                    }
                    return fitForThisD;
                };
                dFutures.add(dExecutor.submit(dTask));
            }
            dExecutor.shutdown();
            try {
                for (Future future : dFutures) {
                    ModelFitDetails fitFromTask = (ModelFitDetails)future.get();
                    if (fitFromTask == null || !(fitFromTask.logLikelihood > overallBestFit.logLikelihood)) continue;
                    overallBestFit = fitFromTask;
                }
            }
            catch (Exception e) {
                System.err.println("Error parallel d_grid (simplified internal screening): " + e.getMessage());
            }
        } else {
            double[] lastBeta = null;
            for (int d_candidate : d_grid) {
                LogisticRegressionResult fit;
                double[][] dArray = new double[this.N][numRegressorsExcludingIntercept];
                int countRegime1 = 0;
                for (int i = 0; i < this.N; ++i) {
                    boolean D_val = this.X_scalar_orig[i] >= d_candidate;
                    dArray[i][0] = (double)D_val;
                    if (!D_val) continue;
                    ++countRegime1;
                }
                int countRegime0 = this.N - countRegime1;
                if (!isBootstrapRun ? (double)countRegime0 < (double)this.N * this.tauMinProportion || (double)countRegime1 < (double)this.N * this.tauMinProportion : countRegime0 == 0 || countRegime1 == 0) continue;
                boolean D_is_constant = true;
                if (this.N > 0 && dArray[0] != null) {
                    double firstD = dArray[0][0];
                    for (int i = 1; i < this.N; ++i) {
                        if (dArray[i][0] == firstD) continue;
                        D_is_constant = false;
                        break;
                    }
                }
                if (D_is_constant && this.N > 0) continue;
                if (this.p > 0 && this.Z_orig != null) {
                    for (int i = 0; i < this.N; ++i) {
                        if (this.Z_orig[i] == null) continue;
                        System.arraycopy(this.Z_orig[i], 0, dArray[i], 1, this.p);
                    }
                }
                if ((fit = this.performLogisticRegression(dArray, currentY, lastBeta)) != null && fit.converged) {
                    lastBeta = fit.coefficients;
                    if (!(fit.logLikelihood > overallBestFit.logLikelihood)) continue;
                    overallBestFit.logLikelihood = fit.logLikelihood;
                    overallBestFit.d_hat = d_candidate;
                    overallBestFit.b_hat = fit.coefficients[0];
                    overallBestFit.k_hat = fit.coefficients[1];
                    if (this.p > 0) {
                        overallBestFit.z_hat = new double[this.p];
                        System.arraycopy(fit.coefficients, 2, overallBestFit.z_hat, 0, this.p);
                    } else {
                        overallBestFit.z_hat = new double[0];
                    }
                    if (fit.stdErrors != null && fit.stdErrors.length > 1) {
                        overallBestFit.k_stdErr = fit.stdErrors[1];
                        if (overallBestFit.k_stdErr > 1.0E-9 && !Double.isNaN(overallBestFit.k_stdErr)) {
                            overallBestFit.k_tStat = overallBestFit.k_hat / overallBestFit.k_stdErr;
                            continue;
                        }
                        overallBestFit.k_tStat = Double.NaN;
                        continue;
                    }
                    overallBestFit.k_stdErr = Double.NaN;
                    overallBestFit.k_tStat = Double.NaN;
                    continue;
                }
                lastBeta = null;
            }
        }
        return overallBestFit;
    }

    public ThresholdModelResult estimateParametersAndTestTwoStage(double screeningAlpha, int randSeed) {
        ModelFitDetails initialFit = this.estimateParametersInternalLogistic(this.y_int_orig, false);
        ThresholdModelResult finalResult = new ThresholdModelResult();
        if (initialFit.logLikelihood == Double.NEGATIVE_INFINITY || Double.isNaN(initialFit.k_hat)) {
            System.err.println("Stage 1: Initial fit failed. Skipping Bootstrap.");
            finalResult.logLikelihood = initialFit.logLikelihood;
            finalResult.pValueK = Double.NaN;
            finalResult.b_hat = initialFit.b_hat;
            finalResult.k_hat = initialFit.k_hat;
            finalResult.z_hat = initialFit.z_hat != null ? (double[])initialFit.z_hat.clone() : null;
            finalResult.d_hat = initialFit.d_hat;
            finalResult.k_stdErr = initialFit.k_stdErr;
            finalResult.k_tStat = initialFit.k_tStat;
            return finalResult;
        }
        finalResult.b_hat = initialFit.b_hat;
        finalResult.k_hat = initialFit.k_hat;
        finalResult.z_hat = initialFit.z_hat != null ? (double[])initialFit.z_hat.clone() : null;
        finalResult.d_hat = initialFit.d_hat;
        finalResult.logLikelihood = initialFit.logLikelihood;
        finalResult.k_stdErr = initialFit.k_stdErr;
        finalResult.k_tStat = initialFit.k_tStat;
        double nominalPValue = 1.0;
        if (!Double.isNaN(finalResult.k_tStat) && !Double.isNaN(finalResult.k_stdErr) && finalResult.k_stdErr > 1.0E-12) {
            double waldStat = finalResult.k_hat / finalResult.k_stdErr * (finalResult.k_hat / finalResult.k_stdErr);
            ChiSquaredDistribution chi2Dist = new ChiSquaredDistribution(1.0);
            nominalPValue = 1.0 - chi2Dist.cumulativeProbability(waldStat);
            System.out.println("  Stage 1: Nominal (screening) Wald p-value: " + String.format("%.4e", nominalPValue));
        } else {
            System.out.println("  Stage 1: Could not calculate nominal Wald p-value (k_tStat or k_stdErr is NaN/problematic).");
        }
        finalResult.pValueK = nominalPValue;
        if (nominalPValue < screeningAlpha && this.numBootstrapReplications > 0) {
            int b;
            System.out.println("  Stage 1 p-value < " + screeningAlpha + ". Proceeding to Bootstrap confirmation (" + this.numBootstrapReplications + " reps).");
            ModelFitDetails h0Fit = this.estimateH0Logistic(this.y_int_orig);
            if (Double.isNaN(h0Fit.b_hat)) {
                System.err.println("Stage 2: H0 fit failed for Bootstrap. Bootstrap p-value will be NaN.");
                return finalResult;
            }
            double final_b0_hat_H0 = h0Fit.b_hat;
            double[] final_z0_hat_H0 = h0Fit.z_hat != null ? (double[])h0Fit.z_hat.clone() : new double[]{};
            List<Double> bootstrap_t_stats = Collections.synchronizedList(new ArrayList());
            int numThreads = Math.max(1, Runtime.getRuntime().availableProcessors() / 2);
            ExecutorService bootstrapExecutor = Executors.newFixedThreadPool(numThreads);
            ArrayList<Future<Double>> futures = new ArrayList<Future<Double>>();
            RandomDataGenerator outerRandomGen = new RandomDataGenerator();
            outerRandomGen.reSeed(randSeed);
            long[] seeds = new long[this.numBootstrapReplications];
            for (b = 0; b < this.numBootstrapReplications; ++b) {
                seeds[b] = outerRandomGen.nextLong(0L, 1000000000L);
            }
            b = 0;
            while (b < this.numBootstrapReplications) {
                int finalB = b++;
                Callable<Double> callable = () -> {
                    RandomDataGenerator localRandomGen = new RandomDataGenerator();
                    localRandomGen.reSeed(seeds[finalB]);
                    int[] y_boot = new int[this.N];
                    for (int i = 0; i < this.N; ++i) {
                        double linearPredictorH0 = final_b0_hat_H0;
                        if (this.p > 0 && final_z0_hat_H0.length > 0 && this.Z_orig != null && this.Z_orig[i] != null) {
                            for (int j = 0; j < this.p; ++j) {
                                if (j >= final_z0_hat_H0.length) continue;
                                linearPredictorH0 += this.Z_orig[i][j] * final_z0_hat_H0[j];
                            }
                        }
                        double prob_y1 = 1.0 / (1.0 + Math.exp(-linearPredictorH0));
                        y_boot[i] = localRandomGen.nextUniform(0.0, 1.0) < prob_y1 ? 1 : 0;
                    }
                    ModelFitDetails bootFit = this.estimateParametersInternalLogistic(y_boot, true);
                    return bootFit.k_tStat;
                };
                futures.add(bootstrapExecutor.submit(callable));
            }
            bootstrapExecutor.shutdown();
            int successfulBootstraps = 0;
            try {
                for (Future future : futures) {
                    Double t_boot = (Double)future.get();
                    if (t_boot == null || Double.isNaN(t_boot) || Double.isInfinite(t_boot)) continue;
                    bootstrap_t_stats.add(t_boot);
                    ++successfulBootstraps;
                }
            }
            catch (Exception e) {
                System.err.println("Stage 2: Error during Bootstrap execution: " + e.getMessage());
                finalResult.pValueK = Double.NaN;
                return finalResult;
            }
            if (this.numBootstrapReplications > 0 && (double)successfulBootstraps < (double)this.numBootstrapReplications * 0.8) {
                System.err.println("Stage 2: Warning - High Bootstrap failure rate (" + successfulBootstraps + "/" + this.numBootstrapReplications + ").");
            }
            if (Double.isNaN(finalResult.k_tStat) || bootstrap_t_stats.isEmpty()) {
                finalResult.pValueK = Double.NaN;
            } else {
                long countExceeding = bootstrap_t_stats.stream().filter(t_b -> Math.abs(t_b) >= Math.abs(finalResult.k_tStat)).count();
                finalResult.pValueK = ((double)countExceeding + 1.0) / ((double)bootstrap_t_stats.size() + 1.0);
                System.out.println("  Stage 2: Bootstrap p-value: " + String.format("%.4f", finalResult.pValueK));
                DescriptiveStatistics stats = new DescriptiveStatistics();
                Iterator iterator2 = bootstrap_t_stats.iterator();
                while (iterator2.hasNext()) {
                    double t_b2 = (Double)iterator2.next();
                    stats.addValue(t_b2);
                }
                double mean_t_boot = stats.getMean();
                double std_t_boot = stats.getStandardDeviation();
                if (std_t_boot > 1.0E-9) {
                    NormalDistribution normalDist = new NormalDistribution(mean_t_boot, std_t_boot);
                    double pValNormal = finalResult.k_tStat >= mean_t_boot ? 2.0 * (1.0 - normalDist.cumulativeProbability(finalResult.k_tStat)) : 2.0 * normalDist.cumulativeProbability(finalResult.k_tStat);
                    finalResult.pValueK_NormalFit = pValNormal;
                    System.out.println("  Stage 2: Bootstrap t_stats details: N_boot_samples=" + stats.getN() + ", Mean=" + String.format("%.4f", mean_t_boot) + ", StdDev=" + String.format("%.4f", std_t_boot));
                    System.out.println("  Stage 2: Bootstrap p-value (Normal Fit): " + String.format("%.6e", finalResult.pValueK_NormalFit));
                } else {
                    System.out.println("  Stage 2: StdDev of bootstrap t_stats too small for Normal fit.");
                }
            }
        } else if (this.numBootstrapReplications > 0) {
            System.out.println("  Stage 1 p-value >= " + screeningAlpha + ". Skipping Bootstrap confirmation.");
        } else {
            System.out.println("  Bootstrap skipped as numBootstrapReplications is 0. pValue is nominal.");
        }
        return finalResult;
    }

    private ModelFitDetails estimateH0Logistic(int[] yData) {
        ModelFitDetails h0FitDetails = new ModelFitDetails();
        double[][] designMatrixForH0 = this.p > 0 && this.Z_orig != null ? this.Z_orig : new double[this.N][0];
        LogisticRegressionResult h0LogitRes = this.performLogisticRegression(designMatrixForH0, yData, null);
        if (h0LogitRes != null && h0LogitRes.converged) {
            h0FitDetails.b_hat = h0LogitRes.coefficients[0];
            if (this.p > 0 && h0LogitRes.coefficients.length > 1) {
                h0FitDetails.z_hat = new double[this.p];
                System.arraycopy(h0LogitRes.coefficients, 1, h0FitDetails.z_hat, 0, this.p);
            } else {
                h0FitDetails.z_hat = new double[0];
            }
            h0FitDetails.logLikelihood = h0LogitRes.logLikelihood;
        } else {
            h0FitDetails.b_hat = Double.NaN;
        }
        return h0FitDetails;
    }

    public ThresholdModelResult estimateParametersAndTestLogistic() {
        ModelFitDetails initialFit = this.estimateParametersInternalLogistic(this.y_int_orig, false);
        ThresholdModelResult finalResult = new ThresholdModelResult();
        if (initialFit.logLikelihood == Double.NEGATIVE_INFINITY || Double.isNaN(initialFit.k_hat)) {
            System.err.println("Failed: initial fit for simplified logistic model.");
            finalResult.logLikelihood = initialFit.logLikelihood;
            finalResult.pValueK = Double.NaN;
            return finalResult;
        }
        finalResult.b_hat = initialFit.b_hat;
        finalResult.k_hat = initialFit.k_hat;
        finalResult.z_hat = initialFit.z_hat != null ? (double[])initialFit.z_hat.clone() : null;
        finalResult.d_hat = initialFit.d_hat;
        finalResult.logLikelihood = initialFit.logLikelihood;
        finalResult.k_stdErr = initialFit.k_stdErr;
        finalResult.k_tStat = initialFit.k_tStat;
        if (this.numBootstrapReplications <= 0) {
            System.out.println("Skipping bootstrap.");
            finalResult.pValueK = Double.NaN;
            return finalResult;
        }
        ModelFitDetails h0Fit = this.estimateH0Logistic(this.y_int_orig);
        if (Double.isNaN(h0Fit.b_hat)) {
            System.err.println("Failed H0 fit for bootstrap (simplified model).");
            finalResult.pValueK = Double.NaN;
            return finalResult;
        }
        double final_b0_hat_H0 = h0Fit.b_hat;
        double[] final_z0_hat_H0 = h0Fit.z_hat != null ? (double[])h0Fit.z_hat.clone() : new double[]{};
        List<Double> bootstrap_t_stats = Collections.synchronizedList(new ArrayList());
        int numThreads = Math.max(1, Runtime.getRuntime().availableProcessors() / 2);
        ExecutorService bootstrapExecutor = Executors.newFixedThreadPool(numThreads);
        ArrayList<Future<Double>> futures = new ArrayList<Future<Double>>();
        for (int b = 0; b < this.numBootstrapReplications; ++b) {
            Callable<Double> task = () -> {
                RandomDataGenerator localRandomGen = new RandomDataGenerator();
                int[] y_boot = new int[this.N];
                for (int i = 0; i < this.N; ++i) {
                    double linearPredictorH0 = final_b0_hat_H0;
                    if (this.p > 0 && final_z0_hat_H0.length > 0 && this.Z_orig != null && this.Z_orig[i] != null) {
                        for (int j = 0; j < this.p; ++j) {
                            if (j >= final_z0_hat_H0.length) continue;
                            linearPredictorH0 += this.Z_orig[i][j] * final_z0_hat_H0[j];
                        }
                    }
                    double prob_y1 = 1.0 / (1.0 + Math.exp(-linearPredictorH0));
                    y_boot[i] = localRandomGen.nextUniform(0.0, 1.0) < prob_y1 ? 1 : 0;
                }
                ModelFitDetails bootFit = this.estimateParametersInternalLogistic(y_boot, true);
                return bootFit.k_tStat;
            };
            futures.add(bootstrapExecutor.submit(task));
        }
        bootstrapExecutor.shutdown();
        int successfulBootstraps = 0;
        try {
            for (Future future : futures) {
                Double t_boot = (Double)future.get();
                if (t_boot == null || Double.isNaN(t_boot) || Double.isInfinite(t_boot)) continue;
                bootstrap_t_stats.add(t_boot);
                ++successfulBootstraps;
            }
        }
        catch (Exception e) {
            System.err.println("Error during bootstrap (simplified): " + e.getMessage());
            finalResult.pValueK = Double.NaN;
            return finalResult;
        }
        if (this.numBootstrapReplications > 0 && (double)successfulBootstraps < (double)this.numBootstrapReplications * 0.8) {
            System.err.println("Warning: Bootstrap failures (simplified) (" + successfulBootstraps + "/" + this.numBootstrapReplications + ")");
        }
        if (Double.isNaN(finalResult.k_tStat) || bootstrap_t_stats.isEmpty()) {
            finalResult.pValueK = Double.NaN;
        } else {
            long countExceeding = bootstrap_t_stats.stream().filter(t_b -> Math.abs(t_b) >= Math.abs(finalResult.k_tStat)).count();
            finalResult.pValueK = ((double)countExceeding + 1.0) / ((double)bootstrap_t_stats.size() + 1.0);
        }
        return finalResult;
    }

    public static void main(String[] args) {
        int N_samples = 2000;
        int p_dim = 2;
        int maxXValue = 100;
        RandomDataGenerator rng = new RandomDataGenerator();
        rng.reSeed(20025L);
        int[] y_data_int = new int[N_samples];
        int[] X_data_scalar = new int[N_samples];
        double[][] Z_data = p_dim > 0 ? new double[N_samples][p_dim] : (double[][])null;
        double b_true = -0.5;
        double k_true = 1.0;
        int d_true = 10;
        double[] z_true = null;
        if (p_dim > 0) {
            z_true = new double[p_dim];
            z_true[0] = 1.0;
            if (p_dim > 1) {
                z_true[1] = -0.5;
            }
        }
        System.out.println("--- Generating Simplified Logistic Test Data ---");
        System.out.println("N=" + N_samples + ", p=" + p_dim + ", X range up to " + maxXValue);
        System.out.println("True params: b=" + b_true + ", k=" + k_true + ", d_true=" + d_true + (p_dim > 0 && z_true != null ? ", z=" + Arrays.toString(z_true) : ""));
        for (int i = 0; i < N_samples; ++i) {
            X_data_scalar[i] = rng.nextInt(0, maxXValue);
            if (p_dim > 0 && Z_data != null) {
                for (int j = 0; j < p_dim; ++j) {
                    Z_data[i][j] = rng.nextUniform(-1.0, 1.0);
                }
            }
            boolean I_val = X_data_scalar[i] >= d_true;
            double W_i = b_true + k_true * (double)I_val;
            if (p_dim > 0 && Z_data != null && Z_data[i] != null && z_true != null) {
                for (int j = 0; j < p_dim; ++j) {
                    if (j >= z_true.length) continue;
                    W_i += Z_data[i][j] * z_true[j];
                }
            }
            double prob_y1 = 1.0 / (1.0 + Math.exp(-W_i));
            y_data_int[i] = rng.nextUniform(0.0, 1.0) < prob_y1 ? 1 : 0;
        }
        Logistic1ThresholdEstimator estimator = new Logistic1ThresholdEstimator(y_data_int, X_data_scalar, Z_data);
        estimator.setTauMinProportion(0.05);
        estimator.setNumBootstrapReplications(999);
        double screeningAlpha = 0.05;
        System.out.println("\n--- Starting Simplified Logistic Threshold Estimation (Two-Stage) ---");
        System.out.println("Settings: tau=" + estimator.tauMinProportion + ", numBootstrap (for Stage 2)=" + estimator.numBootstrapReplications + ", MAX_ITER_LOGISTIC=" + estimator.MAX_ITER_LOGISTIC + ", Screening Alpha=" + screeningAlpha);
        long startTime = System.currentTimeMillis();
        ThresholdModelResult result = estimator.estimateParametersAndTestTwoStage(screeningAlpha, 1000);
        long endTime = System.currentTimeMillis();
        System.out.println("Simplified Logistic Estimation (Two-Stage) finished in " + (endTime - startTime) + " ms.");
        System.out.println("\n--- Simplified Logistic Estimation Results (Two-Stage) ---");
        System.out.println(result);
        System.out.println("True params for comparison: b=" + String.format("%.4f", b_true) + ", k=" + String.format("%.4f", k_true) + ", d_true=" + d_true + (p_dim > 0 && z_true != null ? ", z=" + Arrays.stream(z_true).mapToObj(val -> String.format("%.4f", val)).collect(Collectors.joining(", ", "[", "]")) : ""));
    }

    private static class LogisticRegressionResult {
        double[] coefficients;
        double[] stdErrors;
        double logLikelihood = Double.NEGATIVE_INFINITY;
        boolean converged = false;

        private LogisticRegressionResult() {
        }
    }

    private static class ModelFitDetails {
        public double b_hat = Double.NaN;
        public double k_hat = Double.NaN;
        public double[] z_hat;
        public int d_hat = -1;
        public double logLikelihood = Double.NEGATIVE_INFINITY;
        public double k_stdErr = Double.NaN;
        public double k_tStat = Double.NaN;
    }

    public static class ThresholdModelResult {
        public double b_hat = Double.NaN;
        public double k_hat = Double.NaN;
        public double[] z_hat;
        public int d_hat = -1;
        public double logLikelihood = Double.NEGATIVE_INFINITY;
        public double k_stdErr = Double.NaN;
        public double k_tStat = Double.NaN;
        public double pValueK = Double.NaN;
        public double pValueK_NormalFit = Double.NaN;

        public String toString() {
            String z_hat_str = "N/A";
            if (this.z_hat != null && this.z_hat.length > 0) {
                z_hat_str = Arrays.stream(this.z_hat).mapToObj(val -> String.format("%.4f", val)).collect(Collectors.joining(", "));
            }
            return "ThresholdModelResult{\nb_hat=" + String.format("%.4f", this.b_hat) + ", \nk_hat=" + String.format("%.4f", this.k_hat) + ", \nz_hat=[" + z_hat_str + "], \nd_hat=" + this.d_hat + ", \nlogLikelihood=" + String.format("%.4f", this.logLikelihood) + ", \nk_stdErr=" + String.format("%.4f", this.k_stdErr) + ", \nk_tStat=" + String.format("%.4f", this.k_tStat) + ", \npValueK (Bootstrap)=" + (Double.isNaN(this.pValueK) ? "NaN" : String.format("%.4f", this.pValueK)) + "\n}";
        }
    }
}

