From 4b4f24640b712012cf2995c17fb92d0b530bbcb0 Mon Sep 17 00:00:00 2001 From: Jacob Van Geffen Date: Wed, 18 Dec 2024 09:08:54 -0600 Subject: [PATCH] Use separate function for computing average and filtering NaNs/0s --- src/game/aggregation.rs | 9 +++------ src/game/utils/stats.rs | 12 +++++++----- src/utility.rs | 21 +++++++++++++++++---- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/game/aggregation.rs b/src/game/aggregation.rs index 3abc3a9..0ec7692 100644 --- a/src/game/aggregation.rs +++ b/src/game/aggregation.rs @@ -12,9 +12,6 @@ use crate::{ Action, ActionTree, PostFlopGame, TreeConfig, }; -// TODO do we ever realistically want the Skip option? -// TODO also, for `Error`, we may actually want to error if the directory exists at all -// (as opposed to what's done now, which may not error until after some reports are already written) /// Describes possible behaviors for writing reports to files that already exist. #[derive(Clone, Copy)] pub enum ExistingReportBehavior { @@ -473,7 +470,7 @@ pub fn load_test_game_and_config() -> (PostFlopGame, TreeConfig) { #[cfg(test)] mod tests { - use crate::compute_average; + use crate::compute_average_filter_0s; use super::*; @@ -489,7 +486,6 @@ mod tests { get_current_player(&prev_actions[1..], (starting_player + 1) % 2) } - // TODO this is jank fn get_total_bets(prev_actions: &[Action]) -> i32 { let mut total = 0; let mut prev_bet = 0; @@ -532,7 +528,8 @@ mod tests { } // Check action EVs weighted sum to player ev - let action_ev_weighted_sum = compute_average(&row.action_evs, &row.action_frequencies); + let action_ev_weighted_sum = + compute_average_filter_0s(&row.action_evs, &row.action_frequencies); let player_ev = if player == 0 { row.oop_ev } else { row.ip_ev }; if !(action_ev_weighted_sum.is_nan() && player_ev.is_nan()) { assert!((action_ev_weighted_sum - player_ev).abs() < 1e-3); diff --git a/src/game/utils/stats.rs b/src/game/utils/stats.rs index 1aeefc3..95e83c4 100644 --- a/src/game/utils/stats.rs +++ b/src/game/utils/stats.rs @@ -1,4 +1,4 @@ -use crate::{compute_average, Game, PostFlopGame}; +use crate::{compute_average_filter_0s, Game, PostFlopGame}; /// Returns the player's equity, EV, and EQR (in that order). /// *Requires game.cache_normalized_weights() to be called beforehand.* @@ -18,8 +18,8 @@ pub fn get_player_stats(game: &PostFlopGame, player: usize) -> (f32, f32, f32) { let equity = game.equity(player); let ev = game.expected_values(player); let weights = game.normalized_weights(player); - let average_equity = compute_average(&equity, weights); - let average_ev = compute_average(&ev, weights); + let average_equity = compute_average_filter_0s(&equity, weights); + let average_ev = compute_average_filter_0s(&ev, weights); ( average_equity, average_ev, @@ -47,7 +47,9 @@ pub fn get_action_frequencies(game: &PostFlopGame) -> Vec { let weights = game.normalized_weights(player); (0..actions.len()) - .map(|i| compute_average(&strategy[i * cards.len()..(i + 1) * cards.len()], weights)) + .map(|i| { + compute_average_filter_0s(&strategy[i * cards.len()..(i + 1) * cards.len()], weights) + }) .collect() } @@ -80,7 +82,7 @@ pub fn get_action_evs(game: &mut PostFlopGame) -> Vec { .map(|(s, w)| s * w) .collect(); - compute_average(evs_detail, &weights) + compute_average_filter_0s(evs_detail, &weights) }) .collect() } diff --git a/src/utility.rs b/src/utility.rs index 7b8fe7c..442845c 100644 --- a/src/utility.rs +++ b/src/utility.rs @@ -63,14 +63,27 @@ pub fn compute_average(slice: &[f32], weights: &[f32]) -> f32 { let mut weight_sum = 0.0; let mut value_sum = 0.0; for (&v, &w) in slice.iter().zip(weights.iter()) { - if w != 0.0 { - weight_sum += w as f64; - value_sum += v as f64 * w as f64; - } + weight_sum += w as f64; + value_sum += v as f64 * w as f64; } (value_sum / weight_sum) as f32 } +/// Works like compute_average, but ignores values where the correspondding weight is 0.0. +/// This means that `compute_average_filter_0s` can output a non-NaN value, even when some values are NaN +/// (so long as their corresponding weights are 0.0). +/// +/// Use this function only when performance is not a concern. +#[inline] +pub fn compute_average_filter_0s(slice: &[f32], weights: &[f32]) -> f32 { + let (filtered_slice, filtered_weights): (Vec, Vec) = slice + .iter() + .zip(weights.iter()) + .filter(|(_, &w)| w != 0.0) + .unzip(); + compute_average(&filtered_slice, &filtered_weights) +} + #[inline] fn weighted_sum(values: &[f32], weights: &[f32]) -> f32 { let f = |sum: f64, (&v, &w): (&f32, &f32)| sum + v as f64 * w as f64;