diff --git a/examples/report.rs b/examples/report.rs index a9b3063..fa3ec15 100644 --- a/examples/report.rs +++ b/examples/report.rs @@ -1,7 +1,7 @@ use std::fs; +use aggregation::{generate_all_lines, AggActionTree, ExistingReportBehavior}; use postflop_solver::*; -use utils::batch::report::{generate_all_lines, AggActionTree, ExistingReportBehavior}; // Uncomment if reloading from previous game saves use utils::flop_helper::flop_to_string; @@ -70,7 +70,7 @@ fn main() { ) .unwrap(); - report_tree.update_report_for_game(&mut game, &flop.into_iter().collect()); + report_tree.update_report_for_game(&mut game); // Log progress //if (i + 1) % 10 == 0 { diff --git a/src/game/aggregation.rs b/src/game/aggregation.rs index 34497d0..bb35a63 100644 --- a/src/game/aggregation.rs +++ b/src/game/aggregation.rs @@ -115,7 +115,10 @@ fn folder_name_from_action(action: Action) -> String { Action::Raise(x) => format!("raise{}", x), Action::Check => "check".to_string(), Action::Call => "call".to_string(), - _ => unimplemented!("Cannot currently make folder for terminating action"), + _ => unimplemented!( + "Cannot currently make folder for terminating action {:?}", + action + ), } } @@ -156,7 +159,7 @@ fn report_header(actions: &Vec) -> String { /// actions: list of action likelihoods (0-100), corresponding to the list of action from the owning AggActionTree /// action_evs: expected value (in chips) resulting from each action pub struct AggRow { - flop: [u8; 3], // could also have turn/river here + flop: [u8; 3], turn: Option, river: Option, ip_equity: f32, @@ -277,8 +280,10 @@ impl AggActionTree { let available_actions = action_tree.available_actions().to_vec(); - // Add child node if it doesn't exist, and set current node to child node - current_node = current_node.child_or_add(action, available_actions); + // When node isn't terminal, add child node if it doesn't exist and set current node to child node + if !action_tree.is_terminal_node() { + current_node = current_node.child_or_add(action, available_actions); + } } } @@ -479,33 +484,74 @@ mod tests { get_current_player(&prev_actions[1..], (starting_player + 1) % 2) } - fn check_row(row: &AggRow, player: usize) { + // TODO this is jank + fn get_total_bets(prev_actions: &[Action]) -> i32 { + let mut total = 0; + let mut prev_bet = 0; + let mut street_starting_total = 0; + for &action in prev_actions { + match action { + Action::AllIn(x) | Action::Bet(x) | Action::Raise(x) => { + total = street_starting_total + prev_bet + x; + prev_bet = x; + } + Action::Call => { + total = street_starting_total + prev_bet * 2; + prev_bet = 0; + street_starting_total = total; + } + _ => (), + } + } + total + } + + fn check_row(row: &AggRow, player: usize, pot: f32) { // Check that equities sum to ~1 - assert!((row.ip_equity + row.oop_equity - 1.0).abs() < 1e-3); + println!("OOP equity: {:?}", row.oop_equity); + println!("IP equity: {:?}", row.ip_equity); + // Skip check if both are NaN + if !(row.ip_equity.is_nan() && row.oop_equity.is_nan()) { + assert!((row.ip_equity + row.oop_equity - 1.0).abs() < 1e-3); + } // Check that actions sum to ~1 let action_freq_total: f32 = row.action_frequencies.iter().sum(); - assert!((action_freq_total - 1.0).abs() < 1e-3); + println!("Action freq total: {:?}", action_freq_total); + if !action_freq_total.is_nan() { + assert!((action_freq_total - 1.0).abs() < 1e-3); + } // Check evs sum to pot - // TODO + println!("OOP EV: {:?}", row.oop_ev); + println!("IP EV: {:?}", row.ip_ev); + println!("Pot: {:?}", pot); + if !(row.ip_ev.is_nan() && row.oop_ev.is_nan()) { + assert!((row.oop_ev + row.ip_ev - pot).abs() < 1e-3); + } // Check action EVs weighted sum to player ev - let action_ev_weighted_sum = compute_average(&row.action_evs, &row.action_frequencies); let player_ev = if player == 0 { row.oop_ev } else { row.ip_ev }; + println!("Action EV sum: {:?}", action_ev_weighted_sum); + println!("Player EV: {:?}", player_ev); assert!((action_ev_weighted_sum - player_ev).abs() < 1e-3); } - fn check_tree(tree: &AggActionTree) { + fn check_tree(tree: &AggActionTree, config: &TreeConfig) { let current_player = get_current_player(&tree.prev_actions, 0); + println!("Prev actions: {:?}", tree.prev_actions); for row in &tree.data { - check_row(row, current_player); + if let Some(c) = row.river { + println!("River: {c:?}"); + } + let pot = config.starting_pot + get_total_bets(&tree.prev_actions); + check_row(row, current_player, pot as f32); } for (_, child) in &tree.child_trees { - check_tree(child); + check_tree(child, &config); } } @@ -514,8 +560,12 @@ mod tests { let (mut game, config) = load_game_and_config(); let all_lines = generate_all_lines(config.clone()).unwrap(); - let mut tree = AggActionTree::init_root(all_lines, config).unwrap(); + let mut tree = AggActionTree::init_root(all_lines, config.clone()).unwrap(); tree.update_report_for_game(&mut game); - check_tree(&tree); + // Output the report for debugging + // let report_dir = "reports/agg_test"; + // tree.write_self_and_children(&report_dir, "report.csv", ExistingReportBehavior::Overwrite) + // .expect("Problem writing to files"); + check_tree(&tree, &config); } } diff --git a/src/game/interpreter.rs b/src/game/interpreter.rs index 6076f5c..aa7bdef 100644 --- a/src/game/interpreter.rs +++ b/src/game/interpreter.rs @@ -782,6 +782,7 @@ impl PostFlopGame { ret.chunks_exact_mut(num_hands) .enumerate() .for_each(|(action, row)| { + // TODO: This is an insane way to check if the action is a fold let is_fold = have_actions && self.node().play(action).prev_action == Action::Fold; self.apply_swap(row, player, false); row.iter_mut() diff --git a/src/game/mod.rs b/src/game/mod.rs index 9d29e75..032f08d 100644 --- a/src/game/mod.rs +++ b/src/game/mod.rs @@ -1,3 +1,4 @@ +pub mod aggregation; mod base; mod evaluation; mod interpreter; diff --git a/src/game/utils/mod.rs b/src/game/utils/mod.rs index 874806b..45810f3 100644 --- a/src/game/utils/mod.rs +++ b/src/game/utils/mod.rs @@ -1,2 +1 @@ -pub mod batch; pub mod flop_helper; diff --git a/src/utility.rs b/src/utility.rs index c834de7..a31b783 100644 --- a/src/utility.rs +++ b/src/utility.rs @@ -63,8 +63,10 @@ pub fn compute_average(slice: &[f32], weights: &[f32]) -> f32 { let mut weight_sum = 0.0; let mut value_sum = 0.0; for (&v, &w) in slice.iter().zip(weights.iter()) { - weight_sum += w as f64; - value_sum += v as f64 * w as f64; + if w != 0.0 { + weight_sum += w as f64; + value_sum += v as f64 * w as f64; + } } (value_sum / weight_sum) as f32 } diff --git a/test-artifacts/Td9d6hQc.pfs b/test-artifacts/Td9d6hQc.pfs new file mode 100644 index 0000000..77c75ce Binary files /dev/null and b/test-artifacts/Td9d6hQc.pfs differ