/*
 * Decompiled with CFR 0.152.
 */
package edu.sysu.pmglab.stat;

import cern.jet.stat.Probability;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.random.RandomDataGenerator;

public class Logistic2ThresholdEstimator {
    private final double[] y_orig;
    private final int[] y_int_orig;
    private final double[][] X_orig;
    private final double[][] sortedX;
    private final double[][] Z_orig;
    private final int N;
    private final int m;
    private final int p;
    private double tauMinProportion = 0.1;
    private int numBootstrapReplications = 199;
    private int numQuantilesForC = 50;
    private final int MAX_ITER_LOGISTIC = 100;
    private final double CONVERGENCE_TOL_LOGISTIC = 1.0E-6;

    public Logistic2ThresholdEstimator(double[] y, double[][] X_param, double[][] Z_param) {
        int i;
        this.y_orig = (double[])y.clone();
        this.y_int_orig = Arrays.stream(this.y_orig).mapToInt(val -> (int)Math.round(val)).toArray();
        this.X_orig = new double[X_param.length][];
        this.sortedX = new double[X_param.length][];
        for (i = 0; i < X_param.length; ++i) {
            if (X_param[i] == null) continue;
            this.X_orig[i] = (double[])X_param[i].clone();
            this.sortedX[i] = (double[])X_param[i].clone();
            Arrays.sort(this.sortedX[i]);
        }
        this.N = y.length;
        if (X_param.length != this.N) {
            throw new IllegalArgumentException("X length mismatch with Y");
        }
        int n = this.m = this.N > 0 && X_param[0] != null ? X_param[0].length : 0;
        if (this.m == 0 && this.N > 0) {
            throw new IllegalArgumentException("X cannot be empty or have zero dimension if N > 0");
        }
        if (Z_param != null) {
            if (Z_param.length != this.N) {
                throw new IllegalArgumentException("Z length mismatch with Y");
            }
            this.Z_orig = new double[Z_param.length][];
            for (i = 0; i < Z_param.length; ++i) {
                if (Z_param[i] == null) continue;
                this.Z_orig[i] = (double[])Z_param[i].clone();
            }
            this.p = this.N > 0 && Z_param[0] != null ? Z_param[0].length : 0;
        } else {
            this.Z_orig = null;
            this.p = 0;
        }
    }

    public void setTauMinProportion(double tau) {
        this.tauMinProportion = tau;
    }

    public void setNumBootstrapReplications(int num) {
        this.numBootstrapReplications = num;
    }

    public void setNumQuantilesForC(int numQuantiles) {
        this.numQuantilesForC = Math.max(10, numQuantiles);
    }

    private int countGreaterEqualSorted(double[] sorted_x_row, double c_val) {
        if (sorted_x_row == null || sorted_x_row.length == 0) {
            return 0;
        }
        int low = 0;
        int high = sorted_x_row.length - 1;
        int firstGeIndex = sorted_x_row.length;
        while (low <= high) {
            int mid = low + (high - low) / 2;
            if (sorted_x_row[mid] >= c_val) {
                firstGeIndex = mid;
                high = mid - 1;
                continue;
            }
            low = mid + 1;
        }
        return sorted_x_row.length - firstGeIndex;
    }

    private List<Double> createCGrid() {
        if (this.N == 0 || this.m == 0) {
            return Collections.emptyList();
        }
        HashSet<Double> uniqueXValuesSet = new HashSet<Double>();
        for (int i = 0; i < this.N; ++i) {
            if (this.X_orig[i] == null) continue;
            for (int j = 0; j < this.m; ++j) {
                uniqueXValuesSet.add(this.X_orig[i][j]);
            }
        }
        if (uniqueXValuesSet.isEmpty()) {
            return Collections.emptyList();
        }
        ArrayList<Double> sortedUniqueXValues = new ArrayList<Double>(uniqueXValuesSet);
        Collections.sort(sortedUniqueXValues);
        ArrayList<Double> c_grid = new ArrayList<Double>();
        if (sortedUniqueXValues.size() <= this.numQuantilesForC) {
            return sortedUniqueXValues;
        }
        if (this.numQuantilesForC == 1) {
            c_grid.add((Double)sortedUniqueXValues.get(sortedUniqueXValues.size() / 2));
        } else {
            for (int k = 0; k < this.numQuantilesForC; ++k) {
                double fraction = (double)k / ((double)this.numQuantilesForC - 1.0);
                int index = (int)Math.round(fraction * ((double)sortedUniqueXValues.size() - 1.0));
                index = Math.max(0, Math.min(index, sortedUniqueXValues.size() - 1));
                c_grid.add((Double)sortedUniqueXValues.get(index));
            }
        }
        HashSet tempSet = new HashSet(c_grid);
        c_grid = new ArrayList(tempSet);
        Collections.sort(c_grid);
        return c_grid;
    }

