Skip to content

Commit

Permalink
Re-enable Chi Square test.
Browse files Browse the repository at this point in the history
  • Loading branch information
RobbieMcKinstry committed Nov 8, 2024
1 parent 39fc29b commit a0abc1a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 80 deletions.
94 changes: 21 additions & 73 deletions src/stats/chi.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,13 @@
use std::num::NonZeroU64;
use std::num::NonZeroUsize;

use statrs::distribution::{ChiSquared, ContinuousCDF};

use super::Categorical;

/// A `ContingencyTable` represents a one-dimensional array of observation counts.
/// Each element in the array represents the count of observations in the given category.
/// Categorical data with N categories uses the index `[0..N)` to reference the ith category.
trait ContingencyTable<const N: usize, C: Categorical<N>> {
/// Return the expected number of observations for the category labeled `index`.
fn expected(&self, index: usize) -> u32;
/// Return the observed number of observations for the category labeled `index`.
fn observed(&self, index: usize) -> u32;

/// returns the number of degrees of freedom for this table.
/// This is typically the number of categories minus one.
/// # Panics
/// This method panics if `N` is less than 2.
fn degrees_of_freedom(&self) -> NonZeroU64 {
if N < 2 {
panic!("The experiment must have at least two groups. Only {N} groups provided");
}
NonZeroU64::new(N as u64 - 1).unwrap()
}
}
use super::{Categorical, ContingencyTable};

