diff --git a/Cargo.lock b/Cargo.lock index 33f32d9..bb17d58 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -81,6 +81,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -504,6 +513,18 @@ version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +[[package]] +name = "bytemuck" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.7.2" @@ -537,6 +558,7 @@ dependencies = [ "serde", "serde_json", "static_assertions", + "statrs", "thiserror", "tokio", "tokio-stream", @@ -1036,6 +1058,12 @@ version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -1058,6 +1086,16 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.4" @@ -1116,6 +1154,44 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "nalgebra" +version = "0.32.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4" +dependencies = [ + "approx", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "rand", + "rand_distr", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -1131,6 +1207,16 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -1138,6 +1224,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -1196,6 +1283,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "percent-encoding" version = "2.3.1" @@ -1220,6 +1313,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + [[package]] name = "pretty_assertions" version = "1.4.1" @@ -1248,6 +1350,52 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "redox_syscall" version = "0.5.7" @@ -1355,6 +1503,15 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "safe_arch" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3460605018fdc9612bce72735cba0d27efbcd9904780d44c7e3a9948f96148a" +dependencies = [ + "bytemuck", +] + [[package]] name = "schannel" version = "0.1.26" @@ -1467,6 +1624,19 @@ dependencies = [ "libc", ] +[[package]] +name = "simba" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + [[package]] name = "slab" version = "0.4.9" @@ -1510,6 +1680,18 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "statrs" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f697a07e4606a0a25c044de247e583a330dbb1731d11bc7350b81f48ad567255" +dependencies = [ + "approx", + "nalgebra", + "num-traits", + "rand", +] + [[package]] name = "strsim" version = "0.11.1" @@ -1902,6 +2084,16 @@ version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" +[[package]] +name = "wide" +version = "0.7.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b828f995bf1e9622031f8009f8481a85406ce1f4d4588ff746d872043e855690" +dependencies = [ + "bytemuck", + "safe_arch", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -2071,6 +2263,27 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zeroize" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index 8c03e5b..215f36e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ futures-util = "0.3.31" miette = { version = "7", features = ["fancy"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +statrs = "0.17.1" thiserror = "1.0.61" tokio = { version = "1.37.0", features = ["full"] } tokio-stream = { version = "0.1", features = ["time"] } diff --git a/src/pipeline/mod.rs b/src/pipeline/mod.rs index 786d3db..b1de561 100644 --- a/src/pipeline/mod.rs +++ b/src/pipeline/mod.rs @@ -42,7 +42,7 @@ pub fn repeat_query( // be moved between iterations. pin!(timer); // Each iteration of the loop represents one unit of tiem. - while let Some(_) = timer.next().await { + while timer.next().await.is_some() { // • We perform the query then dump the results into the stream. let items = observer.query().await; for item in items { diff --git a/src/stats/chi.rs b/src/stats/chi.rs index 7d42862..fccf035 100644 --- a/src/stats/chi.rs +++ b/src/stats/chi.rs @@ -1,33 +1,40 @@ use std::collections::{HashMap, HashSet}; use std::hash::Hash; +use std::num::NonZeroU64; -/// A ContingencyTable expresses the frequency with which a category was observed. -/// Usually, it tracks the number of observations in ecah category, but when the +use statrs::distribution::{ChiSquared, ContinuousCDF}; + +/// A ContingencyTable expresses the frequency with which a group was observed. +/// Usually, it tracks the number of observations in ecah group, but when the /// number is already known (i.e. its fixed, like a fair dice or coin), it can -/// expose just the frequencies for each category. +/// expose just the frequencies for each group. pub trait ContingencyTable { - /// return the frequency of the provided category as a number in the range [0, 1]. - /// If this is an empirical table (i.e. its values were from observations), - /// then this is the number of times the category was observed - /// divided by the total number of observations. - fn group_count(&self, cat: &Group) -> usize; + /// return the number of observations of the in the provided group. + fn group_count(&self, cat: &Group) -> u64; + /// Return the set of groups that serve as columns of the contingency table. fn groups(&self) -> Box>; - /// returns the number of degrees of freedom for this table. - /// This is typically the number of groups minus one. - fn degrees_of_freedom(&self) -> usize { - self.groups().count() - 1 - } - // returns the total number of observations made. This should be the sum // of the group count for every group. - fn total_count(&self) -> usize { + fn total_count(&self) -> u64 { self.groups() .fold(0, |sum, group| sum + self.group_count(&group)) } } +/// returns the number of degrees of freedom for this table. +/// This is typically the number of groups minus one. +/// # Panics +/// This method panics if the number of groups returned by `groups` is less than 2. +fn degrees_of_freedom(table: &impl ContingencyTable) -> NonZeroU64 { + let group_count = table.groups().count() as u64; + if group_count < 2 { + panic!("The experiment must have at least two groups. Only {group_count} groups provided"); + } + NonZeroU64::new(group_count - 1).unwrap() +} + /// This helper trait identifies a category with a known set of groups. /// For example, if modeling bools, the groups are True and False. If modeling /// a six sided die, the groups would be 1 through 6. @@ -68,7 +75,7 @@ pub struct FixedContingencyTable where C: EnumerableCategory + Hash + Eq, { - counts: HashMap, + counts: HashMap, } impl FixedContingencyTable @@ -87,12 +94,13 @@ where } /// Sets the expected count of the category to the value provided. - pub fn set_group_count(&mut self, cat: C, count: usize) { + pub fn set_group_count(&mut self, cat: C, count: u64) { self.counts.insert(cat, count); } - /// Returns the frequency of the provide category. - pub fn group_count(&self, cat: &C) -> usize { + /// Returns the number of observations that were classified as + /// having this group/category. + pub fn group_count(&self, cat: &C) -> u64 { self.counts[cat] } } @@ -101,14 +109,7 @@ impl ContingencyTable for FixedContingencyTable where C: EnumerableCategory + Hash + Eq, { - /// Return the number of degrees of freedom, which is the number of - /// groups minus 1. - fn degrees_of_freedom(&self) -> usize { - // The number of degrees of freedom is the number of groups minus one. - self.counts.len() - 1 - } - - fn group_count(&self, cat: &C) -> usize { + fn group_count(&self, cat: &C) -> u64 { // delegate to the method on the base class. Self::group_count(self, cat) } @@ -119,10 +120,31 @@ where } } +/// Alpha represents the alpha cutoff, expressed as a floating point from [0, 1] inclusive. +/// For example, 0.95 is the standard 5% confidency interval. +pub fn chi_square_test( + observed: &impl ContingencyTable, + expected: &impl ContingencyTable, + alpha: f64, +) -> bool +where + Cat: EnumerableCategory + Hash + Eq, +{ + assert!(alpha < 1.0); + assert_eq!( + degrees_of_freedom(observed), + degrees_of_freedom(expected), + "Expected the degrees of freedom from both groups to be the same." + ); + let stat = test_statistic(expected, observed); + let pval = p_value(stat, degrees_of_freedom(observed)); + pval < alpha +} + // calculate the chi square test statistic using the provided contingency tables. fn test_statistic( - control: impl ContingencyTable, - experimental: impl ContingencyTable, + control: &impl ContingencyTable, + experimental: &impl ContingencyTable, ) -> f64 { // • First, get the set of groups. We can't assume that // both table have the same groups, so we deduplicate them using @@ -142,12 +164,18 @@ fn test_statistic( }) } +fn p_value(test_statistic: f64, degrees_of_freedom: NonZeroU64) -> f64 { + let freedom = u64::from(degrees_of_freedom) as f64; + let distribution = ChiSquared::new(freedom).expect("Degrees of freedom must be >= 0"); + 1.0 - distribution.cdf(test_statistic) +} + #[cfg(test)] mod tests { - use std::collections::HashSet; + use std::{collections::HashSet, num::NonZeroU64}; - use crate::stats::chi::FixedContingencyTable; + use crate::stats::chi::{degrees_of_freedom, p_value, FixedContingencyTable}; use super::{test_statistic, ContingencyTable, EnumerableCategory}; use pretty_assertions::assert_eq; @@ -176,7 +204,7 @@ mod tests { #[test] fn enumerable_table() { let mut table = FixedContingencyTable::new(); - let groups = [(true, 30), (false, 70)]; + let groups = [(true, 30u64), (false, 70u64)]; // Put the values into the table. for (group, freq) in groups { table.set_group_count(group, freq); @@ -188,7 +216,7 @@ mod tests { assert_eq!(expected, observed); } // Demonstrate the number of degrees of freedom matches expectations. - assert_eq!(table.degrees_of_freedom(), 1); + assert_eq!(degrees_of_freedom(&table), NonZeroU64::new(1).unwrap()); } /// Scenario: You flip a coin 50 times, and get 21 Heads and 29 Tails. @@ -202,12 +230,21 @@ mod tests { let mut experimental_group = FixedContingencyTable::new(); experimental_group.set_group_count(true, 21); experimental_group.set_group_count(false, 29); - assert_eq!(control_group.degrees_of_freedom(), 1); - assert_eq!(experimental_group.degrees_of_freedom(), 1); - let stat = test_statistic(control_group, experimental_group); + assert_eq!( + degrees_of_freedom(&control_group), + NonZeroU64::new(1).unwrap() + ); + assert_eq!( + degrees_of_freedom(&experimental_group), + NonZeroU64::new(1).unwrap() + ); + let stat = test_statistic(&control_group, &experimental_group); // Round the statistic to two decimal places. let observed = (stat * 100.0).round() / 100.0; let expected = 1.28; assert_eq!(observed, expected); + // Now, calculate the p-value using the test statistic. + let pval = p_value(stat, degrees_of_freedom(&control_group)); + assert!(0.25 < pval && pval < 0.30); } } diff --git a/src/stats/mod.rs b/src/stats/mod.rs index f218db7..e6b530a 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -20,6 +20,12 @@ pub struct ChiSquareEngine { alpha_cutoff: f64, } +impl Default for ChiSquareEngine { + fn default() -> Self { + Self::new() + } +} + impl ChiSquareEngine { pub fn new() -> Self { Self {