    private LogisticRegressionResult performLogisticRegression(double[][] designMatrix, int[] yResponse, double[] initialBetaArray) {
        int nObs = yResponse.length;
        int numPredictors = designMatrix != null && nObs > 0 && designMatrix[0] != null ? designMatrix[0].length : 0;
        int numCoeffs = 1 + numPredictors;
        LogisticRegressionResult result = new LogisticRegressionResult();
        result.stdErrors = new double[numCoeffs];
        RealMatrix X_full = MatrixUtils.createRealMatrix(nObs, numCoeffs);
        for (int i = 0; i < nObs; ++i) {
            X_full.setEntry(i, 0, 1.0);
            if (numPredictors <= 0 || designMatrix == null || designMatrix[i] == null) continue;
            for (int j = 0; j < numPredictors; ++j) {
                X_full.setEntry(i, j + 1, designMatrix[i][j]);
            }
        }
        RealVector beta = initialBetaArray != null && initialBetaArray.length == numCoeffs ? new ArrayRealVector(initialBetaArray, false) : new ArrayRealVector(numCoeffs, 0.0);
        double[] probabilities = new double[nObs];
        for (int iter = 0; iter < 100; ++iter) {
            RealMatrix hessianInv;
            RealVector linearPredictor = X_full.operate(beta);
            for (int i = 0; i < nObs; ++i) {
                probabilities[i] = 1.0 / (1.0 + Math.exp(-linearPredictor.getEntry(i)));
                if (probabilities[i] < 1.0E-15) {
                    probabilities[i] = 1.0E-15;
                }
                if (!(probabilities[i] > 0.999999999999999)) continue;
                probabilities[i] = 0.999999999999999;
            }
            ArrayRealVector residuals = new ArrayRealVector(nObs);
            for (int i = 0; i < nObs; ++i) {
                ((RealVector)residuals).setEntry(i, (double)yResponse[i] - probabilities[i]);
            }
            RealVector gradient = X_full.transpose().operate(residuals);
            double[] diagonalWeights = new double[nObs];
            for (int i = 0; i < nObs; ++i) {
                double pi = probabilities[i];
                double weight = pi * (1.0 - pi);
                if (weight < 1.0E-10) {
                    weight = 1.0E-10;
                }
                diagonalWeights[i] = weight;
            }
            RealMatrix W_diag = MatrixUtils.createRealDiagonalMatrix(diagonalWeights);
            RealMatrix hessian = X_full.transpose().multiply(W_diag).multiply(X_full).scalarMultiply(-1.0);
            try {
                hessianInv = new LUDecomposition(hessian).getSolver().getInverse();
            }
            catch (SingularMatrixException e) {
                result.converged = false;
                result.coefficients = ((RealVector)beta).toArray();
                Arrays.fill(result.stdErrors, Double.NaN);
                result.logLikelihood = Double.NEGATIVE_INFINITY;
                return result;
            }
            RealVector deltaBeta = hessianInv.operate(gradient).mapMultiply(-1.0);
            beta = ((RealVector)beta).add(deltaBeta);
            if (!(deltaBeta.getNorm() < 1.0E-6)) continue;
            result.converged = true;
            break;
        }
        result.coefficients = ((RealVector)beta).toArray();
        if (result.converged) {
            double ll = 0.0;
            RealVector finalLP = X_full.operate(beta);
            for (int i = 0; i < nObs; ++i) {
                probabilities[i] = 1.0 / (1.0 + Math.exp(-finalLP.getEntry(i)));
                probabilities[i] = Math.max(1.0E-15, Math.min(0.999999999999999, probabilities[i]));
                if (yResponse[i] == 1) {
                    ll += Math.log(probabilities[i]);
                    continue;
                }
                ll += Math.log(1.0 - probabilities[i]);
            }
            result.logLikelihood = ll;
            try {
                double[] finalDiagonalWeights = new double[nObs];
                for (int i = 0; i < nObs; ++i) {
                    double pi = probabilities[i];
                    double weight = pi * (1.0 - pi);
                    if (weight < 1.0E-10) {
                        weight = 1.0E-10;
                    }
                    finalDiagonalWeights[i] = weight;
                }
                RealMatrix W_final_diag = MatrixUtils.createRealDiagonalMatrix(finalDiagonalWeights);
                RealMatrix XtWX = X_full.transpose().multiply(W_final_diag).multiply(X_full);
                RealMatrix covMatrix = new LUDecomposition(XtWX).getSolver().getInverse();
                for (int i = 0; i < numCoeffs; ++i) {
                    result.stdErrors[i] = covMatrix.getEntry(i, i) >= 0.0 ? Math.sqrt(covMatrix.getEntry(i, i)) : Double.NaN;
                }
            }
            catch (SingularMatrixException e) {
                Arrays.fill(result.stdErrors, Double.NaN);
            }
        } else {
            result.logLikelihood = Double.NEGATIVE_INFINITY;
            Arrays.fill(result.stdErrors, Double.NaN);
        }
        return result;
    }

