package edu.sysu.pmglab.stat;

import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.colt.matrix.linalg.QRDecomposition;
import cern.jet.stat.Probability;
import java.util.Arrays;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:edu/sysu/pmglab/stat/LinearRegression.class */
public class LinearRegression {
    private DoubleMatrix2D beta;
    private double sse;
    private double sst;
    double[] weights;
    double[] residuals;
    double[] se;
    double[] tValues;
    double[][] x;
    double[] y;
    boolean addIntercept;

    public LinearRegression() {
        this.addIntercept = false;
    }

    public double[] getSE() {
        return this.se;
    }

    public void addIntercept() {
        if (this.addIntercept) {
            return;
        }
        this.addIntercept = true;
        double[][] dArr = new double[this.x.length][this.x[0].length + 1];
        for (int i = 0; i < this.x.length; i++) {
            dArr[i][0] = 1.0d;
            System.arraycopy(this.x[i], 0, dArr[i], 1, this.x[0].length);
        }
        this.x = dArr;
    }

    public void setX(double[][] dArr) {
        this.x = dArr;
        if (dArr == null || dArr.length == 0 || dArr[0].length == 0) {
            System.err.println("No independent variables!!!");
            return;
        }
        for (double[] dArr2 : dArr) {
            if (dArr2[0] != 1.0d) {
                System.err.println("The first column of X values must be 1 to consider the intercept for logistic regression!!!");
            }
        }
    }

    public void setY(double[] dArr) {
        this.y = dArr;
        if (dArr == null || dArr.length == 0) {
            System.err.println("No dependent variables!!!");
            return;
        }
        for (double d : dArr) {
            if (d != CMAESOptimizer.DEFAULT_STOPFITNESS && d != 1.0d) {
                System.err.println("Y values must be 0  or 1 for logistic regression!!!");
            }
        }
    }

    public LinearRegression(double[][] dArr, double[] dArr2, boolean z) {
        this.addIntercept = false;
        this.addIntercept = z;
        if (z) {
            double[][] dArr3 = new double[dArr.length][dArr[0].length + 1];
            for (int i = 0; i < dArr.length; i++) {
                dArr3[i][0] = 1.0d;
                System.arraycopy(dArr[i], 0, dArr3[i], 1, dArr[0].length);
            }
            this.x = dArr3;
        } else {
            this.x = dArr;
        }
        this.y = dArr2;
    }

