Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inverse_cdf for Gamma #227

Merged
merged 5 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 80 additions & 9 deletions src/distribution/gamma.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::distribution::{Continuous, ContinuousCDF};
use crate::function::gamma;
use crate::prec;
use crate::statistics::*;
use crate::{Result, StatsError};
use rand::Rng;
Expand Down Expand Up @@ -132,20 +133,60 @@ impl ContinuousCDF<f64, f64> for Gamma {
fn sf(&self, x: f64) -> f64 {
if x <= 0.0 {
1.0
}
else if ulps_eq!(x, self.shape) && self.rate.is_infinite() {
} else if ulps_eq!(x, self.shape) && self.rate.is_infinite() {
0.0
}
else if self.rate.is_infinite() {
} else if self.rate.is_infinite() {
1.0
}
else if x.is_infinite() {
} else if x.is_infinite() {
0.0
}
else {
} else {
gamma::gamma_ur(self.shape, x * self.rate)
}
}

fn inverse_cdf(&self, p: f64) -> f64 {
if !(0.0..=1.0).contains(&p) {
panic!("default inverse_cdf implementation should be provided probability on [0,1]")
}
if p == 0.0 {
return self.min();
};
if p == 1.0 {
return self.max();
};

// Bisection search for MAX_ITERS.0 iterations
let mut high = 2.0;
let mut low = 1.0;
while self.cdf(low) > p {
low /= 2.0;
}
while self.cdf(high) < p {
high *= 2.0;
}
let mut x_0 = (high + low) / 2.0;

for _ in 0..8 {
if self.cdf(x_0) >= p {
high = x_0;
} else {
low = x_0;
}
if prec::convergence(&mut x_0, (high + low) / 2.0) {
break;
}
}

// Newton Raphson, for at least one step
for _ in 0..4 {
let x_next = x_0 - (self.cdf(x_0) - p) / self.pdf(x_0);
if prec::convergence(&mut x_0, x_next) {
break;
}
}

x_0
}
}

impl Min<f64> for Gamma {
Expand Down Expand Up @@ -456,7 +497,11 @@ mod tests {
for &(arg, res) in test.iter() {
test_case_special(arg, res, 10e-6, f);
}
let test = [((10.0, 10.0), 0.9), ((10.0, 1.0), 9.0), ((10.0, f64::INFINITY), 0.0)];
let test = [
((10.0, 10.0), 0.9),
((10.0, 1.0), 9.0),
((10.0, f64::INFINITY), 0.0),
];
for &(arg, res) in test.iter() {
test_case(arg, res, f);
}
Expand Down Expand Up @@ -562,6 +607,32 @@ mod tests {
test_case((1.0, 0.1), 0.0, |x| x.cdf(0.0));
}

#[test]
fn test_cdf_inverse_identity() {
let f = |p: f64| move |g: Gamma| g.cdf(g.inverse_cdf(p));
let params = [
(1.0, 0.1),
(1.0, 1.0),
(10.0, 10.0),
(10.0, 1.0),
(100.0, 200.0),
];

for param in params {
for n in -5..0 {
let p = 10.0f64.powi(n);
test_case(param, p, f(p));
}
}

// test case from issue #200
{
let x = 20.5567;
let f = |x: f64| move |g: Gamma| g.inverse_cdf(g.cdf(x));
test_case((3.0, 0.5), x, f(x))
}
}

#[test]
fn test_sf() {
let f = |arg: f64| move |x: Gamma| x.sf(arg);
Expand Down
9 changes: 9 additions & 0 deletions src/prec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,12 @@ pub fn almost_eq(a: f64, b: f64, acc: f64) -> bool {

(a - b).abs() < acc
}

/// Compares if two floats are close via `approx::relative_eq!`
/// and `crate::consts::ACC` relative precision.
/// Updates first argument to value of second argument
pub fn convergence(x: &mut f64, x_new: f64) -> bool {
let res = approx::relative_eq!(*x, x_new, max_relative = crate::consts::ACC);
*x = x_new;
res
}
Loading