    /*
     * WARNING - void declaration
     */
    private ModelFitDetails estimateParametersInternalLogistic(int[] currentY, double[][] fixedSortedXRef, double[][] fixedZRef, boolean isBootstrapRun) {
        List<Double> c_grid = this.createCGrid();
        if (c_grid.isEmpty() && this.N > 0 && this.m > 0) {
            return new ModelFitDetails();
        }
        ModelFitDetails overallBestFit = new ModelFitDetails();
        int numRegressorsExcludingIntercept = 1 + this.p;
        int innerThreads = Math.max(1, Runtime.getRuntime().availableProcessors() / (isBootstrapRun ? 4 : 2));
        if (innerThreads > c_grid.size() || c_grid.size() < 2 || c_grid.isEmpty()) {
            innerThreads = 1;
        }
        if (innerThreads > 1) {
            ExecutorService cExecutor = Executors.newFixedThreadPool(innerThreads);
            ArrayList<Future<ModelFitDetails>> cFutures = new ArrayList<Future<ModelFitDetails>>();
            for (double d : c_grid) {
                Callable<ModelFitDetails> cTask = () -> {
                    ModelFitDetails localBestFitForC = new ModelFitDetails();
                    double[] lastBeta = null;
                    int[] counts_for_this_c = new int[this.N];
                    for (int i = 0; i < this.N; ++i) {
                        counts_for_this_c[i] = this.countGreaterEqualSorted(fixedSortedXRef[i], c_candidate);
                    }
                    HashSet<Integer> d_critical_points_set = new HashSet<Integer>();
                    d_critical_points_set.add(1);
                    for (int count_val : counts_for_this_c) {
                        int critical_d = count_val + 1;
                        if (critical_d >= 1 && critical_d <= this.m) {
                            d_critical_points_set.add(critical_d);
                        }
                        if (count_val < 1 || count_val > this.m) continue;
                        d_critical_points_set.add(count_val);
                    }
                    if (this.m > 0) {
                        d_critical_points_set.add(this.m);
                    }
                    List<Integer> reduced_d_grid = new ArrayList<Integer>(d_critical_points_set);
                    reduced_d_grid.removeIf(d_val -> d_val < 1 || d_val > this.m);
                    if (!reduced_d_grid.contains(1) && 1 <= this.m) {
                        reduced_d_grid.add(1);
                    }
                    if (!reduced_d_grid.contains(this.m) && this.m >= 1) {
                        reduced_d_grid.add(this.m);
                    }
                    Collections.sort(reduced_d_grid);
                    if (!reduced_d_grid.isEmpty()) {
                        reduced_d_grid = reduced_d_grid.stream().distinct().collect(Collectors.toList());
                    }
                    if (reduced_d_grid.isEmpty() && this.m >= 1) {
                        reduced_d_grid.add(1);
                        if (this.m > 1 && this.m != 1) {
                            reduced_d_grid.add((Integer)((Object)this.m));
                        }
                        if (reduced_d_grid.size() > 1) {
                            reduced_d_grid = reduced_d_grid.stream().distinct().collect(Collectors.toList());
                            Collections.sort(reduced_d_grid);
                        }
                    }
                    Iterator iterator2 = reduced_d_grid.iterator();
                    while (iterator2.hasNext()) {
                        LogisticRegressionResult fit;
                        int d_candidate = (Integer)iterator2.next();
                        double[][] designMatrixForLogit = new double[this.N][numRegressorsExcludingIntercept];
                        int countRegime1 = 0;
                        for (int i = 0; i < this.N; ++i) {
                            boolean D_val = counts_for_this_c[i] >= d_candidate;
                            designMatrixForLogit[i][0] = (double)D_val;
                            if (!D_val) continue;
                            ++countRegime1;
                        }
                        int countRegime0 = this.N - countRegime1;
                        if (!isBootstrapRun ? (double)countRegime0 < (double)this.N * this.tauMinProportion || (double)countRegime1 < (double)this.N * this.tauMinProportion : countRegime0 == 0 || countRegime1 == 0) continue;
                        boolean D_is_constant = true;
                        if (this.N > 0 && designMatrixForLogit[0] != null) {
                            double firstD = designMatrixForLogit[0][0];
                            for (int i = 1; i < this.N; ++i) {
                                if (designMatrixForLogit[i][0] == firstD) continue;
                                D_is_constant = false;
                                break;
                            }
                        }
                        if (D_is_constant && this.N > 0) continue;
                        if (this.p > 0 && fixedZRef != null) {
                            for (int i = 0; i < this.N; ++i) {
                                if (fixedZRef[i] == null) continue;
                                System.arraycopy(fixedZRef[i], 0, designMatrixForLogit[i], 1, this.p);
                            }
                        }
                        if ((fit = this.performLogisticRegression(designMatrixForLogit, currentY, lastBeta)) != null && fit.converged) {
                            lastBeta = fit.coefficients;
                            if (!(fit.logLikelihood > localBestFitForC.logLikelihood)) continue;
                            localBestFitForC.logLikelihood = fit.logLikelihood;
                            localBestFitForC.c_hat = c_candidate;
                            localBestFitForC.d_hat = d_candidate;
                            localBestFitForC.b_hat = fit.coefficients[0];
                            localBestFitForC.k_hat = fit.coefficients[1];
                            if (this.p > 0) {
                                localBestFitForC.z_hat = new double[this.p];
                                System.arraycopy(fit.coefficients, 2, localBestFitForC.z_hat, 0, this.p);
                            } else {
                                localBestFitForC.z_hat = new double[0];
                            }
                            if (fit.stdErrors != null && fit.stdErrors.length > 1) {
                                localBestFitForC.k_stdErr = fit.stdErrors[1];
                                if (localBestFitForC.k_stdErr > 1.0E-9 && !Double.isNaN(localBestFitForC.k_stdErr)) {
                                    localBestFitForC.k_tStat = localBestFitForC.k_hat / localBestFitForC.k_stdErr;
                                    continue;
                                }
                                localBestFitForC.k_tStat = Double.NaN;
                                continue;
                            }
                            localBestFitForC.k_stdErr = Double.NaN;
                            localBestFitForC.k_tStat = Double.NaN;
                            continue;
                        }
                        lastBeta = null;
                    }
                    return localBestFitForC;
                };
                cFutures.add(cExecutor.submit(cTask));
            }
            cExecutor.shutdown();
            try {
                for (Future future : cFutures) {
                    ModelFitDetails fitFromTask = (ModelFitDetails)future.get();
                    if (fitFromTask == null || !(fitFromTask.logLikelihood > overallBestFit.logLikelihood)) continue;
                    overallBestFit = fitFromTask;
                }
            }
            catch (Exception e) {
                System.err.println("Error parallel c_candidate (logistic estInternal): " + e.getMessage());
            }
        } else {
            for (double c_candidate : c_grid) {
                Object var12_18 = null;
                int[] counts_for_this_c = new int[this.N];
                for (int i = 0; i < this.N; ++i) {
                    counts_for_this_c[i] = this.countGreaterEqualSorted(fixedSortedXRef[i], c_candidate);
                }
                HashSet<Integer> d_critical_points_set = new HashSet<Integer>();
                d_critical_points_set.add(1);
                for (int count_val : counts_for_this_c) {
                    if (count_val >= 1 && count_val <= this.m) {
                        d_critical_points_set.add(count_val);
                    }
                    if (count_val + 1 < 1 || count_val + 1 > this.m) continue;
                    d_critical_points_set.add(count_val + 1);
                }
                d_critical_points_set.add(this.m);
                List<Object> reduced_d_grid = new ArrayList<Integer>(d_critical_points_set);
                reduced_d_grid.removeIf(d_val -> d_val < 1 || d_val > this.m);
                Collections.sort(reduced_d_grid);
                if (!reduced_d_grid.isEmpty()) {
                    reduced_d_grid = reduced_d_grid.stream().distinct().collect(Collectors.toList());
                }
                if (reduced_d_grid.isEmpty() && this.m >= 1) {
                    reduced_d_grid.add(1);
                    if (this.m > 1) {
                        reduced_d_grid.add(this.m);
                    }
                    Collections.sort(reduced_d_grid);
                }
                Iterator iterator2 = reduced_d_grid.iterator();
                while (iterator2.hasNext()) {
                    void var12_17;
                    LogisticRegressionResult fit;
                    int d_candidate = (Integer)iterator2.next();
                    double[][] designMatrixForLogit = new double[this.N][numRegressorsExcludingIntercept];
                    int countRegime1 = 0;
                    for (int i = 0; i < this.N; ++i) {
                        boolean D_val = counts_for_this_c[i] >= d_candidate;
                        designMatrixForLogit[i][0] = (double)D_val;
                        if (!D_val) continue;
                        ++countRegime1;
                    }
                    int countRegime0 = this.N - countRegime1;
                    if (!isBootstrapRun ? (double)countRegime0 < (double)this.N * this.tauMinProportion || (double)countRegime1 < (double)this.N * this.tauMinProportion : countRegime0 == 0 || countRegime1 == 0) continue;
                    boolean D_is_constant = true;
                    if (this.N > 0 && designMatrixForLogit[0] != null) {
                        double firstD = designMatrixForLogit[0][0];
                        for (int i = 1; i < this.N; ++i) {
                            if (designMatrixForLogit[i][0] == firstD) continue;
                            D_is_constant = false;
                            break;
                        }
                    }
                    if (D_is_constant && this.N > 0) continue;
                    if (this.p > 0 && fixedZRef != null) {
                        for (int i = 0; i < this.N; ++i) {
                            if (fixedZRef[i] == null) continue;
                            System.arraycopy(fixedZRef[i], 0, designMatrixForLogit[i], 1, this.p);
                        }
                    }
                    if ((fit = this.performLogisticRegression(designMatrixForLogit, currentY, (double[])var12_17)) != null && fit.converged) {
                        double[] dArray = fit.coefficients;
                        if (!(fit.logLikelihood > overallBestFit.logLikelihood)) continue;
                        overallBestFit.logLikelihood = fit.logLikelihood;
                        overallBestFit.c_hat = c_candidate;
                        overallBestFit.d_hat = d_candidate;
                        overallBestFit.b_hat = fit.coefficients[0];
                        overallBestFit.k_hat = fit.coefficients[1];
                        if (this.p > 0) {
                            overallBestFit.z_hat = new double[this.p];
                            System.arraycopy(fit.coefficients, 2, overallBestFit.z_hat, 0, this.p);
                        } else {
                            overallBestFit.z_hat = new double[0];
                        }
                        if (fit.stdErrors != null && fit.stdErrors.length > 1) {
                            overallBestFit.k_stdErr = fit.stdErrors[1];
                            if (overallBestFit.k_stdErr > 1.0E-9 && !Double.isNaN(overallBestFit.k_stdErr)) {
                                overallBestFit.k_tStat = overallBestFit.k_hat / overallBestFit.k_stdErr;
                                continue;
                            }
                            overallBestFit.k_tStat = Double.NaN;
                            continue;
                        }
                        overallBestFit.k_stdErr = Double.NaN;
                        overallBestFit.k_tStat = Double.NaN;
                        continue;
                    }
                    Object var12_20 = null;
                }
            }
        }
        return overallBestFit;
    }

