diff --git a/DESIGN.md b/DESIGN.md new file mode 100644 index 0000000..3da8edc --- /dev/null +++ b/DESIGN.md @@ -0,0 +1,200 @@ +# Design + +This document is a description, as far as I understand it, of the inner design +of the solver and PostFlopGame. This is a working document for me to get my +bearings. + +## PostFlopGame + +### Build/Allocate/Initialize + +To set up a `PostFlopGame` we need to **create a `PostFlopGame` instance**, +**allocate global storage and `PostFlopNode`s**, and **initialize the +`PostFlopNode` child/parent relationship**. This is done in several steps. + +We begin by creating a `PostFlopGame` instance. + +```rust +let mut game = PostFlopGame::with_config(card_config, action_tree).unwrap(); +``` + +A `PostFlopGame` requires an +`ActionTree` which describes all possible actions and lines (no runout +information), and a `CardConfig`, which describes player ranges and +flop/turn/river data. + +Once we have created a `PostFlopGame` instance we need to allocate the following +memory and initialize its values: + ++ `game.node_arena` ++ `game.storage1` ++ `game.storage2` ++ `game.storage_ip` ++ `game.storage_chance` + +These fields are not allocated/initialized at the same time: + ++ `game.node_arena` is allocated and initialized via `with_config()` (i.e., when + we created our `PostFlopGame`), ++ other storage is allocated via `game.allocate_memory()`. + +#### Allocating and Initializing `node_arena` + +We constructed a `PostFlopGame` by calling +`PostFlopGame::with_config(card_config, action_tree)`, which under the hood +actually calls: + +```rust + let mut game = Self::new(); + game.update_config(card_config, action_tree)?; +``` + +`PostFlopGame::update_config` sets up configuration data, sanity checks things +are correct, and then calls `self.init_root()`. + +`init_root` is responsible for: + +1. Counting number of `PostFlopNode`s to be allocated (`self.nodes_per_street`), + broken up by flop, turn, and river +2. Allocating `PostFlopNode`s in the `node_arena` field +3. Clearing storage: `self.clear_storage()` sets each storage item to a new + `Vec` +4. Invoking `build_tree_recursive` which initializes each node's child/parent + relationship via `child_offset` (through calls to `push_actions` and + `push_chances`). + +Each `PostFlopNode` points to node-specific data (e.g., strategies and +cfregrets) that is located inside of `PostFlopGame.storage*` fields (which is +currently unallocated) via similarly named fields `PostFlopNode.storage*`. + +Additionally, each node points to the children offset with `children_offset`, +which records where in `node_arena` relative to the current node that node's +children begin. We allocate this memory via: + +```rust +game.allocate_memory(false); // pass `true` to use compressed memory +``` + +This allocates the following memory: + ++ `self.storage1` ++ `self.storage2` ++ `self.storage3` ++ `self.storage_chance` + +Next, `allocate_memory()` calls `allocate_memory_nodes(&mut self)`, which +iterates through each node in `node_arena` and sets storage pointers. + +After `allocate_memory` returns we still need to set `child_offset`s. + +### Storage + +There are several fields marked as `// global storage` in `game::mod::PostFlopGame`: + +```rust + // global storage + // `storage*` are used as a global storage and are referenced by `PostFlopNode::storage*`. + // Methods like `PostFlopNode::strategy` define how the storage is used. + node_arena: Vec>, + storage1: Vec, + storage2: Vec, + storage_ip: Vec, + storage_chance: Vec, + locking_strategy: BTreeMap>, +``` + +These are referenced from `PostFlopNode`: + +```rust + storage1: *mut u8, // strategy + storage2: *mut u8, // regrets or cfvalues + storage3: *mut u8, // IP cfvalues +``` + ++ `storage1` seems to store the strategy ++ `storage2` seems to store regrets/cfvalues, and ++ `storage3` stores IP's cf values (does that make `storage2` store OOP's cfvalues?) + +Storage is a byte vector `Vec`, and these store floating point values. + +> [!IMPORTANT] +> Why are these stored as `Vec`s? Is this for swapping between +> `f16` and `f32`s? + +Some storage is allocated in `game::base::allocate_memory`: + +```rust + let storage_bytes = (num_bytes * self.num_storage) as usize; + let storage_ip_bytes = (num_bytes * self.num_storage_ip) as usize; + let storage_chance_bytes = (num_bytes * self.num_storage_chance) as usize; + + self.storage1 = vec![0; storage_bytes]; + self.storage2 = vec![0; storage_bytes]; + self.storage_ip = vec![0; storage_ip_bytes]; + self.storage_chance = vec![0; storage_chance_bytes]; +``` + +`node_arena` is allocated in `game::base::init_root()`: + +```rust + let num_nodes = self.count_nodes_per_street(); + let total_num_nodes = num_nodes[0] + num_nodes[1] + num_nodes[2]; + + if total_num_nodes > u32::MAX as u64 + || mem::size_of::() as u64 * total_num_nodes > isize::MAX as u64 + { + return Err("Too many nodes".to_string()); + } + + self.num_nodes = num_nodes; + self.node_arena = (0..total_num_nodes) + .map(|_| MutexLike::new(PostFlopNode::default())) + .collect::>(); + self.clear_storage(); +``` + +`locking_strategy` maps node indexes (`PostFlopGame::node_index`) to a locked +strategy. `locking_strategy` is initialized to an empty `BTreeMap>` by deriving Default. It is inserted into via +`PostFlopGame::lock_current_strategy` + +### Serialization/Deserialization + +Serialization relies on the `bincode` library's `Encode` and `Decode`. We can set +the `target_storage_mode` to allow for a non-full save. For instance, + +```rust +game.set_target_storage_mode(BoardState::Turn); +``` + +will ensure that when `game` is encoded, it will only save Flop and Turn data. +When a serialized tree is deserialized, if it is a partial save (e.g., a Turn +save) you will not be able to navigate to unsaved streets. + +Several things break when we deserialize a partial save: + ++ `node_arena` is only partially populated ++ `node.children()` points to raw data when `node` points to an street that is + not serialized (e.g., a chance node before the river for a Turn save). + +### Allocating `node_arena` + +We want to first allocate nodes for `node_arena`, and then run some form of +`build_tree_recursive`. This assumes that `node_arena` is already allocated, and +recursively visits children of nodes and modifies them to + +### Data Coupling/Relations/Invariants + ++ A node is locked IFF it is contained in the game's locking_strategy ++ `PostFlopGame.node_arena` is pointed to by `PostFlopNode.children_offset`. For + instance, this is the basic definition of the `PostFlopNode.children()` + function: + + ```rust + slice::from_raw_parts( + self_ptr.add(self.children_offset as usize), + self.num_children as usize, + ) + ``` + + We get a pointer to `self` and add children offset. diff --git a/examples/simple.rs b/examples/simple.rs new file mode 100644 index 0000000..00782d0 --- /dev/null +++ b/examples/simple.rs @@ -0,0 +1,53 @@ +use postflop_solver::*; + +fn main() { + // ranges of OOP and IP in string format + // see the documentation of `Range` for more details about the format + let oop_range = "66+"; + let ip_range = "66+"; + + let card_config = CardConfig { + range: [oop_range.parse().unwrap(), ip_range.parse().unwrap()], + flop: flop_from_str("Td9d6h").unwrap(), + turn: NOT_DEALT, + river: NOT_DEALT, + }; + + // bet sizes -> 60% of the pot, geometric size, and all-in + // raise sizes -> 2.5x of the previous bet + // see the documentation of `BetSizeOptions` for more details + let bet_sizes = BetSizeOptions::try_from(("100%", "100%")).unwrap(); + + let tree_config = TreeConfig { + initial_state: BoardState::Flop, // must match `card_config` + starting_pot: 200, + effective_stack: 200, + rake_rate: 0.0, + rake_cap: 0.0, + flop_bet_sizes: [bet_sizes.clone(), bet_sizes.clone()], // [OOP, IP] + turn_bet_sizes: [bet_sizes.clone(), bet_sizes.clone()], + river_bet_sizes: [bet_sizes.clone(), bet_sizes], + turn_donk_sizes: None, // use default bet sizes + river_donk_sizes: Some(DonkSizeOptions::try_from("100%").unwrap()), + add_allin_threshold: 1.5, // add all-in if (maximum bet size) <= 1.5x pot + force_allin_threshold: 0.15, // force all-in if (SPR after the opponent's call) <= 0.15 + merging_threshold: 0.1, + }; + + // build the game tree + // `ActionTree` can be edited manually after construction + let action_tree = ActionTree::new(tree_config).unwrap(); + let mut game = PostFlopGame::with_config(card_config, action_tree).unwrap(); + + // allocate memory without compression (use 32-bit float) + game.allocate_memory(false); + + // solve the game + let max_num_iterations = 20; + let target_exploitability = game.tree_config().starting_pot as f32 * 0.100; // 10.0% of the pot + let exploitability = solve(&mut game, max_num_iterations, target_exploitability, true); + println!("Exploitability: {:.2}", exploitability); + + // get equity and EV of a specific hand + game.cache_normalized_weights(); +} diff --git a/src/card.rs b/src/card.rs index 2bba763..e36121a 100644 --- a/src/card.rs +++ b/src/card.rs @@ -11,6 +11,8 @@ use bincode::{Decode, Encode}; /// - `card_id = 4 * rank + suit` (where `0 <= card_id < 52`) /// - `rank`: 2 => `0`, 3 => `1`, 4 => `2`, ..., A => `12` /// - `suit`: club => `0`, diamond => `1`, heart => `2`, spade => `3` +/// +/// An undealt card is represented by Card::MAX (see `NOT_DEALT`). pub type Card = u8; /// Constant representing that the card is not yet dealt. diff --git a/src/game/base.rs b/src/game/base.rs index 383327a..e7501f7 100644 --- a/src/game/base.rs +++ b/src/game/base.rs @@ -7,7 +7,7 @@ use std::mem::{self, MaybeUninit}; #[cfg(feature = "rayon")] use rayon::prelude::*; -#[derive(Default)] +#[derive(Default, Debug)] struct BuildTreeInfo { flop_index: usize, turn_index: usize, @@ -519,10 +519,18 @@ impl PostFlopGame { ) = self.card_config.isomorphism(&self.private_cards); } - /// Initializes the root node of game tree. + /// Initializes the root node of game tree and recursively build the tree. + /// + /// This function is responsible for computing the number of nodes required + /// for each street (via `count_nodes_per_street()`), allocating + /// `PostFlopNode`s to `self.node_arena`, and calling `build_tree_recursive`, + /// which recursively visits all nodes and, among other things, initializes + /// the child/parent relation. + /// + /// This does _not_ allocate global storage (e.g., `self.storage1`, etc). fn init_root(&mut self) -> Result<(), String> { - let num_nodes = self.count_num_nodes(); - let total_num_nodes = num_nodes[0] + num_nodes[1] + num_nodes[2]; + let nodes_per_street = self.count_nodes_per_street(); + let total_num_nodes = nodes_per_street[0] + nodes_per_street[1] + nodes_per_street[2]; if total_num_nodes > u32::MAX as u64 || mem::size_of::() as u64 * total_num_nodes > isize::MAX as u64 @@ -530,15 +538,15 @@ impl PostFlopGame { return Err("Too many nodes".to_string()); } - self.num_nodes = num_nodes; + self.num_nodes_per_street = nodes_per_street; self.node_arena = (0..total_num_nodes) .map(|_| MutexLike::new(PostFlopNode::default())) .collect::>(); self.clear_storage(); let mut info = BuildTreeInfo { - turn_index: num_nodes[0] as usize, - river_index: (num_nodes[0] + num_nodes[1]) as usize, + turn_index: nodes_per_street[0] as usize, + river_index: (nodes_per_street[0] + nodes_per_street[1]) as usize, ..Default::default() }; @@ -584,9 +592,10 @@ impl PostFlopGame { self.storage_chance = Vec::new(); } - /// Counts the number of nodes in the game tree. + /// Counts the number of nodes in the game tree per street, accounting for + /// isomorphism. #[inline] - fn count_num_nodes(&self) -> [u64; 3] { + fn count_nodes_per_street(&self) -> [u64; 3] { let (turn_coef, river_coef) = match (self.card_config.turn, self.card_config.river) { (NOT_DEALT, _) => { let mut river_coef = 0; @@ -988,7 +997,7 @@ impl PostFlopGame { if self.card_config.river != NOT_DEALT { self.bunching_arena = arena; - self.assign_zero_weights(); + self.assign_zero_weights_to_dead_cards(); return Ok(()); } @@ -1043,7 +1052,7 @@ impl PostFlopGame { let player_swap = swap_option.map(|swap| { let mut tmp = (0..player_len).collect::>(); - apply_swap(&mut tmp, &swap[player]); + apply_swap_list(&mut tmp, &swap[player]); tmp }); @@ -1065,8 +1074,8 @@ impl PostFlopGame { let slices = if let Some(swap) = swap_option { tmp.0.extend_from_slice(&arena[index..index + opponent_len]); tmp.1.extend_from_slice(opponent_strength); - apply_swap(&mut tmp.0, &swap[player ^ 1]); - apply_swap(&mut tmp.1, &swap[player ^ 1]); + apply_swap_list(&mut tmp.0, &swap[player ^ 1]); + apply_swap_list(&mut tmp.1, &swap[player ^ 1]); (tmp.0.as_slice(), &tmp.1) } else { (&arena[index..index + opponent_len], opponent_strength) @@ -1103,7 +1112,7 @@ impl PostFlopGame { if self.card_config.turn != NOT_DEALT { self.bunching_arena = arena; - self.assign_zero_weights(); + self.assign_zero_weights_to_dead_cards(); return Ok(()); } @@ -1137,7 +1146,7 @@ impl PostFlopGame { let player_swap = swap_option.map(|swap| { let mut tmp = (0..player_len).collect::>(); - apply_swap(&mut tmp, &swap[player]); + apply_swap_list(&mut tmp, &swap[player]); tmp }); @@ -1154,7 +1163,7 @@ impl PostFlopGame { let slice = &arena[index..index + opponent_len]; let slice = if let Some(swap) = swap_option { tmp.extend_from_slice(slice); - apply_swap(&mut tmp, &swap[player ^ 1]); + apply_swap_list(&mut tmp, &swap[player ^ 1]); &tmp } else { slice @@ -1181,7 +1190,7 @@ impl PostFlopGame { } self.bunching_arena = arena; - self.assign_zero_weights(); + self.assign_zero_weights_to_dead_cards(); Ok(()) } @@ -1416,7 +1425,7 @@ impl PostFlopGame { Ok(info) } - /// Allocates memory recursively. + /// Assigns allocated storage memory. fn allocate_memory_nodes(&mut self) { let num_bytes = if self.is_compression_enabled { 2 } else { 4 }; let mut action_counter = 0; @@ -1447,4 +1456,8 @@ impl PostFlopGame { } } } + + pub fn get_state(&self) -> &State { + &self.state + } } diff --git a/src/game/interpreter.rs b/src/game/interpreter.rs index d44d63a..d7845b6 100644 --- a/src/game/interpreter.rs +++ b/src/game/interpreter.rs @@ -30,7 +30,7 @@ impl PostFlopGame { self.weights[0].copy_from_slice(&self.initial_weights[0]); self.weights[1].copy_from_slice(&self.initial_weights[1]); - self.assign_zero_weights(); + self.assign_zero_weights_to_dead_cards(); } /// Returns the history of the current node. @@ -366,7 +366,7 @@ impl PostFlopGame { } // update the weights - self.assign_zero_weights(); + self.assign_zero_weights_to_dead_cards(); } // player node else { @@ -1005,8 +1005,9 @@ impl PostFlopGame { unsafe { node_ptr.offset_from(self.node_arena.as_ptr()) as usize } } - /// Assigns zero weights to the hands that are not possible. - pub(super) fn assign_zero_weights(&mut self) { + /// Assigns zero weights to the hands that are not possible (e.g, by a card + /// being removed by a turn or river). + pub(super) fn assign_zero_weights_to_dead_cards(&mut self) { if self.bunching_num_dead_cards == 0 { let mut board_mask: u64 = 0; if self.turn != NOT_DEALT { diff --git a/src/game/mod.rs b/src/game/mod.rs index 33d9a19..08b23bc 100644 --- a/src/game/mod.rs +++ b/src/game/mod.rs @@ -17,10 +17,10 @@ use std::collections::BTreeMap; #[cfg(feature = "bincode")] use bincode::{Decode, Encode}; -#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)] #[repr(u8)] #[cfg_attr(feature = "bincode", derive(Decode, Encode))] -enum State { +pub enum State { ConfigError = 0, #[default] Uninitialized = 1, @@ -82,7 +82,7 @@ pub struct PostFlopGame { // store options storage_mode: BoardState, target_storage_mode: BoardState, - num_nodes: [u64; 3], + num_nodes_per_street: [u64; 3], is_compression_enabled: bool, num_storage: u64, num_storage_ip: u64, diff --git a/src/game/serialization.rs b/src/game/serialization.rs index a9a7858..171ae48 100644 --- a/src/game/serialization.rs +++ b/src/game/serialization.rs @@ -58,7 +58,13 @@ impl PostFlopGame { } } - /// Returns the number of storage elements required for the target storage mode. + /// Returns the number of storage elements required for the target storage mode: + /// `[|storage1|, |storage2|, |storage_ip|, |storage_chance|]` + /// + /// If this is a River save (`target_storage_mode == BoardState::River`) + /// then do not store cfvalues. + /// + /// If this is a Flop save, fn num_target_storage(&self) -> [usize; 4] { if self.state <= State::TreeBuilt { return [0; 4]; @@ -71,8 +77,8 @@ impl PostFlopGame { } let mut node_index = match self.target_storage_mode { - BoardState::Flop => self.num_nodes[0], - _ => self.num_nodes[0] + self.num_nodes[1], + BoardState::Flop => self.num_nodes_per_street[0], + _ => self.num_nodes_per_street[0] + self.num_nodes_per_street[1], } as usize; let mut num_storage = [0; 4]; @@ -128,7 +134,7 @@ impl Encode for PostFlopGame { self.removed_lines.encode(encoder)?; self.action_root.encode(encoder)?; self.target_storage_mode.encode(encoder)?; - self.num_nodes.encode(encoder)?; + self.num_nodes_per_street.encode(encoder)?; self.is_compression_enabled.encode(encoder)?; self.num_storage.encode(encoder)?; self.num_storage_ip.encode(encoder)?; @@ -140,8 +146,10 @@ impl Encode for PostFlopGame { self.storage_chance[0..num_storage[3]].encode(encoder)?; let num_nodes = match self.target_storage_mode { - BoardState::Flop => self.num_nodes[0] as usize, - BoardState::Turn => (self.num_nodes[0] + self.num_nodes[1]) as usize, + BoardState::Flop => self.num_nodes_per_street[0] as usize, + BoardState::Turn => { + (self.num_nodes_per_street[0] + self.num_nodes_per_street[1]) as usize + } BoardState::River => self.node_arena.len(), }; @@ -193,7 +201,7 @@ impl Decode for PostFlopGame { removed_lines: Decode::decode(decoder)?, action_root: Decode::decode(decoder)?, storage_mode: Decode::decode(decoder)?, - num_nodes: Decode::decode(decoder)?, + num_nodes_per_street: Decode::decode(decoder)?, is_compression_enabled: Decode::decode(decoder)?, num_storage: Decode::decode(decoder)?, num_storage_ip: Decode::decode(decoder)?, diff --git a/src/sliceop.rs b/src/sliceop.rs index bf726b6..f99ad70 100644 --- a/src/sliceop.rs +++ b/src/sliceop.rs @@ -1,16 +1,52 @@ use crate::utility::*; use std::mem::MaybeUninit; +/// Subtracts each element of the left-hand side (`lhs`) slice by the +/// corresponding element of the right-hand side (`rhs`) slice, modifying the +/// `lhs` slice in place. +/// +/// # Arguments +/// +/// - `lhs`: A mutable reference to the left-hand side slice, which will be +/// modified in place. +/// - `rhs`: A reference to the right-hand side slice, which provides the +/// values to be subtracted from each corresponding element in `lhs`. #[inline] pub(crate) fn sub_slice(lhs: &mut [f32], rhs: &[f32]) { lhs.iter_mut().zip(rhs).for_each(|(l, r)| *l -= *r); } +/// Multiplies each element of the left-hand side (`lhs`) slice by the +/// corresponding element of the right-hand side (`rhs`) slice, modifying the +/// `lhs` slice in place. +/// +/// # Arguments +/// +/// - `lhs`: A mutable reference to the left-hand side slice, which will be +/// modified in place. +/// - `rhs`: A reference to the right-hand side slice, which provides the +/// values to be multiplied against each corresponding element in `lhs`. #[inline] pub(crate) fn mul_slice(lhs: &mut [f32], rhs: &[f32]) { lhs.iter_mut().zip(rhs).for_each(|(l, r)| *l *= *r); } +/// Divides each element of the left-hand side (`lhs`) slice by the +/// corresponding element of the right-hand side (`rhs`) slice, modifying the +/// `lhs` slice in place. If an element in `rhs` is zero, the corresponding +/// element in `lhs` is set to a specified `default` value instead of performing +/// the division. +/// +/// # Arguments +/// +/// - `lhs`: A mutable reference to the left-hand side slice, which will be +/// modified in place. Each element of this slice is divided by the +/// corresponding element in the `rhs` slice, or set to `default` if the +/// corresponding element in `rhs` is zero. +/// - `rhs`: A reference to the right-hand side slice, which provides the +/// divisor for each element in `lhs`. +/// - `default: f32`: A fallback value that is used for elements in `lhs` where +/// the corresponding element in `rhs` is zero. #[inline] pub(crate) fn div_slice(lhs: &mut [f32], rhs: &[f32], default: f32) { lhs.iter_mut() @@ -18,6 +54,22 @@ pub(crate) fn div_slice(lhs: &mut [f32], rhs: &[f32], default: f32) { .for_each(|(l, r)| *l = if is_zero(*r) { default } else { *l / *r }); } +/// Divides each element of the left-hand side (`lhs`) slice by the +/// corresponding element of the right-hand side (`rhs`) slice, modifying the +/// `lhs` slice in place. If an element in `rhs` is zero, the corresponding +/// element in `lhs` is set to a specified `default` value instead of performing +/// the division. +/// +/// # Arguments +/// +/// - `lhs`: A mutable reference to the left-hand side slice, which will be +/// modified in place. Each element of this slice is divided by the +/// corresponding element in the `rhs` slice, or set to `default` if the +/// corresponding element in `rhs` is zero. +/// - `rhs`: A reference to the right-hand side slice, which provides the +/// divisor for each element in `lhs`. +/// - `default: f32`: A fallback value that is used for elements in `lhs` where +/// the corresponding element in `rhs` is zero. #[inline] pub(crate) fn div_slice_uninit( dst: &mut [MaybeUninit], @@ -32,6 +84,7 @@ pub(crate) fn div_slice_uninit( }); } +/// Multiply a source slice by a scalar and store in a destination slice #[inline] pub(crate) fn mul_slice_scalar_uninit(dst: &mut [MaybeUninit], src: &[f32], scalar: f32) { dst.iter_mut().zip(src).for_each(|(d, s)| { @@ -39,6 +92,17 @@ pub(crate) fn mul_slice_scalar_uninit(dst: &mut [MaybeUninit], src: &[f32], }); } +/// Compute a _strided summation_ of `f32` elements in `src`, where the stride +/// length is `dst.len()`. +/// +/// In more detail, break source slice `src` into `N` chunks `C0...CN-1`, where +/// `N = dst.len()`, and set the `i`th element of `dst` to be the sum of the +/// `i`th element of each chunk `Ck`: +/// +/// - `dst[0] = SUM(k=0..N-1, Ck[0])` +/// - `dst[1] = SUM(k=0..N-1, Ck[1])` +/// - `dst[2] = SUM(k=0..N-1, Ck[2])` +/// - ... #[inline] pub(crate) fn sum_slices_uninit<'a>(dst: &'a mut [MaybeUninit], src: &[f32]) -> &'a mut [f32] { let len = dst.len(); @@ -54,6 +118,17 @@ pub(crate) fn sum_slices_uninit<'a>(dst: &'a mut [MaybeUninit], src: &[f32] dst } +/// Compute a _strided summation_ of `f32` elements in `src`, where the stride +/// length is `dst.len()`, and store as `f64` in `dst`. +/// +/// In more detail, break source slice `src` into `N` chunks `C0...CN-1`, where +/// `N = dst.len()`, and set the `i`th element of `dst` to be the sum of the +/// `i`th element of each chunk `Ck`: +/// +/// - `dst[0] = SUM(k=0..N-1, Ck[0])` +/// - `dst[1] = SUM(k=0..N-1, Ck[1])` +/// - `dst[2] = SUM(k=0..N-1, Ck[2])` +/// - ... #[inline] pub(crate) fn sum_slices_f64_uninit<'a>( dst: &'a mut [MaybeUninit], @@ -72,6 +147,30 @@ pub(crate) fn sum_slices_f64_uninit<'a>( dst } +/// Performs a fused multiply-add (FMA) operation on slices, storing the result +/// in a destination slice. +/// +/// This function multiplies the first `dst.len()` corresponding elements of the +/// two source slices (`src1` and `src2`) and stores the results in the +/// destination slice (`dst`). After the initial multiplication, it continues +/// to perform additional multiply-add operations using subsequent chunks of +/// `src1` and `src2`, adding the products to the already computed values in +/// `dst`. +/// +/// # Arguments +/// +/// - `dst`: A mutable reference to a slice of uninitialized memory where the +/// results will be stored. The length of this slice dictates how many +/// elements are processed in the initial operation. +/// - `src1`: A reference to the first source slice, providing the +/// multiplicands. +/// - `src2`: A reference to the second source slice, providing the multipliers. +/// +/// # Returns +/// +/// A mutable reference to the `dst` slice, now reinterpreted as a fully +/// initialized slice of `f32` values, containing the results of the fused +/// multiply-add operations. #[inline] pub(crate) fn fma_slices_uninit<'a>( dst: &'a mut [MaybeUninit], @@ -236,11 +335,27 @@ pub(crate) fn inner_product_cond( acc.iter().sum::() as f32 } +/// Extract a reference to a specific "row" from a one-dimensional slice, where +/// the data is conceptually arranged as a two-dimensional array. +/// +/// # Arguments +/// +/// * `slice` - slice to extract a reference from +/// * `index` - the index of the conceptual "row" to reference +/// * `row_size` - the size of the conceptual "row" to reference #[inline] pub(crate) fn row(slice: &[T], index: usize, row_size: usize) -> &[T] { &slice[index * row_size..(index + 1) * row_size] } +/// Extract a mutable reference to a specific "row" from a one-dimensional +/// slice, where the data is conceptually arranged as a two-dimensional array. +/// +/// # Arguments +/// +/// * `slice` - slice to extract a mutable reference from +/// * `index` - the index of the conceptual "row" to reference +/// * `row_size` - the size of the conceptual "row" to reference #[inline] pub(crate) fn row_mut(slice: &mut [T], index: usize, row_size: usize) -> &mut [T] { &mut slice[index * row_size..(index + 1) * row_size] diff --git a/src/solver.rs b/src/solver.rs index 5a1cc5a..1548a1a 100644 --- a/src/solver.rs +++ b/src/solver.rs @@ -9,8 +9,11 @@ use std::mem::MaybeUninit; use crate::alloc::*; struct DiscountParams { + // coefficient for accumulated positive regrets alpha_t: f32, + // coefficient for accumulated negative regrets beta_t: f32, + // contributions to average strategy gamma_t: f32, } @@ -132,7 +135,16 @@ pub fn solve_step(game: &T, current_iteration: u32) { } } -/// Recursively solves the counterfactual values. +/// Recursively solves the counterfactual values and store them in `result`. +/// +/// # Arguments +/// +/// * `result` - slice to store resulting counterfactual regret values +/// * `game` - reference to the game we are solving +/// * `node` - current node we are solving +/// * `player` - current player we are solving for +/// * `cfreach` - the probability of reaching this point with a particular private hand +/// * `params` - the DiscountParams that parametrize the solver fn solve_recursive( result: &mut [MaybeUninit], game: &T, @@ -157,7 +169,14 @@ fn solve_recursive( return; } - // allocate memory for storing the counterfactual values + // Allocate memory for storing the counterfactual values. Conceptually this + // is a `num_actions * num_hands` 2-dimensional array, where the `i`th + // row (which has length `num_hands`) corresponds to the cfvalues of each + // hand after taking the `i`th action. + // + // Rows are obtained using operations from `sliceop` (e.g., `sliceop::row_mut()`). + // + // `cfv_actions` will be written to by recursive calls to `solve_recursive`. #[cfg(feature = "custom-alloc")] let cfv_actions = MutexLike::new(Vec::with_capacity_in(num_actions * num_hands, StackAlloc)); #[cfg(not(feature = "custom-alloc"))] @@ -189,13 +208,15 @@ fn solve_recursive( ); }); - // use 64-bit floating point values + // use 64-bit floating point values for precision during summations + // before demoting back to f32 #[cfg(feature = "custom-alloc")] let mut result_f64 = Vec::with_capacity_in(num_hands, StackAlloc); #[cfg(not(feature = "custom-alloc"))] let mut result_f64 = Vec::with_capacity(num_hands); - // sum up the counterfactual values + // compute the strided summation of the counterfactual values for each + // hand and store in `result_f64` let mut cfv_actions = cfv_actions.lock(); unsafe { cfv_actions.set_len(num_actions * num_hands) }; sum_slices_f64_uninit(result_f64.spare_capacity_mut(), &cfv_actions); @@ -209,13 +230,13 @@ fn solve_recursive( let swap_list = &game.isomorphic_swap(node, i)[player]; let tmp = row_mut(&mut cfv_actions, isomorphic_index as usize, num_hands); - apply_swap(tmp, swap_list); + apply_swap_list(tmp, swap_list); result_f64.iter_mut().zip(&*tmp).for_each(|(r, &v)| { *r += v as f64; }); - apply_swap(tmp, swap_list); + apply_swap_list(tmp, swap_list); } result.iter_mut().zip(&result_f64).for_each(|(r, &v)| { @@ -236,7 +257,7 @@ fn solve_recursive( ); }); - // compute the strategy by regret-maching algorithm + // compute the strategy by regret-matching algorithm let mut strategy = if game.is_compression_enabled() { regret_matching_compressed(node.regrets_compressed(), num_actions) } else { @@ -247,7 +268,17 @@ fn solve_recursive( let locking = game.locking_strategy(node); apply_locking_strategy(&mut strategy, locking); - // sum up the counterfactual values + // Compute the counterfactual values for each hand, which for hand `h` is + // computed to be the sum over actions `a` of the frequency with which + // `h` takes action `a` and the regret of hand `h` taking action `a`. + // In pseudocode, this is: + // + // ``` + // result[h] = sum([freq(h, a) * regret(h, a) for a in actions]) + // ``` + // + // This sum-of-products us computed as a fused multiply-add using + // `fma_slices_uninit` and is stored in `result`. let mut cfv_actions = cfv_actions.lock(); unsafe { cfv_actions.set_len(num_actions * num_hands) }; let result = fma_slices_uninit(result, &strategy, &cfv_actions); @@ -299,6 +330,7 @@ fn solve_recursive( node.set_regret_scale(new_scale); } else { // update the cumulative strategy + // - `gamma` is used to discount cumulative strategy contributions let gamma = params.gamma_t; let cum_strategy = node.strategy_mut(); cum_strategy.iter_mut().zip(&strategy).for_each(|(x, y)| { @@ -306,6 +338,8 @@ fn solve_recursive( }); // update the cumulative regret + // - alpha is used to discount positive cumulative regrets + // - beta is used to discount negative cumulative regrets let (alpha, beta) = (params.alpha_t, params.beta_t); let cum_regret = node.regrets_mut(); cum_regret.iter_mut().zip(&*cfv_actions).for_each(|(x, y)| { @@ -380,6 +414,18 @@ fn regret_matching(regret: &[f32], num_actions: usize) -> Vec { } /// Computes the strategy by regret-matching algorithm. +/// +/// The resulting strategy has each element (e.g., a hand like **AdQs**) take +/// an action proportional to its regret, where negative regrets are interpreted +/// as zero. +/// +/// # Arguments +/// +/// * `regret` - slice of regrets for the current decision point, one "row" of +/// for each action. The `i`th row contains the regrets of each strategically +/// distinct element (e.g., in holdem an element would be a hole card) for +/// taking the `i`th action. +/// * `num_actions` - the number of actions represented in `regret`. #[cfg(not(feature = "custom-alloc"))] #[inline] fn regret_matching(regret: &[f32], num_actions: usize) -> Vec { @@ -391,10 +437,15 @@ fn regret_matching(regret: &[f32], num_actions: usize) -> Vec { unsafe { strategy.set_len(regret.len()) }; let row_size = regret.len() / num_actions; + + // We want to normalize each element's strategy, so compute the element-wise + // denominator by computing the strided summation of strategy let mut denom = Vec::with_capacity(row_size); sum_slices_uninit(denom.spare_capacity_mut(), &strategy); unsafe { denom.set_len(row_size) }; + // We set the default to be equally distributed across all options. This is + // used when a strategy for a particular hand is uniformly zero. let default = 1.0 / num_actions as f32; strategy.chunks_exact_mut(row_size).for_each(|row| { div_slice(row, &denom, default); diff --git a/src/utility.rs b/src/utility.rs index c834de7..d38208b 100644 --- a/src/utility.rs +++ b/src/utility.rs @@ -227,9 +227,14 @@ pub(crate) fn encode_unsigned_slice(dst: &mut [u16], slice: &[f32]) -> f32 { scale } -/// Applies the given swap to the given slice. +/// Applies the given list of swaps to the given slice. +/// +/// # Arguments +/// +/// * `slice` - mutable slice to perform swaps on +/// * `swap_list` - a list of index pairs to swap #[inline] -pub(crate) fn apply_swap(slice: &mut [T], swap_list: &[(u16, u16)]) { +pub(crate) fn apply_swap_list(slice: &mut [T], swap_list: &[(u16, u16)]) { for &(i, j) in swap_list { unsafe { ptr::swap( @@ -425,13 +430,13 @@ fn compute_cfvalue_recursive( let swap_list = &game.isomorphic_swap(node, i)[player]; let tmp = row_mut(&mut cfv_actions, isomorphic_index as usize, num_hands); - apply_swap(tmp, swap_list); + apply_swap_list(tmp, swap_list); result_f64.iter_mut().zip(&*tmp).for_each(|(r, &v)| { *r += v as f64; }); - apply_swap(tmp, swap_list); + apply_swap_list(tmp, swap_list); } result.iter_mut().zip(&result_f64).for_each(|(r, &v)| { @@ -637,13 +642,13 @@ fn compute_best_cfv_recursive( let swap_list = &game.isomorphic_swap(node, i)[player]; let tmp = row_mut(&mut cfv_actions, isomorphic_index as usize, num_hands); - apply_swap(tmp, swap_list); + apply_swap_list(tmp, swap_list); result_f64.iter_mut().zip(&*tmp).for_each(|(r, &v)| { *r += v as f64; }); - apply_swap(tmp, swap_list); + apply_swap_list(tmp, swap_list); } result.iter_mut().zip(&result_f64).for_each(|(r, &v)| {