diff --git a/src/stats/chi.rs b/src/stats/chi.rs index 66c44f1..2bec847 100644 --- a/src/stats/chi.rs +++ b/src/stats/chi.rs @@ -1,34 +1,13 @@ -use std::num::NonZeroU64; +use std::num::NonZeroUsize; use statrs::distribution::{ChiSquared, ContinuousCDF}; -use super::Categorical; - -/// A `ContingencyTable` represents a one-dimensional array of observation counts. -/// Each element in the array represents the count of observations in the given category. -/// Categorical data with N categories uses the index `[0..N)` to reference the ith category. -trait ContingencyTable> { - /// Return the expected number of observations for the category labeled `index`. - fn expected(&self, index: usize) -> u32; - /// Return the observed number of observations for the category labeled `index`. - fn observed(&self, index: usize) -> u32; - - /// returns the number of degrees of freedom for this table. - /// This is typically the number of categories minus one. - /// # Panics - /// This method panics if `N` is less than 2. - fn degrees_of_freedom(&self) -> NonZeroU64 { - if N < 2 { - panic!("The experiment must have at least two groups. Only {N} groups provided"); - } - NonZeroU64::new(N as u64 - 1).unwrap() - } -} +use super::{Categorical, ContingencyTable}; /// Alpha represents the alpha cutoff, expressed as a floating point from [0, 1] inclusive. /// For example, 0.95 is the standard 5% confidency interval. fn chi_square_test>( - table: &impl ContingencyTable, + table: &ContingencyTable, alpha: f64, ) -> bool { assert!(alpha < 1.0); @@ -38,12 +17,12 @@ fn chi_square_test>( } // calculate the chi square test statistic using the provided contingency tables. -fn test_statistic>(table: &impl ContingencyTable) -> f64 { +fn test_statistic>(table: &ContingencyTable) -> f64 { let mut sum = 0.0; // For each category, we calculate the square error between the expected and observed groups. for i in 0..N { - let expected_count = table.expected(i) as i64; - let observed_count = table.observed(i) as i64; + let expected_count = table.expected_by_index(i) as i64; + let observed_count = table.observed_by_index(i) as i64; let diff = observed_count - expected_count; let error = diff.pow(2) as f64; let incremental_error = error / (expected_count as f64); @@ -55,8 +34,8 @@ fn test_statistic>(table: &impl ContingencyTab /// calculates the p-value given the test statistic and the degrees of freedom. /// This is determined by the area of the Chi Square distribution (which is a special /// case of the gamma distribution). -fn p_value(test_statistic: f64, degrees_of_freedom: NonZeroU64) -> f64 { - let freedom = u64::from(degrees_of_freedom) as f64; +fn p_value(test_statistic: f64, degrees_of_freedom: NonZeroUsize) -> f64 { + let freedom = usize::from(degrees_of_freedom) as f64; let distribution = ChiSquared::new(freedom).expect("Degrees of freedom must be >= 0"); 1.0 - distribution.cdf(test_statistic) } @@ -64,62 +43,31 @@ fn p_value(test_statistic: f64, degrees_of_freedom: NonZeroU64) -> f64 { #[cfg(test)] mod tests { - use super::ContingencyTable; - use static_assertions::assert_obj_safe; + use std::num::NonZeroUsize; - // Require the contingency table is object-safe for certain commonly used categories. - assert_obj_safe!(ContingencyTable<5, String>); - - /* - /// This simple smoke test shows that the `FixedFrequencyTable` - /// can have its frequencies set and accessed. - #[test] - fn enumerable_table() { - let mut table = FixedTable::new(); - let groups = [(true, 30u64), (false, 70u64)]; - // Put the values into the table. - for (group, freq) in groups { - table.set_group_count(group, freq); - } - // Retreive the values from the table. - for (group, freq) in groups { - let expected = freq; - let observed = table.group_count(&group); - assert_eq!(expected, observed); - } - // Demonstrate the number of degrees of freedom matches expectations. - assert_eq!(degrees_of_freedom(&table), NonZeroU64::new(1).unwrap()); - } - */ + use super::{p_value, test_statistic}; + use crate::stats::{contingency::Coin, ContingencyTable}; - /* /// Scenario: You flip a coin 50 times, and get 21 Heads and 29 Tails. /// You want to determine if the coin is fair. Output the test statistic. /// Let True represent Heads and False represent Tails. #[test] fn calc_test_statistic() { - let mut control_group = FixedTable::new(); - control_group.set_group_count(true, 25); - control_group.set_group_count(false, 25); - let mut experimental_group = FixedTable::new(); - experimental_group.set_group_count(true, 21); - experimental_group.set_group_count(false, 29); - assert_eq!( - degrees_of_freedom(&control_group), - NonZeroU64::new(1).unwrap() - ); - assert_eq!( - degrees_of_freedom(&experimental_group), - NonZeroU64::new(1).unwrap() - ); - let stat = test_statistic(&control_group, &experimental_group); + let mut table = ContingencyTable::new(); + table.set_expected(&Coin::Heads, 25); + table.set_expected(&Coin::Tails, 25); + table.set_observed(&Coin::Heads, 21); + table.set_observed(&Coin::Tails, 29); + let degrees = table.degrees_of_freedom(); + // We expect one degree of freedom since there are only two categories. + assert_eq!(degrees, NonZeroUsize::new(1).unwrap()); + let stat = test_statistic(&table); // Round the statistic to two decimal places. let observed = (stat * 100.0).round() / 100.0; let expected = 1.28; assert_eq!(observed, expected); // Now, calculate the p-value using the test statistic. - let pval = p_value(stat, degrees_of_freedom(&control_group)); + let pval = p_value(stat, degrees); assert!(0.25 < pval && pval < 0.30); } - */ } diff --git a/src/stats/contingency.rs b/src/stats/contingency.rs index 58fd1d5..034cf90 100644 --- a/src/stats/contingency.rs +++ b/src/stats/contingency.rs @@ -1,4 +1,4 @@ -use std::num::NonZeroU64; +use std::num::NonZeroUsize; use crate::stats::{histogram::Histogram, Categorical}; @@ -28,9 +28,15 @@ impl> ContingencyTable { /// Calculate the expected number of elements. This is a ratio pub fn expected(&self, cat: &C) -> f64 { + let index = cat.category(); + self.expected_by_index(index) + } + + /// calculate the expected count for the category with index `i`. + pub fn expected_by_index(&self, i: usize) -> f64 { // • Calculate the expected number of elements as a ratio // of the total number of elements observed. - let expected_in_category = self.expected.get_count(cat) as f64; + let expected_in_category = self.expected.get_count_by_index(i) as f64; let expected_total = self.expected.total() as f64; // • Grab the total number of elements observed, and calculate // using the ratio. @@ -43,15 +49,20 @@ impl> ContingencyTable { expected_in_category * total_observed / expected_total } + /// calculate the expected count for the category with index `i`. + pub fn observed_by_index(&self, i: usize) -> u32 { + self.observed.get_count_by_index(i) + } + /// returns the number of degrees of freedom for this table. /// This is typically the number of categories minus one. /// # Panics /// This method panics if `N` is less than 2. - pub fn degrees_of_freedom(&self) -> NonZeroU64 { + pub fn degrees_of_freedom(&self) -> NonZeroUsize { if N < 2 { panic!("The experiment must have at least two groups. Only {N} groups provided"); } - NonZeroU64::new(N as u64 - 1).unwrap() + NonZeroUsize::new(N - 1).unwrap() } pub fn observed(&self, cat: &C) -> u32 { @@ -81,9 +92,12 @@ impl> Default for ContingencyTable { } } +#[cfg(test)] +pub(crate) use tests::Coin; + #[cfg(test)] mod tests { - use std::num::NonZeroU64; + use std::num::NonZeroUsize; use pretty_assertions::assert_eq; @@ -161,14 +175,14 @@ mod tests { #[test] fn calc_degrees_of_freedom() { let table: ContingencyTable<2, Coin> = ContingencyTable::new(); - let expected = NonZeroU64::new(1).unwrap(); + let expected = NonZeroUsize::new(1).unwrap(); let observed = table.degrees_of_freedom(); assert_eq!(observed, expected); } use crate::{metrics::ResponseStatusCode, stats::Categorical}; #[derive(PartialEq, Eq, Debug, Hash)] - enum Coin { + pub(crate) enum Coin { Heads, Tails, } diff --git a/src/stats/histogram.rs b/src/stats/histogram.rs index c8cb249..85a4d3b 100644 --- a/src/stats/histogram.rs +++ b/src/stats/histogram.rs @@ -70,6 +70,13 @@ where self.clear_category(cat); self.increment_by(cat, count); } + + pub(super) fn get_count_by_index(&self, i: usize) -> u32 { + if i >= N { + panic!("Index out of bounds. The index provided must be a natural number less than the number of categories."); + } + self.bins[i] + } } impl Default for Histogram