    private ModelFitDetails estimateH0Logistic(int[] yData, double[][] zDataMatrixForH0) {
        ModelFitDetails h0FitDetails = new ModelFitDetails();
        double[][] designMatrixForH0 = this.p > 0 ? zDataMatrixForH0 : new double[this.N][0];
        LogisticRegressionResult h0LogitRes = this.performLogisticRegression(designMatrixForH0, yData, null);
        if (h0LogitRes != null && h0LogitRes.converged) {
            h0FitDetails.b_hat = h0LogitRes.coefficients[0];
            if (this.p > 0) {
                h0FitDetails.z_hat = new double[this.p];
                System.arraycopy(h0LogitRes.coefficients, 1, h0FitDetails.z_hat, 0, this.p);
            } else {
                h0FitDetails.z_hat = new double[0];
            }
            h0FitDetails.logLikelihood = h0LogitRes.logLikelihood;
        } else {
            h0FitDetails.b_hat = Double.NaN;
        }
        return h0FitDetails;
    }

    public ThresholdModelResult estimateParametersAndTestLogistic(long randSeed) {
        ModelFitDetails initialFit = this.estimateParametersInternalLogistic(this.y_int_orig, this.sortedX, this.Z_orig, false);
        ThresholdModelResult finalResult = new ThresholdModelResult();
        if (initialFit.logLikelihood == Double.NEGATIVE_INFINITY || Double.isNaN(initialFit.k_hat)) {
            System.err.println("Failed: initial fit for logistic model.");
            finalResult.logLikelihood = initialFit.logLikelihood;
            finalResult.pValueK = Double.NaN;
            return finalResult;
        }
        finalResult.b_hat = initialFit.b_hat;
        finalResult.k_hat = initialFit.k_hat;
        finalResult.z_hat = initialFit.z_hat != null ? (double[])initialFit.z_hat.clone() : null;
        finalResult.c_hat = initialFit.c_hat;
        finalResult.d_hat = initialFit.d_hat;
        finalResult.logLikelihood = initialFit.logLikelihood;
        finalResult.k_stdErr = initialFit.k_stdErr;
        finalResult.k_tStat = initialFit.k_tStat;
        finalResult.bootstrapNum = this.numBootstrapReplications;
        if (this.numBootstrapReplications <= 0) {
            System.out.println("Skipping bootstrap as numBootstrapReplications <= 0.");
            finalResult.pValueK = Double.NaN;
            return finalResult;
        }
        ModelFitDetails h0Fit = this.estimateH0Logistic(this.y_int_orig, this.Z_orig);
        if (Double.isNaN(h0Fit.b_hat)) {
            System.err.println("Failed to estimate H0 model for bootstrap.");
            finalResult.pValueK = Double.NaN;
            return finalResult;
        }
        double final_b0_hat_H0 = h0Fit.b_hat;
        double[] final_z0_hat_H0 = h0Fit.z_hat != null ? (double[])h0Fit.z_hat.clone() : new double[]{};
        List<Double> bootstrap_t_stats = Collections.synchronizedList(new ArrayList());
        int numThreads = Math.max(1, Runtime.getRuntime().availableProcessors() / 2);
        ExecutorService bootstrapExecutor = Executors.newFixedThreadPool(numThreads);
        ArrayList<Future<Double>> futures = new ArrayList<Future<Double>>();
        RandomDataGenerator masterRng = new RandomDataGenerator();
        masterRng.reSeed(randSeed);
        for (int b = 0; b < this.numBootstrapReplications; ++b) {
            long seed = masterRng.nextLong(Long.MIN_VALUE, Long.MAX_VALUE);
            Callable<Double> task = () -> {
                RandomDataGenerator localRandomGen = new RandomDataGenerator();
                localRandomGen.reSeed(seed);
                int[] y_boot = new int[this.N];
                for (int i = 0; i < this.N; ++i) {
                    double linearPredictorH0 = final_b0_hat_H0;
                    if (this.p > 0 && final_z0_hat_H0.length > 0 && this.Z_orig != null && this.Z_orig[i] != null) {
                        for (int j = 0; j < this.p; ++j) {
                            linearPredictorH0 += this.Z_orig[i][j] * final_z0_hat_H0[j];
                        }
                    }
                    double prob_y1 = 1.0 / (1.0 + Math.exp(-linearPredictorH0));
                    y_boot[i] = localRandomGen.nextUniform(0.0, 1.0) < prob_y1 ? 1 : 0;
                }
                ModelFitDetails bootFit = this.estimateParametersInternalLogistic(y_boot, this.sortedX, this.Z_orig, true);
                return bootFit.k_tStat;
            };
            futures.add(bootstrapExecutor.submit(task));
        }
        bootstrapExecutor.shutdown();
        int successfulBootstraps = 0;
        try {
            for (Future future : futures) {
                Double t_boot = (Double)future.get();
                if (t_boot == null || Double.isNaN(t_boot) || Double.isInfinite(t_boot)) continue;
                bootstrap_t_stats.add(t_boot);
                ++successfulBootstraps;
            }
        }
        catch (Exception e) {
            System.err.println("Error during bootstrap (logistic main): " + e.getMessage());
            finalResult.pValueK = Double.NaN;
            return finalResult;
        }
        if (this.numBootstrapReplications > 0 && (double)successfulBootstraps < (double)this.numBootstrapReplications * 0.8) {
            System.err.println("Warning: Bootstrap failures (logistic main) (" + successfulBootstraps + "/" + this.numBootstrapReplications + ")");
        }
        if (Double.isNaN(finalResult.k_tStat) || bootstrap_t_stats.isEmpty()) {
            finalResult.pValueK = Double.NaN;
        } else {
            long countExceeding = bootstrap_t_stats.stream().filter(t_b -> Math.abs(t_b) >= Math.abs(finalResult.k_tStat)).count();
            finalResult.pValueK = ((double)countExceeding + 1.0) / ((double)bootstrap_t_stats.size() + 1.0);
        }
        return finalResult;
    }