    public void fit() {
        if (this.x.length != this.y.length) {
            throw new IllegalArgumentException("matrix dimensions don't agree");
        }
        int length = this.y.length;
        DenseDoubleMatrix2D denseDoubleMatrix2D = new DenseDoubleMatrix2D(this.x);
        double[][] dArr = new double[this.y.length][1];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i][0] = this.y[i];
        }
        this.beta = new QRDecomposition(denseDoubleMatrix2D).solve(new DenseDoubleMatrix2D(dArr));
        double d = 0.0d;
        for (int i2 = 0; i2 < length; i2++) {
            d += this.y[i2];
        }
        double d2 = d / length;
        for (int i3 = 0; i3 < length; i3++) {
            double d3 = this.y[i3] - d2;
            this.sst += d3 * d3;
        }
        DenseDoubleMatrix2D denseDoubleMatrix2D2 = new DenseDoubleMatrix2D(length, 1);
        DoubleMatrix2D zMult = denseDoubleMatrix2D.zMult(this.beta, denseDoubleMatrix2D2);
        for (int i4 = 0; i4 < length; i4++) {
            zMult.setQuick(0, i4, this.y[i4] - denseDoubleMatrix2D2.getQuick(0, i4));
        }
        for (int i5 = 0; i5 < zMult.rows(); i5++) {
            for (int i6 = 0; i6 < zMult.columns(); i6++) {
                this.sse += zMult.get(i5, i6) * zMult.get(i5, i6);
            }
        }
    }

    public double[] calculateStandardErrors() {
        int length = this.y.length;
        int length2 = this.x[0].length;
        double d = this.sse / (length - length2);
        Algebra algebra = new Algebra();
        DoubleMatrix2D denseDoubleMatrix2D = new DenseDoubleMatrix2D(this.x);
        DoubleMatrix2D zMult = denseDoubleMatrix2D.viewDice().zMult(denseDoubleMatrix2D, (DoubleMatrix2D) null);
        for (int i = 0; i < length2; i++) {
            zMult.setQuick(i, i, zMult.getQuick(i, i) + 1.0E-6d);
        }
        DoubleMatrix2D inverse = algebra.inverse(zMult);
        double[] dArr = new double[length2];
        for (int i2 = 0; i2 < length2; i2++) {
            dArr[i2] = Math.sqrt(d * inverse.getQuick(i2, i2));
        }
        return dArr;
    }

    public double[] calculateTValues() {
        this.se = calculateStandardErrors();
        this.tValues = new double[this.beta.rows()];
        for (int i = 0; i < this.beta.rows(); i++) {
            this.tValues[i] = this.beta.getQuick(i, 0) / this.se[i];
        }
        return this.tValues;
    }

    public double[] calculatePValues() {
        int length = this.y.length;
        int length2 = this.x[0].length;
        double[] calculateTValues = calculateTValues();
        double[] dArr = new double[calculateTValues.length];
        for (int i = 0; i < calculateTValues.length; i++) {
            dArr[i] = 2.0d * (1.0d - Probability.studentT(length - length2, Math.abs(calculateTValues[i])));
        }
        return dArr;
    }

    public double[] getResiduals() {
        return this.residuals;
    }

    public LinearRegression(double[][] dArr, double[] dArr2) {
        this.addIntercept = false;
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException("matrix dimensions don't agree");
        }
        int length = dArr2.length;
        DenseDoubleMatrix2D denseDoubleMatrix2D = new DenseDoubleMatrix2D(dArr);
        double[][] dArr3 = new double[dArr2.length][1];
        for (int i = 0; i < dArr3.length; i++) {
            dArr3[i][0] = dArr2[i];
        }
        this.beta = new QRDecomposition(denseDoubleMatrix2D).solve(new DenseDoubleMatrix2D(dArr3));
        double d = 0.0d;
        for (double d2 : dArr2) {
            d += d2;
        }
        double d3 = d / length;
        for (double d4 : dArr2) {
            double d5 = d4 - d3;
            this.sst += d5 * d5;
        }
        DenseDoubleMatrix2D denseDoubleMatrix2D2 = new DenseDoubleMatrix2D(length, 1);
        DoubleMatrix2D zMult = denseDoubleMatrix2D.zMult(this.beta, denseDoubleMatrix2D2);
        for (int i2 = 0; i2 < length; i2++) {
            zMult.setQuick(0, i2, dArr2[i2] - denseDoubleMatrix2D2.getQuick(0, i2));
        }
        for (int i3 = 0; i3 < zMult.rows(); i3++) {
            for (int i4 = 0; i4 < zMult.columns(); i4++) {
                this.sse += zMult.get(i3, i4) * zMult.get(i3, i4);
            }
        }
    }

    public boolean robustLinearRegression(double[] dArr, double[][] dArr2, int i, int i2) {
        this.weights = new double[dArr.length];
        Arrays.fill(this.weights, 1.0d);
        return robustLinearRegression(dArr, dArr2, this.weights, i, i2);
    }

    public double[] getWeights() {
        return this.weights;
    }

    /* JADX WARN: Removed duplicated region for block: B:19:0x00d9  */
    /* JADX WARN: Removed duplicated region for block: B:40:0x01a1  */
    /* JADX WARN: Removed duplicated region for block: B:49:0x01d1 A[LOOP:0: B:5:0x0070->B:49:0x01d1, LOOP_END] */
    /* JADX WARN: Removed duplicated region for block: B:50:0x01cb A[SYNTHETIC] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public boolean robustLinearRegression(double[] r9, double[][] r10, double[] r11, int r12, int r13) {
        /*
            Method dump skipped, instructions count: 498
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: edu.sysu.pmglab.stat.LinearRegression.robustLinearRegression(double[], double[][], double[], int, int):boolean");
    }

    public boolean weightedLeastSquaresArray(double[] dArr, double[][] dArr2, double[] dArr3, double[] dArr4, double[] dArr5, double[] dArr6) {
        int length = dArr.length;
        int length2 = dArr2[0].length;
        int i = length - length2;
        double[] dArr7 = new double[length];
        if (i < 1) {
            return false;
        }
        double[][] dArr8 = new double[length2][length2];
        double[] dArr9 = new double[length2];
        for (int i2 = 0; i2 < length2; i2++) {
            for (int i3 = 0; i3 < length2; i3++) {
                dArr8[i2][i3] = 0.0d;
            }
        }
        for (int i4 = 0; i4 < length2; i4++) {
            for (int i5 = 0; i5 < length2; i5++) {
                dArr8[i4][i5] = 0.0d;
                for (int i6 = 0; i6 < length; i6++) {
                    dArr8[i4][i5] = dArr8[i4][i5] + (dArr3[i6] * dArr2[i6][i4] * dArr2[i6][i5]);
                }
            }
            dArr9[i4] = 0.0d;
            for (int i7 = 0; i7 < length; i7++) {
                dArr9[i4] = dArr9[i4] + (dArr3[i7] * dArr2[i7][i4] * dArr[i7]);
            }
        }
        if (!symmetricMatrixInvert(dArr8)) {
            return false;
        }
        for (int i8 = 0; i8 < length2; i8++) {
            dArr4[i8] = 0.0d;
            for (int i9 = 0; i9 < length2; i9++) {
                dArr4[i8] = dArr4[i8] + (dArr8[i8][i9] * dArr9[i9]);
            }
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i10 = 0; i10 < length; i10++) {
            d3 += dArr3[i10] * dArr[i10];
            d4 += dArr3[i10];
        }
        double d5 = d3 / d4;
        for (int i11 = 0; i11 < length; i11++) {
            dArr7[i11] = 0.0d;
            for (int i12 = 0; i12 < length2; i12++) {
                dArr7[i11] = dArr7[i11] + (dArr4[i12] * dArr2[i11][i12]);
            }
            dArr6[i11] = dArr7[i11] - dArr[i11];
            d += dArr3[i11] * (dArr[i11] - d5) * (dArr[i11] - d5);
            d2 += dArr3[i11] * dArr6[i11] * dArr6[i11];
        }
        double d6 = d2 / i;
        double d7 = 1.0d - (d2 / d);
        if (d7 < 0.9999999d) {
            double d8 = ((d7 / (1.0d - d7)) * i) / (length2 - 1);
        }
        Math.sqrt(d6);
        for (int i13 = 0; i13 < length2; i13++) {
            for (int i14 = 0; i14 < length2; i14++) {
                dArr8[i13][i14] = dArr8[i13][i14] * d6;
            }
            dArr5[i13] = Math.sqrt(dArr8[i13][i13]);
        }
        return true;
    }

    public boolean symmetricMatrixInvert(double[][] dArr) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        double[] dArr3 = new double[length];
        double[] dArr4 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr4[i] = 1.0d;
        }
        int i2 = 0;
        for (int i3 = 0; i3 < length; i3++) {
            double d = 0.0d;
            for (int i4 = 0; i4 < length; i4++) {
                double abs = Math.abs(dArr[i4][i4]);
                if (abs > d && dArr4[i4] != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    d = abs;
                    i2 = i4;
                }
            }
            if (d == CMAESOptimizer.DEFAULT_STOPFITNESS) {
                return false;
            }
            dArr4[i2] = 0.0d;
            dArr3[i2] = 1.0d / dArr[i2][i2];
            dArr2[i2] = 1.0d;
            dArr[i2][i2] = 0.0d;
            if (i2 != 0) {
                for (int i5 = 0; i5 < i2; i5++) {
                    dArr2[i5] = dArr[i5][i2];
                    if (dArr4[i5] == CMAESOptimizer.DEFAULT_STOPFITNESS) {
                        dArr3[i5] = dArr[i5][i2] * dArr3[i2];
                    } else {
                        dArr3[i5] = (-dArr[i5][i2]) * dArr3[i2];
                    }
                    dArr[i5][i2] = 0.0d;
                }
            }
            if (i2 + 1 < length) {
                for (int i6 = i2 + 1; i6 < length; i6++) {
                    if (dArr4[i6] != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                        dArr2[i6] = dArr[i2][i6];
                    } else {
                        dArr2[i6] = -dArr[i2][i6];
                    }
                    dArr3[i6] = (-dArr[i2][i6]) * dArr3[i2];
                    dArr[i2][i6] = 0.0d;
                }
            }
            for (int i7 = 0; i7 < length; i7++) {
                i2 = i7;
                while (i2 < length) {
                    dArr[i7][i2] = dArr[i7][i2] + (dArr2[i7] * dArr3[i2]);
                    i2++;
                }
            }
        }
        int i8 = length;
        int i9 = length - 1;
        for (int i10 = 1; i10 < length; i10++) {
            i8--;
            i9--;
            for (int i11 = 0; i11 <= i9; i11++) {
                dArr[i8][i11] = dArr[i11][i8];
            }
        }
        return true;
    }

    public double beta(int i) {
        return this.beta.get(i, 0);
    }

    public double R2() {
        return 1.0d - (this.sse / this.sst);
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    public static void main(String[] strArr) {
        LinearRegression linearRegression = new LinearRegression(new double[]{new double[]{1.0d, 10.0d, 4343.0d}, new double[]{1.0d, 20.0d, 4356.0d}, new double[]{1.0d, 40.0d, 123.0d}, new double[]{1.0d, 80.0d, 14343.0d}, new double[]{1.0d, 160.0d, 2567.0d}, new double[]{1.0d, 200.0d, 1321.0d}}, new double[]{243.0d, 483.0d, 508.0d, 1503.0d, 1764.0d, 2129.0d}, false);
        linearRegression.fit();
        double d = linearRegression.calculatePValues()[1];
        double[] calculateTValues = linearRegression.calculateTValues();
        double[] se = linearRegression.getSE();
        System.out.printf("%.2f + %.2f beta1 + %.2f beta2  (R^2 = %.2f)\n", Double.valueOf(linearRegression.beta(0)), Double.valueOf(linearRegression.beta(1)), Double.valueOf(linearRegression.beta(2)), Double.valueOf(linearRegression.R2()));
        for (int i = 0; i < calculateTValues.length; i++) {
            System.out.println(calculateTValues[i] + "__" + se[i]);
        }
        System.out.println(d);
    }
}
