/*
 * Decompiled with CFR 0.152.
 */
package umontreal.iro.lecuyer.probdist;

import umontreal.iro.lecuyer.functions.MathFunction;
import umontreal.iro.lecuyer.probdist.DiscreteDistributionInt;
import umontreal.iro.lecuyer.util.RootFinder;

public class LogarithmicDist
extends DiscreteDistributionInt {
    private double theta;
    private double t;

    public LogarithmicDist(double theta) {
        this.setTheta(theta);
    }

    public double prob(int x) {
        if (x < 1) {
            return 0.0;
        }
        return this.t * Math.pow(this.theta, x) / (double)x;
    }

    public double cdf(int x) {
        double res;
        if (x < 1) {
            return 0.0;
        }
        double term = res = this.prob(1);
        for (int i = 2; i <= x; ++i) {
            res += (term *= this.theta) / (double)i;
        }
        return res;
    }

    public double barF(int x) {
        double res;
        if (x <= 1) {
            return 1.0;
        }
        double term = res = this.prob(x);
        int i = x + 1;
        while (term > EPSILON) {
            res += (term *= this.theta * (double)(i - 1) / (double)i);
        }
        return res;
    }

    public int inverseFInt(double u) {
        return LogarithmicDist.inverseF(this.theta, u);
    }

    public double getMean() {
        return LogarithmicDist.getMean(this.theta);
    }

    public double getVariance() {
        return LogarithmicDist.getVariance(this.theta);
    }

    public double getStandardDeviation() {
        return LogarithmicDist.getStandardDeviation(this.theta);
    }

    public static double prob(double theta, int x) {
        if (theta <= 0.0 || theta >= 1.0) {
            throw new IllegalArgumentException("theta not in range (0,1)");
        }
        if (x < 1) {
            return 0.0;
        }
        return -1.0 / Math.log1p(-theta) * Math.pow(theta, x) / (double)x;
    }

    public static double cdf(double theta, int x) {
        double res;
        if (theta <= 0.0 || theta >= 1.0) {
            throw new IllegalArgumentException("theta not in range (0,1)");
        }
        if (x < 1) {
            return 0.0;
        }
        double term = res = LogarithmicDist.prob(theta, 1);
        for (int i = 2; i <= x; ++i) {
            res += (term *= theta) / (double)i;
        }
        return res;
    }

    public static double barF(double theta, int x) {
        double res;
        if (theta <= 0.0 || theta >= 1.0) {
            throw new IllegalArgumentException("theta not in range (0,1)");
        }
        if (x <= 1) {
            return 1.0;
        }
        double term = res = LogarithmicDist.prob(theta, x);
        int i = x + 1;
        while (term > EPSILON) {
            res += (term *= theta * (double)(i - 1) / (double)i);
        }
        return res;
    }

    public static int inverseF(double theta, double u) {
        throw new UnsupportedOperationException();
    }

    public static double[] getMLE(int[] x, int n) {
        if (n <= 0) {
            throw new IllegalArgumentException("n <= 0");
        }
        double[] parameters = new double[1];
        double sum = 0.0;
        for (int i = 0; i < n; ++i) {
            sum += (double)x[i];
        }
        double mean = sum / (double)n;
        Function f2 = new Function(mean);
        parameters[0] = RootFinder.brentDekker(1.0E-15, 0.999999999999999, f2, 1.0E-7);
        return parameters;
    }

    public static LogarithmicDist getInstanceFromMLE(int[] x, int n) {
        double[] parameters = LogarithmicDist.getMLE(x, n);
        return new LogarithmicDist(parameters[0]);
    }

    public static double getMean(double theta) {
        if (theta <= 0.0 || theta >= 1.0) {
            throw new IllegalArgumentException("theta not in range (0,1)");
        }
        return -1.0 / Math.log1p(-theta) * (theta / (1.0 - theta));
    }

    public static double getVariance(double theta) {
        if (theta <= 0.0 || theta >= 1.0) {
            throw new IllegalArgumentException("theta not in range (0,1)");
        }
        double v = Math.log1p(-theta);
        return -theta * (theta + v) / ((1.0 - theta) * (1.0 - theta) * v * v);
    }

    public static double getStandardDeviation(double theta) {
        return Math.sqrt(LogarithmicDist.getVariance(theta));
    }

    public double getTheta() {
        return this.theta;
    }

    public void setTheta(double theta) {
        if (theta <= 0.0 || theta >= 1.0) {
            throw new IllegalArgumentException("theta not in range (0,1)");
        }
        this.theta = theta;
        this.t = -1.0 / Math.log1p(-theta);
        this.supportA = 1;
    }

    public double[] getParams() {
        double[] retour = new double[]{this.theta};
        return retour;
    }

    public String toString() {
        return this.getClass().getSimpleName() + " : theta = " + this.theta;
    }

    private static class Function
    implements MathFunction {
        protected double mean;

        public Function(double mean) {
            this.mean = mean;
        }

        public double evaluate(double x) {
            if (x <= 0.0 || x >= 1.0) {
                return 1.0E200;
            }
            return x + this.mean * (1.0 - x) * Math.log1p(-x);
        }
    }
}