    public ThresholdModelResult screenDataFastNoBootstrap() {
        ModelFitDetails initialFit = this.estimateParametersInternalLogistic(this.y_int_orig, this.sortedX, this.Z_orig, false);
        ThresholdModelResult finalResult = new ThresholdModelResult();
        if (initialFit.logLikelihood == Double.NEGATIVE_INFINITY || Double.isNaN(initialFit.k_hat)) {
            finalResult.logLikelihood = initialFit.logLikelihood;
            finalResult.pValueK = 1.0;
            return finalResult;
        }
        finalResult.b_hat = initialFit.b_hat;
        finalResult.k_hat = initialFit.k_hat;
        finalResult.z_hat = initialFit.z_hat != null ? (double[])initialFit.z_hat.clone() : null;
        finalResult.c_hat = initialFit.c_hat;
        finalResult.d_hat = initialFit.d_hat;
        finalResult.logLikelihood = initialFit.logLikelihood;
        finalResult.k_stdErr = initialFit.k_stdErr;
        finalResult.k_tStat = initialFit.k_tStat;
        finalResult.bootstrapNum = 0;
        if (!Double.isNaN(finalResult.k_tStat) && !Double.isNaN(finalResult.k_stdErr) && finalResult.k_stdErr > 1.0E-12) {
            double waldStat = finalResult.k_hat / finalResult.k_stdErr * (finalResult.k_hat / finalResult.k_stdErr);
            finalResult.pValueK = Probability.chiSquareComplemented(1.0, waldStat);
        } else {
            finalResult.pValueK = 1.0;
        }
        return finalResult;
    }