/// Alpha represents the alpha cutoff, expressed as a floating point from [0, 1] inclusive.
/// For example, 0.95 is the standard 5% confidency interval.
fn chi_square_test<const N: usize, C: Categorical<N>>(
table: &impl ContingencyTable<N, C>,
table: &ContingencyTable<N, C>,
alpha: f64,
) -> bool {
assert!(alpha < 1.0);
Expand All @@ -38,12 +17,12 @@ fn chi_square_test<const N: usize, C: Categorical<N>>(
}

// calculate the chi square test statistic using the provided contingency tables.
fn test_statistic<const N: usize, C: Categorical<N>>(table: &impl ContingencyTable<N, C>) -> f64 {
fn test_statistic<const N: usize, C: Categorical<N>>(table: &ContingencyTable<N, C>) -> f64 {
let mut sum = 0.0;
// For each category, we calculate the square error between the expected and observed groups.
for i in 0..N {
let expected_count = table.expected(i) as i64;
let observed_count = table.observed(i) as i64;
let expected_count = table.expected_by_index(i) as i64;
let observed_count = table.observed_by_index(i) as i64;
let diff = observed_count - expected_count;
let error = diff.pow(2) as f64;
let incremental_error = error / (expected_count as f64);
Expand All @@ -55,71 +34,40 @@ fn test_statistic<const N: usize, C: Categorical<N>>(table: &impl ContingencyTab
/// calculates the p-value given the test statistic and the degrees of freedom.
/// This is determined by the area of the Chi Square distribution (which is a special
/// case of the gamma distribution).
fn p_value(test_statistic: f64, degrees_of_freedom: NonZeroU64) -> f64 {
let freedom = u64::from(degrees_of_freedom) as f64;
fn p_value(test_statistic: f64, degrees_of_freedom: NonZeroUsize) -> f64 {
let freedom = usize::from(degrees_of_freedom) as f64;
let distribution = ChiSquared::new(freedom).expect("Degrees of freedom must be >= 0");
1.0 - distribution.cdf(test_statistic)
}

#[cfg(test)]
mod tests {

use super::ContingencyTable;
use static_assertions::assert_obj_safe;
use std::num::NonZeroUsize;

// Require the contingency table is object-safe for certain commonly used categories.
assert_obj_safe!(ContingencyTable<5, String>);

/*
/// This simple smoke test shows that the `FixedFrequencyTable`
/// can have its frequencies set and accessed.
#[test]
fn enumerable_table() {
let mut table = FixedTable::new();
let groups = [(true, 30u64), (false, 70u64)];
// Put the values into the table.
for (group, freq) in groups {
table.set_group_count(group, freq);
}
// Retreive the values from the table.
for (group, freq) in groups {
let expected = freq;
let observed = table.group_count(&group);
assert_eq!(expected, observed);
}
// Demonstrate the number of degrees of freedom matches expectations.
assert_eq!(degrees_of_freedom(&table), NonZeroU64::new(1).unwrap());
}
*/
use super::{p_value, test_statistic};
use crate::stats::{contingency::Coin, ContingencyTable};

/*
/// Scenario: You flip a coin 50 times, and get 21 Heads and 29 Tails.
/// You want to determine if the coin is fair. Output the test statistic.
/// Let True represent Heads and False represent Tails.
#[test]
fn calc_test_statistic() {
let mut control_group = FixedTable::new();
control_group.set_group_count(true, 25);
control_group.set_group_count(false, 25);
let mut experimental_group = FixedTable::new();
experimental_group.set_group_count(true, 21);
experimental_group.set_group_count(false, 29);
assert_eq!(
degrees_of_freedom(&control_group),
NonZeroU64::new(1).unwrap()
);
assert_eq!(
degrees_of_freedom(&experimental_group),
NonZeroU64::new(1).unwrap()
);
let stat = test_statistic(&control_group, &experimental_group);
let mut table = ContingencyTable::new();
table.set_expected(&Coin::Heads, 25);
table.set_expected(&Coin::Tails, 25);
table.set_observed(&Coin::Heads, 21);
table.set_observed(&Coin::Tails, 29);
let degrees = table.degrees_of_freedom();
// We expect one degree of freedom since there are only two categories.
assert_eq!(degrees, NonZeroUsize::new(1).unwrap());
let stat = test_statistic(&table);
// Round the statistic to two decimal places.
let observed = (stat * 100.0).round() / 100.0;
let expected = 1.28;
assert_eq!(observed, expected);
// Now, calculate the p-value using the test statistic.
let pval = p_value(stat, degrees_of_freedom(&control_group));
let pval = p_value(stat, degrees);
assert!(0.25 < pval && pval < 0.30);
}
*/
}
28 changes: 21 additions & 7 deletions src/stats/contingency.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::num::NonZeroU64;
use std::num::NonZeroUsize;

use crate::stats::{histogram::Histogram, Categorical};

Expand Down Expand Up @@ -28,9 +28,15 @@ impl<const N: usize, C: Categorical<N>> ContingencyTable<N, C> {

/// Calculate the expected number of elements. This is a ratio
pub fn expected(&self, cat: &C) -> f64 {
let index = cat.category();
self.expected_by_index(index)
}

/// calculate the expected count for the category with index `i`.
pub fn expected_by_index(&self, i: usize) -> f64 {
// • Calculate the expected number of elements as a ratio
// of the total number of elements observed.
let expected_in_category = self.expected.get_count(cat) as f64;
let expected_in_category = self.expected.get_count_by_index(i) as f64;
let expected_total = self.expected.total() as f64;
// • Grab the total number of elements observed, and calculate
// using the ratio.
Expand All @@ -43,15 +49,20 @@ impl<const N: usize, C: Categorical<N>> ContingencyTable<N, C> {
expected_in_category * total_observed / expected_total
}

/// calculate the expected count for the category with index `i`.
pub fn observed_by_index(&self, i: usize) -> u32 {
self.observed.get_count_by_index(i)
}

/// returns the number of degrees of freedom for this table.
/// This is typically the number of categories minus one.
/// # Panics
/// This method panics if `N` is less than 2.
pub fn degrees_of_freedom(&self) -> NonZeroU64 {
pub fn degrees_of_freedom(&self) -> NonZeroUsize {
if N < 2 {
panic!("The experiment must have at least two groups. Only {N} groups provided");
}
NonZeroU64::new(N as u64 - 1).unwrap()
NonZeroUsize::new(N - 1).unwrap()
}

pub fn observed(&self, cat: &C) -> u32 {
Expand Down Expand Up @@ -81,9 +92,12 @@ impl<const N: usize, C: Categorical<N>> Default for ContingencyTable<N, C> {
}
}

#[cfg(test)]
pub(crate) use tests::Coin;

#[cfg(test)]
mod tests {
use std::num::NonZeroU64;
use std::num::NonZeroUsize;

use pretty_assertions::assert_eq;

Expand Down Expand Up @@ -161,14 +175,14 @@ mod tests {
#[test]
fn calc_degrees_of_freedom() {
let table: ContingencyTable<2, Coin> = ContingencyTable::new();
let expected = NonZeroU64::new(1).unwrap();
let expected = NonZeroUsize::new(1).unwrap();
let observed = table.degrees_of_freedom();
assert_eq!(observed, expected);
}

use crate::{metrics::ResponseStatusCode, stats::Categorical};
#[derive(PartialEq, Eq, Debug, Hash)]
enum Coin {
pub(crate) enum Coin {
Heads,
Tails,
}
Expand Down
7 changes: 7 additions & 0 deletions src/stats/histogram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ where
self.clear_category(cat);
self.increment_by(cat, count);
}

pub(super) fn get_count_by_index(&self, i: usize) -> u32 {
if i >= N {
panic!("Index out of bounds. The index provided must be a natural number less than the number of categories.");
}
self.bins[i]
}
}

impl<const N: usize, C> Default for Histogram<N, C>
Expand Down

0 comments on commit a0abc1a

Please sign in to comment.