From a23e31906fabfb824ebb01a6bcaa8d061bd49325 Mon Sep 17 00:00:00 2001 From: bkushigian Date: Sat, 17 Aug 2024 13:36:24 -0700 Subject: [PATCH] Tmp commit --- examples/file_io_debug.rs | 267 +++++++++++++++++++++++++++----------- src/file.rs | 42 ++++++ src/game/base.rs | 102 ++++++++++----- src/game/interpreter.rs | 20 +++ src/solver.rs | 68 ++++++++++ 5 files changed, 387 insertions(+), 112 deletions(-) diff --git a/examples/file_io_debug.rs b/examples/file_io_debug.rs index f0348a6..d5a625f 100644 --- a/examples/file_io_debug.rs +++ b/examples/file_io_debug.rs @@ -1,97 +1,210 @@ -use std::fs::File; - use postflop_solver::*; -fn main() { - // see `basic.rs` for the explanation of the following code +fn recursive_compare_strategies_helper( + saved: &mut PostFlopGame, + loaded: &mut PostFlopGame, + storage_mode: BoardState, +) { + let history = saved.history().to_vec(); + saved.cache_normalized_weights(); + loaded.cache_normalized_weights(); + + // Check if OOP hands have the same evs + let evs_oop_1 = saved.expected_values(0); + let ws_oop_1 = saved.weights(0); + let evs_oop_2 = loaded.expected_values(1); + let ws_oop_2 = saved.weights(0); + + assert!(ws_oop_1.len() == ws_oop_2.len()); + for (w1, w2) in ws_oop_1.iter().zip(ws_oop_2) { + assert!((w1 - w2).abs() < 0.001); + } + for (i, (e1, e2)) in evs_oop_1.iter().zip(&evs_oop_2).enumerate() { + assert!((e1 - e2).abs() < 0.001, "ev diff({}): {}", i, e1 - e2); + } + + let ev_oop_1 = compute_average(&evs_oop_1, &ws_oop_1); + let ev_oop_2 = compute_average(&evs_oop_2, &ws_oop_2); + + let ev_diff = (ev_oop_1 - ev_oop_2).abs(); + println!("EV Diff: {:0.2}", ev_diff); + assert!((ev_oop_1 - ev_oop_2).abs() < 0.01); + for child_index in 0..saved.available_actions().len() { + saved.play(child_index); + loaded.play(child_index); + + recursive_compare_strategies_helper(saved, loaded, storage_mode); + + saved.apply_history(&history); + loaded.apply_history(&history); + } +} + +fn compare_strategies( + saved: &mut PostFlopGame, + loaded: &mut PostFlopGame, + storage_mode: BoardState, +) { + saved.back_to_root(); + loaded.back_to_root(); + saved.cache_normalized_weights(); + loaded.cache_normalized_weights(); + for (i, ((e1, e2), cards)) in saved + .expected_values(0) + .iter() + .zip(loaded.expected_values(0)) + .zip(saved.private_cards(0)) + .enumerate() + { + println!("ev {}: {}:{}", hole_to_string(*cards).unwrap(), e1, e2); + } + for (i, ((e1, e2), cards)) in saved + .expected_values(1) + .iter() + .zip(loaded.expected_values(1)) + .zip(saved.private_cards(1)) + .enumerate() + { + println!("ev {}: {}:{}", hole_to_string(*cards).unwrap(), e1, e2); + } + recursive_compare_strategies_helper(saved, loaded, storage_mode); +} + +fn print_strats_at_current_node( + g1: &mut PostFlopGame, + g2: &mut PostFlopGame, + actions: &Vec, +) { + let action_string = actions + .iter() + .map(|a| format!("{:?}", a)) + .collect::>() + .join(":"); - let oop_range = "66+,A8s+,A5s-A4s,AJo+,K9s+,KQo,QTs+,JTs,96s+,85s+,75s+,65s,54s"; - let ip_range = "QQ-22,AQs-A2s,ATo+,K5s+,KJo+,Q8s+,J8s+,T7s+,96s+,86s+,75s+,64s+,53s+"; + let player = g1.current_player(); + + println!( + "\x1B[32;1mActions To Reach Node\x1B[0m: [{}]", + action_string + ); + // Print high level node data + if g1.is_chance_node() { + println!("\x1B[32;1mPlayer\x1B[0m: Chance"); + } else if g1.is_terminal_node() { + if player == 0 { + println!("\x1B[32;1mPlayer\x1B[0m: OOP (Terminal)"); + } else { + println!("\x1B[32;1mPlayer\x1B[0m: IP (Terminal)"); + } + } else { + if player == 0 { + println!("\x1B[32;1mPlayer\x1B[0m: OOP"); + } else { + println!("\x1B[32;1mPlayer\x1B[0m: IP"); + } + let private_cards = g1.private_cards(player); + let strat1 = g1.strategy_by_private_hand(); + let strat2 = g2.strategy_by_private_hand(); + let weights1 = g1.weights(player); + let weights2 = g2.weights(player); + let actions = g1.available_actions(); + + // Print both games strategies + for ((cards, (w1, s1)), (w2, s2)) in private_cards + .iter() + .zip(weights1.iter().zip(strat1)) + .zip(weights2.iter().zip(strat2)) + { + let hole_cards = hole_to_string(*cards).unwrap(); + print!("\x1B[34;1m{hole_cards}\x1B[0m@({:.2} v {:.2}) ", w1, w2); + let mut action_frequencies = vec![]; + for (a, (freq1, freq2)) in actions.iter().zip(s1.iter().zip(s2)) { + action_frequencies.push(format!( + "\x1B[32;1m{:?}\x1B[0m: \x1B[31m{:0.3}\x1B[0m v \x1B[33m{:0>.3}\x1B[0m", + a, freq1, freq2 + )) + } + println!("{}", action_frequencies.join(" ")); + } + } +} + +fn main() { + let oop_range = "AA,QQ"; + let ip_range = "KK"; let card_config = CardConfig { range: [oop_range.parse().unwrap(), ip_range.parse().unwrap()], - flop: flop_from_str("Td9d6h").unwrap(), - turn: card_from_str("Qc").unwrap(), - river: NOT_DEALT, + flop: flop_from_str("3h3s3d").unwrap(), + ..Default::default() }; - let bet_sizes = BetSizeOptions::try_from(("60%, e, a", "2.5x")).unwrap(); - let tree_config = TreeConfig { - initial_state: BoardState::Turn, - starting_pot: 200, - effective_stack: 900, + starting_pot: 100, + effective_stack: 100, rake_rate: 0.0, rake_cap: 0.0, - flop_bet_sizes: [bet_sizes.clone(), bet_sizes.clone()], - turn_bet_sizes: [bet_sizes.clone(), bet_sizes.clone()], - river_bet_sizes: [bet_sizes.clone(), bet_sizes], - turn_donk_sizes: None, - river_donk_sizes: Some(DonkSizeOptions::try_from("50%").unwrap()), - add_allin_threshold: 1.5, - force_allin_threshold: 0.15, - merging_threshold: 0.1, + flop_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], + turn_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], + river_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], + ..Default::default() }; let action_tree = ActionTree::new(tree_config).unwrap(); - let mut game = PostFlopGame::with_config(card_config, action_tree).unwrap(); - game.allocate_memory(false); - - let max_num_iterations = 20; - let target_exploitability = game.tree_config().starting_pot as f32 * 0.01; - solve(&mut game, max_num_iterations, target_exploitability, true); - let r = game.set_target_storage_mode(BoardState::Turn); - println!("{r:?}"); - - // save the solved game tree to a file - // 4th argument is zstd compression level (1-22); requires `zstd` feature to use - save_data_to_file(&game, "memo string", "filename.bin", None).unwrap(); - - // load the solved game tree from a file - // 2nd argument is the maximum memory usage in bytes - let (mut game2, _memo_string): (PostFlopGame, _) = - load_data_from_file("filename.bin", None).unwrap(); - - println!("Game 1 Internal Data"); - game.print_internal_data(); - println!("Game 2 Internal Data"); - game2.print_internal_data(); - - // check if the loaded game tree is the same as the original one - game.cache_normalized_weights(); - game2.cache_normalized_weights(); - assert_eq!(game.equity(0), game2.equity(0)); - - // discard information after the river deal when serializing - // this operation does not lose any information of the game tree itself - game2.set_target_storage_mode(BoardState::Turn).unwrap(); - - // compare the memory usage for serialization - println!( - "Memory usage of the original game tree: {:.2}MB", // 11.50MB - game.target_memory_usage() as f64 / (1024.0 * 1024.0) - ); - println!( - "Memory usage of the truncated game tree: {:.2}MB", // 0.79MB - game2.target_memory_usage() as f64 / (1024.0 * 1024.0) - ); + let mut game1 = PostFlopGame::with_config(card_config, action_tree).unwrap(); + game1.allocate_memory(false); + + solve(&mut game1, 100, 0.01, false); + + // save (turn) + game1.set_target_storage_mode(BoardState::Turn).unwrap(); + save_data_to_file(&game1, "", "tmpfile.flop", None).unwrap(); + + // load (turn) + let mut game2: PostFlopGame = load_data_from_file("tmpfile.flop", None).unwrap().0; + // compare_strategies(&mut game, &mut game2, BoardState::Turn); + assert!(game2.rebuild_and_resolve_forgotten_streets().is_ok()); + + let mut actions_so_far = vec![]; + + // Print Root Node + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); - // overwrite the file with the truncated game tree - // game tree constructed from this file cannot access information after the river deal - save_data_to_file(&game2, "memo string", "filename.bin", None).unwrap(); - let (mut game3, _memo_string): (PostFlopGame, String) = - load_data_from_file("filename.bin", None).unwrap(); + // OOP: Check + actions_so_far.push(game1.available_actions()[0]); + game1.play(0); + game2.play(0); + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); + + // IP: Check + actions_so_far.push(game1.available_actions()[0]); + game1.play(0); + game2.play(0); + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); - game.play(0); - game.play(0); - println!("Game X/X Actions: {:?}", game.available_actions()); + // Chance: 2c + actions_so_far.push(game1.available_actions()[0]); + game1.play(0); game2.play(0); + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); + + // OOP: CHECK + actions_so_far.push(game1.available_actions()[0]); + game1.play(0); game2.play(0); - println!("Game2 X/X Actions: {:?}", game.available_actions()); - game3.play(0); - game3.play(0); - println!("Game3 X/X Actions: {:?}", game3.available_actions()); + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); + + // IP: CHECK + actions_so_far.push(game1.available_actions()[0]); + game1.play(0); + game2.play(0); + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); + + // CHANCE: 0 + actions_so_far.push(game1.available_actions()[1]); + game1.play(1); + game2.play(1); + print_strats_at_current_node(&mut game1, &mut game2, &actions_so_far); - // delete the file - std::fs::remove_file("filename.bin").unwrap(); + // compare_strategies(&mut game, &mut game2, BoardState::Turn); } diff --git a/src/file.rs b/src/file.rs index 17e7d5b..d7e90f7 100644 --- a/src/file.rs +++ b/src/file.rs @@ -269,6 +269,7 @@ mod tests { use crate::action_tree::*; use crate::card::*; use crate::range::*; + use crate::solver::solve; use crate::utility::*; #[test] @@ -375,4 +376,45 @@ mod tests { assert!((root_ev_oop - 45.0).abs() < 1e-4); assert!((root_ev_ip - 15.0).abs() < 1e-4); } + + #[test] + fn test_reload_and_resolve() { + let oop_range = "AA,QQ"; + let ip_range = "KK"; + + let card_config = CardConfig { + range: [oop_range.parse().unwrap(), ip_range.parse().unwrap()], + flop: flop_from_str("3h3s3d").unwrap(), + ..Default::default() + }; + + let tree_config = TreeConfig { + starting_pot: 100, + effective_stack: 100, + rake_rate: 0.0, + rake_cap: 0.0, + flop_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], + turn_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], + river_bet_sizes: [("e", "").try_into().unwrap(), ("e", "").try_into().unwrap()], + ..Default::default() + }; + + let action_tree = ActionTree::new(tree_config).unwrap(); + let mut game = PostFlopGame::with_config(card_config, action_tree).unwrap(); + println!( + "memory usage: {:.2}GB", + game.memory_usage().0 as f64 / (1024.0 * 1024.0 * 1024.0) + ); + game.allocate_memory(false); + + solve(&mut game, 100, 0.01, false); + + // save (turn) + game.set_target_storage_mode(BoardState::Turn).unwrap(); + save_data_to_file(&game, "", "tmpfile.flop", None).unwrap(); + + // load (turn) + let mut game: PostFlopGame = load_data_from_file("tmpfile.flop", None).unwrap().0; + assert!(game.rebuild_and_resolve_forgotten_streets().is_ok()); + } } diff --git a/src/game/base.rs b/src/game/base.rs index 6dc26c8..94931c1 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -1,13 +1,14 @@ use super::*; use crate::bunching::*; use crate::interface::*; +use crate::solve_with_node_as_root; use crate::utility::*; use std::mem::{self, MaybeUninit}; #[cfg(feature = "rayon")] use rayon::prelude::*; -#[derive(Default)] +#[derive(Default, Debug)] struct BuildTreeInfo { flop_index: usize, turn_index: usize, @@ -768,6 +769,7 @@ impl PostFlopGame { node.num_children += 1; let mut child = node.children().last().unwrap().lock(); child.prev_action = Action::Chance(card); + child.parent_node_index = node_index; child.turn = node.turn; child.river = card; } @@ -845,12 +847,9 @@ impl PostFlopGame { self.num_nodes_per_street = nodes_per_street; - let total_new_nodes_to_allocate = total_num_nodes - self.node_arena.len() as u64; - self.node_arena.append( - &mut (0..total_new_nodes_to_allocate) - .map(|_| MutexLike::new(PostFlopNode::default())) - .collect::>(), - ); + self.node_arena = (0..total_num_nodes) + .map(|_| MutexLike::new(PostFlopNode::default())) + .collect::>(); // self.clear_storage(); let mut info = BuildTreeInfo { @@ -879,64 +878,97 @@ impl PostFlopGame { Ok(()) } - pub fn reload_and_resolve(&mut self, enable_compression: bool) -> Result<(), String> { - self.allocate_memory_after_load(enable_compression)?; + pub fn rebuild_and_resolve_forgotten_streets(&mut self) -> Result<(), String> { + self.check_card_config()?; self.reinit_root()?; + self.allocate_memory_after_load()?; + self.resolve_reloaded_nodes(1000, 0.01, false) + } - // Collect root nodes to resolve - let nodes_to_solve = match self.storage_mode { + /// Return the node index for each root of the forgotten gametrees that were + /// omitted during a partial save. + /// + /// When we perform a partial save (e.g., a flop save), we lose + /// cfvalues/strategy data for all subtrees rooted at the forgotten street + /// (in the case of a flop save, this would be all subtrees rooted at the + /// beginning of the turn). + /// + /// To regain this information we need to resolve each of these subtrees + /// individually. This function collects the index of each such root. + pub fn collect_unsolved_roots_after_reload(&mut self) -> Result, String> { + match self.storage_mode { BoardState::Flop => { let turn_root_nodes = self .node_arena .iter() - .filter(|n| { + .enumerate() + .filter(|(_, n)| { n.lock().turn != NOT_DEALT && n.lock().river == NOT_DEALT && matches!(n.lock().prev_action, Action::Chance(..)) }) + .map(|(i, _)| i) .collect::>(); - turn_root_nodes + Ok(turn_root_nodes) } BoardState::Turn => { let river_root_nodes = self .node_arena .iter() - .filter(|n| { + .enumerate() + .filter(|(_, n)| { n.lock().turn != NOT_DEALT && matches!(n.lock().prev_action, Action::Chance(..)) }) + .map(|(i, _)| i) .collect::>(); - river_root_nodes - } - BoardState::River => vec![], - }; - for node in nodes_to_solve { - // Get history of this node - // let mut history = vec![]; - let mut n = node.lock(); - while n.parent_node_index < usize::MAX { - let parent = self.node_arena[n.parent_node_index].lock(); - let action = n.prev_action; + Ok(river_root_nodes) } + BoardState::River => Ok(vec![]), } + } + + pub fn resolve_reloaded_nodes( + &mut self, + max_num_iterations: u32, + target_exploitability: f32, + print_progress: bool, + ) -> Result<(), String> { + let nodes_to_solve = self.collect_unsolved_roots_after_reload()?; + self.state = State::MemoryAllocated; + for node_idx in nodes_to_solve { + let node = self.node_arena.get(node_idx).ok_or("Invalid node index")?; + // let history = node + // .lock() + // .compute_history_recursive(&self) + // .ok_or("Unable to compute history for node".to_string())? + // .to_vec(); + // self.apply_history(&history); + solve_with_node_as_root( + self, + node.lock(), + max_num_iterations, + target_exploitability, + print_progress, + ); + } + finalize(self); Ok(()) } - /// Reallocate memory for full tree after performing a partial load - pub fn allocate_memory_after_load(&mut self, enable_compression: bool) -> Result<(), String> { + /// Reallocate memory for full tree after performing a partial load. This + /// must be called after `init_root()` + pub fn allocate_memory_after_load(&mut self) -> Result<(), String> { if self.state <= State::Uninitialized { return Err("Game is not successfully initialized".to_string()); } - if self.state == State::MemoryAllocated - && self.storage_mode == BoardState::River - && self.is_compression_enabled == enable_compression - { + if self.state == State::MemoryAllocated && self.storage_mode == BoardState::River { return Ok(()); } - let num_bytes = if enable_compression { 2 } else { 4 }; + let num_bytes = if self.is_compression_enabled { 2 } else { 4 }; if num_bytes * self.num_storage > isize::MAX as u64 || num_bytes * self.num_storage_chance > isize::MAX as u64 { @@ -944,7 +976,7 @@ impl PostFlopGame { } self.state = State::MemoryAllocated; - self.is_compression_enabled = enable_compression; + // self.is_compression_enabled = self.is_compression_enabled; let old_storage1 = std::mem::replace(&mut self.storage1, vec![]); let old_storage2 = std::mem::replace(&mut self.storage2, vec![]); @@ -960,7 +992,7 @@ impl PostFlopGame { self.storage_ip = vec![0; storage_ip_bytes]; self.storage_chance = vec![0; storage_chance_bytes]; - self.allocate_memory_nodes(); + self.allocate_memory_nodes(); // Assign node storage pointers self.storage_mode = BoardState::River; self.target_storage_mode = BoardState::River; @@ -1595,7 +1627,7 @@ impl PostFlopGame { Ok(info) } - /// Allocates memory recursively. + /// Assigns allocated storage memory. fn allocate_memory_nodes(&mut self) { let num_bytes = if self.is_compression_enabled { 2 } else { 4 }; let mut action_counter = 0; diff --git a/src/game/interpreter.rs b/src/game/interpreter.rs index d7845b6..619011d 100644 --- a/src/game/interpreter.rs +++ b/src/game/interpreter.rs @@ -763,6 +763,7 @@ impl PostFlopGame { node.cfvalues_ip().to_vec() } } else if player == self.current_player() { + println!("BINGO"); have_actions = true; if self.is_compression_enabled { let slice = node.cfvalues_compressed(); @@ -846,6 +847,25 @@ impl PostFlopGame { ret } + pub fn strategy_by_private_hand(&self) -> Vec> { + let strat = self.strategy(); + let player = self.current_player(); + let num_hands = self.private_cards(player).len(); + let num_actions = self.available_actions().len(); + assert!(num_hands * num_actions == strat.len()); + let mut strat_by_hand: Vec> = Vec::with_capacity(num_hands); + for j in 0..num_hands { + strat_by_hand.push(Vec::with_capacity(num_actions)); + } + + for i in 0..num_actions { + for j in 0..num_hands { + strat_by_hand[j].push(strat[i * num_hands + j]); + } + } + strat_by_hand + } + /// Returns the total bet amount of each player (OOP, IP). #[inline] pub fn total_bet_amount(&self) -> [i32; 2] { diff --git a/src/solver.rs b/src/solver.rs index 5a1cc5a..02db93c 100644 --- a/src/solver.rs +++ b/src/solver.rs @@ -132,6 +132,74 @@ pub fn solve_step(game: &T, current_iteration: u32) { } } +/// Performs Discounted CFR algorithm until the given number of iterations or exploitability is +/// satisfied. +/// +/// This method returns the exploitability of the obtained strategy. +pub fn solve_with_node_as_root( + game: &mut T, + mut root: MutexGuardLike, + max_num_iterations: u32, + target_exploitability: f32, + print_progress: bool, +) -> f32 { + if game.is_solved() { + panic!("Game is already solved"); + } + + if !game.is_ready() { + panic!("Game is not ready"); + } + + let mut exploitability = compute_exploitability(game); + + if print_progress { + print!("iteration: 0 / {max_num_iterations} "); + print!("(exploitability = {exploitability:.4e})"); + io::stdout().flush().unwrap(); + } + + for t in 0..max_num_iterations { + if exploitability <= target_exploitability { + break; + } + + let params = DiscountParams::new(t); + + // alternating updates + for player in 0..2 { + let mut result = Vec::with_capacity(game.num_private_hands(player)); + solve_recursive( + result.spare_capacity_mut(), + game, + &mut root, + player, + game.initial_weights(player ^ 1), + ¶ms, + ); + } + + if (t + 1) % 10 == 0 || t + 1 == max_num_iterations { + exploitability = compute_exploitability(game); + } + + if print_progress { + print!("\riteration: {} / {} ", t + 1, max_num_iterations); + print!("(exploitability = {exploitability:.4e})"); + io::stdout().flush().unwrap(); + } + } + + if print_progress { + println!(); + io::stdout().flush().unwrap(); + } + + finalize(game); + + exploitability +} + /// Recursively solves the counterfactual values. fn solve_recursive( result: &mut [MaybeUninit],