    public static void main(String[] args) {
        int N_samples = 1000;
        int m_dim = 50;
        int p_dim = 2;
        RandomDataGenerator rng = new RandomDataGenerator();
        rng.reSeed(12347L);
        double[] y_data_double = new double[N_samples];
        double[][] X_data = new double[N_samples][m_dim];
        double[][] Z_data = p_dim > 0 ? new double[N_samples][p_dim] : (double[][])null;
        double b_true = -0.5;
        double k_true = 0.0;
        double c_true = 0.3;
        int d_true = Math.max(1, m_dim / 3);
        double[] z_true = null;
        if (p_dim > 0) {
            z_true = new double[p_dim];
            z_true[0] = 1.0;
            if (p_dim > 1) {
                z_true[1] = 0.0;
            }
        }
        System.out.println("--- Generating Logistic Test Data ---");
        System.out.println("N=" + N_samples + ", m=" + m_dim + ", p=" + p_dim);
        System.out.println("True params: b=" + b_true + ", k=" + k_true + ", c=" + c_true + ", d=" + d_true + (p_dim > 0 && z_true != null ? ", z=" + Arrays.toString(z_true) : ""));
        for (int i = 0; i < N_samples; ++i) {
            int j;
            for (j = 0; j < m_dim; ++j) {
                X_data[i][j] = rng.nextGaussian(0.0, 1.0);
            }
            if (p_dim > 0 && Z_data != null) {
                for (j = 0; j < p_dim; ++j) {
                    Z_data[i][j] = rng.nextUniform(-1.0, 1.0);
                }
            }
            int count_X_ge_c = 0;
            for (double x_val : X_data[i]) {
                if (!(x_val >= c_true)) continue;
                ++count_X_ge_c;
            }
            boolean I_val = count_X_ge_c >= d_true;
            double W_i = b_true + k_true * (double)I_val;
            if (p_dim > 0 && Z_data != null && Z_data[i] != null && z_true != null) {
                for (int j2 = 0; j2 < p_dim; ++j2) {
                    W_i += Z_data[i][j2] * z_true[j2];
                }
            }
            double prob_y1 = 1.0 / (1.0 + Math.exp(-W_i));
            y_data_double[i] = rng.nextUniform(0.0, 1.0) < prob_y1 ? 1.0 : 0.0;
        }
        Logistic2ThresholdEstimator estimator = new Logistic2ThresholdEstimator(y_data_double, X_data, Z_data);
        estimator.setTauMinProportion(0.1);
        estimator.setNumQuantilesForC(30);
        estimator.setNumBootstrapReplications(199);
        System.out.println("\n--- Starting Logistic Threshold Estimation ---");
        System.out.println("Settings: tau=" + estimator.tauMinProportion + ", numCQuantiles=" + estimator.numQuantilesForC + ", numBootstrap=" + estimator.numBootstrapReplications);
        long startTime = System.currentTimeMillis();
        ThresholdModelResult result1 = estimator.screenDataFastNoBootstrap();
        System.out.println(result1);
        ThresholdModelResult result = estimator.estimateParametersAndTestLogistic(1000L);
        long endTime = System.currentTimeMillis();
        System.out.println("Logistic Estimation finished in " + (endTime - startTime) + " ms.");
        System.out.println("\n--- Logistic Estimation Results ---");
        System.out.println(result);
        System.out.println("True params for comparison: b=" + String.format("%.4f", b_true) + ", k=" + String.format("%.4f", k_true) + ", c=" + String.format("%.4f", c_true) + ", d=" + d_true + (p_dim > 0 && z_true != null ? ", z=" + Arrays.stream(z_true).mapToObj(val -> String.format("%.4f", val)).collect(Collectors.joining(", ", "[", "]")) : ""));
    }

