Skip to content

Commit

Permalink
optimal transport formulation. wip
Browse files Browse the repository at this point in the history
  • Loading branch information
krukah committed Oct 20, 2024
1 parent 32e0b6e commit 2feda7a
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 141 deletions.
11 changes: 7 additions & 4 deletions src/clustering/abstractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ impl Abstractor {
.inner() // cluster turn
.save()
.inner() // cluster flop
.save()
.inner() // cluster preflop (but really just save flop.metric)
.save();
}
}
Expand Down Expand Up @@ -236,9 +238,10 @@ impl Abstractor {
/// 3. Write the extension (4 bytes)
/// 4. Write the observation and abstraction pairs
/// 5. Write the trailer (2 bytes)
pub fn save(&self, name: String) {
log::info!("saving abstraction lookup {}", name);
let ref mut file = File::create(format!("{}.abstraction.pgcopy", name)).expect("new file");
pub fn save(&self, street: Street) {
log::info!("{:<32}{:<32}", "saving abstraction lookup", street);
let ref mut file =
File::create(format!("{}.abstraction.pgcopy", street)).expect("new file");
file.write_all(b"PGCOPY\n\xff\r\n\0").expect("header");
file.write_u32::<BigEndian>(0).expect("flags");
file.write_u32::<BigEndian>(0).expect("extension");
Expand Down Expand Up @@ -272,7 +275,7 @@ mod tests {
.map(|o| (o, Abstraction::random()))
.collect(),
);
save.save(street.to_string());
save.save(street);
// Load from disk
let load = Abstractor::load_street(street);
std::iter::empty()
Expand Down
15 changes: 13 additions & 2 deletions src/clustering/histogram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,14 @@ impl Histogram {
/// all witnessed Abstractions.
/// treat this like an unordered array
/// even though we use BTreeMap for struct.
pub fn support(&self) -> Vec<&Abstraction> {
self.contribution.keys().collect()
pub fn support(&self) -> impl Iterator<Item = &Abstraction> {
self.contribution.keys()
}
pub fn normalized(&self) -> BTreeMap<Abstraction, f32> {
self.contribution
.iter()
.map(|(&a, &count)| (a, count as f32 / self.mass as f32))
.collect()
}

/// useful only for k-means edge case of centroid drift
Expand All @@ -39,6 +45,11 @@ impl Histogram {
self.contribution.is_empty()
}

/// size of the support
pub fn size(&self) -> usize {
self.contribution.len()
}

/// insert the Abstraction into our support,
/// incrementing its local weight,
/// incrementing our global norm.
Expand Down
97 changes: 52 additions & 45 deletions src/clustering/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,6 @@ use rayon::iter::IntoParallelRefIterator;
use rayon::iter::ParallelIterator;
use std::collections::BTreeMap;

/// number of kmeans centroids.
/// this determines the granularity of the abstraction space.
///
/// - CPU: O(N^2) for kmeans initialization
/// - CPU: O(N) for kmeans clustering
/// - RAM: O(N^2) for learned metric
/// - RAM: O(N) for learned centroids
const N_KMEANS_CENTROIDS: usize = 256;

/// number of kmeans iterations.
/// this controls the precision of the abstraction space.
///
/// - CPU: O(N) for kmeans clustering
const N_KMEANS_ITERATION: usize = 64;

/// Hierarchical K Means Learner.
/// this is decomposed into the necessary data structures
/// for kmeans clustering to occur for a given `Street`.
Expand Down Expand Up @@ -61,6 +46,35 @@ pub struct Layer {
}

impl Layer {
/// number of kmeans centroids.
/// this determines the granularity of the abstraction space.
///
/// - CPU: O(N^2) for kmeans initialization
/// - CPU: O(N) for kmeans clustering
/// - RAM: O(N^2) for learned metric
/// - RAM: O(N) for learned centroids
const fn k(street: Street) -> usize {
match street {
Street::Pref => 169,
Street::Flop => 8,
Street::Turn => 8,
Street::Rive => unreachable!(),
}
}

/// number of kmeans iterations.
/// this controls the precision of the abstraction space.
///
/// - CPU: O(N) for kmeans clustering
const fn t(street: Street) -> usize {
match street {
Street::Pref => 0,
Street::Flop => 128,
Street::Turn => 32,
Street::Rive => unreachable!(),
}
}

/// start with the River layer. everything is empty because we
/// can generate `Abstractor` and `SmallSpace` from "scratch".
/// - `lookup`: lazy equity calculation of river observations
Expand Down Expand Up @@ -95,8 +109,8 @@ impl Layer {
}
/// save the current layer's `Metric` and `Abstractor` to disk
pub fn save(self) -> Self {
self.metric.save(format!("{}", self.street.next())); // outer layer generates this purely (metric over projections)
self.lookup.save(format!("{}", self.street)); // while inner layer generates this (clusters)
self.metric.save(self.street.next()); // outer layer generates this purely (metric over projections)
self.lookup.save(self.street); // while inner layer generates this (clusters)
self
}

Expand All @@ -115,7 +129,7 @@ impl Layer {
///
/// we symmetrize the distance by averaging the EMDs in both directions.
/// the distnace isn't symmetric in the first place only because our heuristic algo is not fully accurate
pub fn inner_metric(&self) -> Metric {
fn inner_metric(&self) -> Metric {
log::info!(
"{:<32}{:<32}",
"computing metric",
Expand Down Expand Up @@ -170,13 +184,13 @@ impl Layer {
log::info!(
"{:<32}{:<32}",
"declaring abstractions",
format!("{} {} clusters", self.street, N_KMEANS_CENTROIDS)
format!("{} {} clusters", self.street, Self::k(self.street))
);
let ref mut rng = rand::thread_rng();
let progress = Self::progress(N_KMEANS_CENTROIDS);
let progress = Self::progress(Self::k(self.street));
self.kmeans.expand(self.sample_uniform(rng));
progress.inc(1);
while self.kmeans.0.len() < N_KMEANS_CENTROIDS {
while self.kmeans.0.len() < Self::k(self.street) {
self.kmeans.expand(self.sample_outlier(rng));
progress.inc(1);
}
Expand All @@ -189,17 +203,16 @@ impl Layer {
log::info!(
"{:<32}{:<32}",
"clustering observations",
format!("{} {} iterations", self.street, N_KMEANS_ITERATION)
format!("{} {} iterations", self.street, Self::t(self.street))
);
let progress = Self::progress(N_KMEANS_ITERATION);
for _ in 0..N_KMEANS_ITERATION {
let progress = Self::progress(Self::t(self.street));
for _ in 0..Self::t(self.street) {
let neighbors = self
.points
.0
.par_iter()
.map(|(_, h)| self.nearest_neighbor(h))
.collect::<Vec<(Abstraction, f32)>>();
self.kmeans.clear();
self.assign_nearest_neighbor(neighbors);
self.assign_orphans_randomly();
progress.inc(1);
Expand All @@ -211,36 +224,33 @@ impl Layer {
/// by computing the EMD distance between the `Observation`'s `Histogram` and each `Centroid`'s `Histogram`
/// and returning the `Abstraction` of the nearest `Centroid`
fn assign_nearest_neighbor(&mut self, neighbors: Vec<(Abstraction, f32)>) {
self.kmeans.clear();
let mut loss = 0.;
for ((observation, histogram), (abstraction, distance)) in
std::iter::zip(self.points.0.iter_mut(), neighbors.iter())
{
loss += distance * distance;
self.lookup.assign(abstraction, observation);
self.kmeans.absorb(abstraction, histogram);
for ((obs, hist), (abs, dist)) in self.points.0.iter_mut().zip(neighbors.iter()) {
loss += dist * dist;
self.lookup.assign(abs, obs);
self.kmeans.absorb(abs, hist);
}
log::debug!("LOSS {:>12.8}", loss / self.points.0.len() as f32);
let loss = loss / self.points.0.len() as f32;
log::trace!("LOSS {:>12.8}", loss);
}
/// centroid drift may make it such that some centroids are empty
/// so we reinitialize empty centroids with random Observations if necessary
fn assign_orphans_randomly(&mut self) {
for ref a in self.kmeans.orphans() {
log::warn!(
"{:<32}{:<32}",
"reassigning empty centroid",
format!("0x{}", a)
);
let ref mut rng = rand::thread_rng();
let ref sample = self.sample_uniform(rng);
self.kmeans.absorb(a, sample);
log::debug!(
"{:<32}{:<32}",
"reassigned empty centroid",
format!("0x{}", a)
);
}
}

/// the first Centroid is uniformly random across all `Observation` `Histogram`s
fn sample_uniform<R>(&self, rng: &mut R) -> Histogram
where
R: Rng,
{
fn sample_uniform<R: Rng>(&self, rng: &mut R) -> Histogram {
self.points
.0
.values()
Expand All @@ -251,10 +261,7 @@ impl Layer {
/// each next Centroid is selected with probability proportional to
/// the squared distance to the nearest neighboring Centroid.
/// faster convergence, i guess. on the shoulders of giants
fn sample_outlier<R>(&self, rng: &mut R) -> Histogram
where
R: Rng,
{
fn sample_outlier<R: Rng>(&self, rng: &mut R) -> Histogram {
let weights = self
.points
.0
Expand Down
Loading

0 comments on commit 2feda7a

Please sign in to comment.