diff --git a/src/bm/bm_util/position.rs b/src/bm/bm_util/position.rs index ae901895..14343c5a 100644 --- a/src/bm/bm_util/position.rs +++ b/src/bm/bm_util/position.rs @@ -106,7 +106,7 @@ impl Position { let frc_score = frc::frc_corner_bishop(&board); - Evaluation::new(self.evaluator.feed_forward(&board, 0) + frc_score + eval_bonus) + Evaluation::new(self.evaluator.feed_forward(&board) + frc_score + eval_bonus) } pub fn insufficient_material(&self) -> bool { diff --git a/src/bm/nnue.rs b/src/bm/nnue.rs index 0a5531dd..e211ee9c 100644 --- a/src/bm/nnue.rs +++ b/src/bm/nnue.rs @@ -58,7 +58,7 @@ pub struct Nnue { impl Nnue { pub fn new() -> Self { let mut bytes = &NN_BYTES[12..]; - let incremental = Arc::new(*include::dense_from_bytes_i16::(bytes)); + let incremental = Arc::new(*include::sparse_from_bytes_i16::(bytes)); bytes = &bytes[INPUT * MID * 2..]; let incremental_bias = include::bias_from_bytes_i16::(bytes); bytes = &bytes[MID * 2..]; @@ -208,7 +208,7 @@ impl Nnue { } #[inline] - pub fn feed_forward(&mut self, board: &Board, bucket: usize) -> i16 { + pub fn feed_forward(&mut self, board: &Board) -> i16 { let acc = &mut self.accumulator[self.head]; let mut incr = [0; MID * 2]; let (stm, nstm) = match board.side_to_move() { @@ -218,6 +218,6 @@ impl Nnue { layers::clipped_relu(*stm.get(), &mut incr); layers::clipped_relu(*nstm.get(), &mut incr[MID..]); - layers::out(self.out_layer.ff(&incr, bucket)[bucket]) + layers::out(self.out_layer.ff(&incr)[0]) } } diff --git a/src/bm/nnue/include.rs b/src/bm/nnue/include.rs index 453e0e3d..cd03062c 100644 --- a/src/bm/nnue/include.rs +++ b/src/bm/nnue/include.rs @@ -1,4 +1,4 @@ -pub fn dense_from_bytes_i16< +pub fn sparse_from_bytes_i16< T: From + Copy + Default, const INPUT: usize, const OUTPUT: usize, @@ -34,16 +34,16 @@ pub fn dense_from_bytes_i8< const OUTPUT: usize, >( bytes: &[u8], -) -> Box<[[T; OUTPUT]; INPUT]> { +) -> Box<[[T; INPUT]; OUTPUT]> { let mut weights = vec![]; for &byte in bytes.iter().take(INPUT * OUTPUT) { weights.push(i8::from_le_bytes([byte])) } - let mut dense = Box::new([[T::default(); OUTPUT]; INPUT]); - for (i, weights) in weights.chunks(OUTPUT).enumerate() { + let mut dense = Box::new([[T::default(); INPUT]; OUTPUT]); + for (i, weights) in weights.chunks(INPUT).enumerate() { for (j, &weight) in weights.into_iter().enumerate() { dense[i][j] = T::from(weight); } } dense -} \ No newline at end of file +} diff --git a/src/bm/nnue/layers.rs b/src/bm/nnue/layers.rs index e5b5007d..359ac8f5 100644 --- a/src/bm/nnue/layers.rs +++ b/src/bm/nnue/layers.rs @@ -35,23 +35,20 @@ impl<'a, const INPUT: usize, const OUTPUT: usize> Incremental { #[derive(Debug, Clone)] pub struct Dense { - weights: Arc<[[i8; OUTPUT]; INPUT]>, + weights: Arc<[[i8; INPUT]; OUTPUT]>, bias: [i32; OUTPUT], } impl Dense { - pub fn new(weights: Arc<[[i8; OUTPUT]; INPUT]>, bias: [i32; OUTPUT]) -> Self { + pub fn new(weights: Arc<[[i8; INPUT]; OUTPUT]>, bias: [i32; OUTPUT]) -> Self { Self { weights, bias } } #[inline] - pub fn ff(&self, inputs: &[u8; INPUT], bucket: usize) -> [i32; OUTPUT] { + pub fn ff(&self, inputs: &[u8; INPUT]) -> [i32; OUTPUT] { let mut out = self.bias; - for (&input, weights) in inputs.iter().zip(&*self.weights) { - for (out, &weight) in out[bucket..bucket + 1] - .iter_mut() - .zip(weights[bucket..bucket + 1].iter()) - { + for (out, weights) in out.iter_mut().zip(&*self.weights) { + for (&input, &weight) in inputs.iter().zip(weights.iter()) { *out += weight as i32 * input as i32; } } diff --git a/src/bm/uci.rs b/src/bm/uci.rs index 7c8411bf..57753fe1 100644 --- a/src/bm/uci.rs +++ b/src/bm/uci.rs @@ -9,7 +9,6 @@ use crate::bm::bm_runner::ab_runner::AbRunner; use crate::bm::bm_runner::config::{NoInfo, Run, UciInfo}; use crate::bm::bm_runner::time::{TimeManagementInfo, TimeManager}; -use crate::bm::nnue::Nnue; const VERSION: &str = "6.0"; @@ -121,12 +120,6 @@ impl UciAdapter { let runner = &mut *self.bm_runner.lock().unwrap(); println!("eval : {}", runner.raw_eval().raw()); - { - let mut nnue = Nnue::new(); - for i in 0..1 { - println!("bucket {}: {}", i, nnue.feed_forward(runner.get_board(), i)); - } - } } UciCommand::Go(commands) => self.go(commands), UciCommand::NewGame => {