    private static class LogisticRegressionResult {
        double[] coefficients;
        double[] stdErrors;
        double logLikelihood = Double.NEGATIVE_INFINITY;
        boolean converged = false;

        private LogisticRegressionResult() {
        }
    }

    private static class ModelFitDetails {
        public double b_hat = Double.NaN;
        public double k_hat = Double.NaN;
        public double[] z_hat;
        public double c_hat = Double.NaN;
        public int d_hat = -1;
        public double logLikelihood = Double.NEGATIVE_INFINITY;
        public double k_stdErr = Double.NaN;
        public double k_tStat = Double.NaN;
    }

    public static class ThresholdModelResult {
        public double b_hat = Double.NaN;
        public double k_hat = Double.NaN;
        public double[] z_hat;
        public double c_hat = Double.NaN;
        public int d_hat = -1;
        public double logLikelihood = Double.NEGATIVE_INFINITY;
        public double k_stdErr = Double.NaN;
        public double k_tStat = Double.NaN;
        public double pValueK = Double.NaN;
        public int bootstrapNum = 0;

        public String toString() {
            String z_hat_str = "N/A";
            if (this.z_hat != null && this.z_hat.length > 0) {
                z_hat_str = Arrays.stream(this.z_hat).mapToObj(val -> String.format("%.4f", val)).collect(Collectors.joining(", "));
            }
            return "ThresholdModelResult{\nb_hat=" + String.format("%.4f", this.b_hat) + ", \nk_hat=" + String.format("%.4f", this.k_hat) + ", \nz_hat=[" + z_hat_str + "], \nc_hat=" + String.format("%.4f", this.c_hat) + ", \nd_hat=" + this.d_hat + ", \nlogLikelihood=" + String.format("%.4f", this.logLikelihood) + ", \nk_stdErr=" + String.format("%.4f", this.k_stdErr) + ", \nk_tStat=" + String.format("%.4f", this.k_tStat) + ", \npValueK (Bootstrap[" + this.bootstrapNum + "])=" + (Double.isNaN(this.pValueK) ? "NaN" : String.format("%.4f", this.pValueK)) + "\n}";
        }
    }
}

