Skip to content

Commit

Permalink
Merge branch 'main' into docs
Browse files Browse the repository at this point in the history
  • Loading branch information
bkushigian committed Oct 7, 2024
2 parents 3dad754 + a2aedc9 commit 202fbd0
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 55 deletions.
4 changes: 2 additions & 2 deletions src/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ struct StackAllocData {
}

thread_local! {
static STACK_ALLOC_DATA: RefCell<StackAllocData> = RefCell::new(StackAllocData {
static STACK_ALLOC_DATA: RefCell<StackAllocData> = const {RefCell::new(StackAllocData {
index: usize::MAX,
base: Vec::new(),
current: Vec::new(),
});
})};
}

impl StackAllocData {
Expand Down
100 changes: 54 additions & 46 deletions src/bunching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ pub struct BunchingData {
#[inline]
fn mask_to_index(mut mask: u64, k: usize) -> usize {
let mut index = 0;
for i in 0..k {
COMB_TABLE.iter().take(k).for_each(|xs| {
assert!(mask != 0);
let tz = mask.trailing_zeros();
index += COMB_TABLE[i][tz as usize];
index += xs[tz as usize];
mask &= mask - 1;
}
});
index
}

Expand All @@ -175,8 +175,8 @@ fn next_combination(mask: u64) -> u64 {
#[inline]
fn compress_mask(mut mask: u64, flop: [Card; 3]) -> u64 {
assert!(flop[0] < flop[1] && flop[1] < flop[2]);
for i in 0..3 {
let m = (1 << (flop[i] as usize - i)) - 1;
for (i, &c) in flop.iter().enumerate() {
let m = (1 << (c as usize - i)) - 1;
mask = (mask & m) | ((mask >> 1) & !m);
}
mask
Expand Down Expand Up @@ -716,36 +716,40 @@ impl BunchingData {
let chunk_end_index = usize::min(chunk_start_index + 100, end_index);
let mut src_mask = index_to_mask(chunk_start_index, K);

for src_index in chunk_start_index..chunk_end_index {
for entry in src_table
.iter()
.take(chunk_end_index)
.skip(chunk_start_index)
{
let mut src_mask_copy = src_mask;
src_mask = next_combination(src_mask);

let freq = src_table[src_index].load();
let freq = entry.load();
if freq == 0.0 {
continue;
}

let mut src_mask_bit = [0; K];
for i in 0..K {
src_mask_bit.iter_mut().for_each(|bit| {
let lsb = src_mask_copy & src_mask_copy.wrapping_neg();
src_mask_copy ^= lsb;
src_mask_bit[i] = lsb;
}
*bit = lsb;
});

for i in 0..(1 << K) - 1 {
if num_ones[i] > 6 {
for (i, &x) in num_ones.iter().enumerate() {
if x > 6 {
continue;
}

let mut dst_mask = 0;
for j in 0..K {
for (j, &y) in src_mask_bit.iter().enumerate() {
if i & (1 << j) != 0 {
dst_mask |= src_mask_bit[j];
dst_mask |= y;
}
}

let dst_index = mask_to_index(dst_mask, num_ones[i] as usize);
self.sum[num_ones[i] as usize][dst_index].add(freq);
let dst_index = mask_to_index(dst_mask, x as usize);
self.sum[x as usize][dst_index].add(freq);
}
}
});
Expand Down Expand Up @@ -773,37 +777,41 @@ impl BunchingData {
let dst_end_index = usize::min(dst_start_index + 100, end_index);
let mut mask = index_to_mask(dst_start_index, N);

for dst_index in dst_start_index..dst_end_index {
let mut mask_copy = mask;
mask = next_combination(mask);

let mut mask_bit = [0; N];
for i in 0..N {
let lsb = mask_copy & mask_copy.wrapping_neg();
mask_copy ^= lsb;
mask_bit[i] = lsb;
}

let mut result = 0.0;

for &(i, k) in &indices {
let mut src_mask = 0;
for j in 0..N {
if i & (1 << j) != 0 {
src_mask |= mask_bit[j];
dst_table
.iter()
.take(dst_end_index)
.skip(dst_start_index)
.for_each(|dst| {
let mut mask_copy = mask;
mask = next_combination(mask);

let mut mask_bit = [0; N];
mask_bit.iter_mut().for_each(|bit| {
let lsb = mask_copy & mask_copy.wrapping_neg();
mask_copy ^= lsb;
*bit = lsb;
});

let mut result = 0.0;

for &(i, k) in &indices {
let mut src_mask = 0;
mask_bit.iter().take(N).enumerate().for_each(|(j, mb)| {
if i & (1 << j) != 0 {
src_mask |= mb;
}
});

let src_index = mask_to_index(src_mask, k as usize);
if k & 1 == 0 {
result += self.sum[k as usize][src_index].load();
} else {
result -= self.sum[k as usize][src_index].load();
}
}

let src_index = mask_to_index(src_mask, k as usize);
if k & 1 == 0 {
result += self.sum[k as usize][src_index].load();
} else {
result -= self.sum[k as usize][src_index].load();
}
}

dst_table[dst_index].store(f32::max(result as f32, 0.0));
}
dst.store(f32::max(result as f32, 0.0));
});
});
}
}
Expand All @@ -820,8 +828,8 @@ mod tests {
];

let mut mask = 0b001111;
for i in 0..15 {
assert_eq!(mask, seq[i]);
for x in seq {
assert_eq!(mask, x);
mask = next_combination(mask);
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/game/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,8 @@ impl PostFlopGame {
board_mask |= 1 << river;
}

for player in 0..2 {
let (hands, weights) = range[player].get_hands_weights(board_mask);
for (player, r) in range.iter().enumerate() {
let (hands, weights) = r.get_hands_weights(board_mask);
self.initial_weights[player] = weights;
self.private_cards[player] = hands;
}
Expand Down
8 changes: 4 additions & 4 deletions src/game/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ impl PostFlopGame {
static VERSION_STR: &str = "2023-03-19";

thread_local! {
static PTR_BASE: Cell<[*const u8; 2]> = Cell::new([ptr::null(); 2]);
static CHANCE_BASE: Cell<*const u8> = Cell::new(ptr::null());
static PTR_BASE_MUT: Cell<[*mut u8; 3]> = Cell::new([ptr::null_mut(); 3]);
static CHANCE_BASE_MUT: Cell<*mut u8> = Cell::new(ptr::null_mut());
static PTR_BASE: Cell<[*const u8; 2]> = const {Cell::new([ptr::null(); 2])};
static CHANCE_BASE: Cell<*const u8> = const {Cell::new(ptr::null())};
static PTR_BASE_MUT: Cell<[*mut u8; 3]> = const {Cell::new([ptr::null_mut(); 3])};
static CHANCE_BASE_MUT: Cell<*mut u8> = const {Cell::new(ptr::null_mut())};
}

impl Encode for PostFlopGame {
Expand Down
2 changes: 1 addition & 1 deletion src/mutex_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl<T: ?Sized> MutexLike<T> {
}
}

impl<T: ?Sized + Default> Default for MutexLike<T> {
impl<T: Default> Default for MutexLike<T> {
#[inline]
fn default() -> Self {
Self::new(Default::default())
Expand Down
1 change: 1 addition & 0 deletions src/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ impl FromStr for Range {
}
}

#[allow(clippy::to_string_trait_impl)]
impl ToString for Range {
#[inline]
fn to_string(&self) -> String {
Expand Down

0 comments on commit 202fbd0

Please sign in to comment.