Skip to content

Commit

Permalink
Use separate function for computing average and filtering NaNs/0s
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob Van Geffen committed Dec 18, 2024
1 parent bb80bc7 commit 4b4f246
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
9 changes: 3 additions & 6 deletions src/game/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ use crate::{
Action, ActionTree, PostFlopGame, TreeConfig,
};

// TODO do we ever realistically want the Skip option?
// TODO also, for `Error`, we may actually want to error if the directory exists at all
// (as opposed to what's done now, which may not error until after some reports are already written)
/// Describes possible behaviors for writing reports to files that already exist.
#[derive(Clone, Copy)]
pub enum ExistingReportBehavior {
Expand Down Expand Up @@ -473,7 +470,7 @@ pub fn load_test_game_and_config() -> (PostFlopGame, TreeConfig) {

#[cfg(test)]
mod tests {
use crate::compute_average;
use crate::compute_average_filter_0s;

use super::*;

Expand All @@ -489,7 +486,6 @@ mod tests {
get_current_player(&prev_actions[1..], (starting_player + 1) % 2)
}

// TODO this is jank
fn get_total_bets(prev_actions: &[Action]) -> i32 {
let mut total = 0;
let mut prev_bet = 0;
Expand Down Expand Up @@ -532,7 +528,8 @@ mod tests {
}

// Check action EVs weighted sum to player ev
let action_ev_weighted_sum = compute_average(&row.action_evs, &row.action_frequencies);
let action_ev_weighted_sum =
compute_average_filter_0s(&row.action_evs, &row.action_frequencies);
let player_ev = if player == 0 { row.oop_ev } else { row.ip_ev };
if !(action_ev_weighted_sum.is_nan() && player_ev.is_nan()) {
assert!((action_ev_weighted_sum - player_ev).abs() < 1e-3);
Expand Down
12 changes: 7 additions & 5 deletions src/game/utils/stats.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{compute_average, Game, PostFlopGame};
use crate::{compute_average_filter_0s, Game, PostFlopGame};

/// Returns the player's equity, EV, and EQR (in that order).
/// *Requires game.cache_normalized_weights() to be called beforehand.*
Expand All @@ -18,8 +18,8 @@ pub fn get_player_stats(game: &PostFlopGame, player: usize) -> (f32, f32, f32) {
let equity = game.equity(player);
let ev = game.expected_values(player);
let weights = game.normalized_weights(player);
let average_equity = compute_average(&equity, weights);
let average_ev = compute_average(&ev, weights);
let average_equity = compute_average_filter_0s(&equity, weights);
let average_ev = compute_average_filter_0s(&ev, weights);
(
average_equity,
average_ev,
Expand Down Expand Up @@ -47,7 +47,9 @@ pub fn get_action_frequencies(game: &PostFlopGame) -> Vec<f32> {
let weights = game.normalized_weights(player);

(0..actions.len())
.map(|i| compute_average(&strategy[i * cards.len()..(i + 1) * cards.len()], weights))
.map(|i| {
compute_average_filter_0s(&strategy[i * cards.len()..(i + 1) * cards.len()], weights)
})
.collect()
}

Expand Down Expand Up @@ -80,7 +82,7 @@ pub fn get_action_evs(game: &mut PostFlopGame) -> Vec<f32> {
.map(|(s, w)| s * w)
.collect();

compute_average(evs_detail, &weights)
compute_average_filter_0s(evs_detail, &weights)
})
.collect()
}
21 changes: 17 additions & 4 deletions src/utility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,27 @@ pub fn compute_average(slice: &[f32], weights: &[f32]) -> f32 {
let mut weight_sum = 0.0;
let mut value_sum = 0.0;
for (&v, &w) in slice.iter().zip(weights.iter()) {
if w != 0.0 {
weight_sum += w as f64;
value_sum += v as f64 * w as f64;
}
weight_sum += w as f64;
value_sum += v as f64 * w as f64;
}
(value_sum / weight_sum) as f32
}

/// Works like compute_average, but ignores values where the correspondding weight is 0.0.
/// This means that `compute_average_filter_0s` can output a non-NaN value, even when some values are NaN
/// (so long as their corresponding weights are 0.0).
///
/// Use this function only when performance is not a concern.
#[inline]
pub fn compute_average_filter_0s(slice: &[f32], weights: &[f32]) -> f32 {
let (filtered_slice, filtered_weights): (Vec<f32>, Vec<f32>) = slice
.iter()
.zip(weights.iter())
.filter(|(_, &w)| w != 0.0)
.unzip();
compute_average(&filtered_slice, &filtered_weights)
}

#[inline]
fn weighted_sum(values: &[f32], weights: &[f32]) -> f32 {
let f = |sum: f64, (&v, &w): (&f32, &f32)| sum + v as f64 * w as f64;
Expand Down

0 comments on commit 4b4f246

Please sign in to comment.