Skip to content

Commit

Permalink
Merge pull request #120 from dsekercioglu/dense-fix
Browse files Browse the repository at this point in the history
Fix dense layer memory layout, remove unnecessary buckets argument
  • Loading branch information
jnlt3 authored Jul 16, 2022
2 parents e739c96 + 082b559 commit 21bf762
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/bm/bm_util/position.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl Position {

let frc_score = frc::frc_corner_bishop(&board);

Evaluation::new(self.evaluator.feed_forward(&board, 0) + frc_score + eval_bonus)
Evaluation::new(self.evaluator.feed_forward(&board) + frc_score + eval_bonus)
}

pub fn insufficient_material(&self) -> bool {
Expand Down
6 changes: 3 additions & 3 deletions src/bm/nnue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub struct Nnue {
impl Nnue {
pub fn new() -> Self {
let mut bytes = &NN_BYTES[12..];
let incremental = Arc::new(*include::dense_from_bytes_i16::<i16, INPUT, MID>(bytes));
let incremental = Arc::new(*include::sparse_from_bytes_i16::<i16, INPUT, MID>(bytes));
bytes = &bytes[INPUT * MID * 2..];
let incremental_bias = include::bias_from_bytes_i16::<i16, MID>(bytes);
bytes = &bytes[MID * 2..];
Expand Down Expand Up @@ -208,7 +208,7 @@ impl Nnue {
}

#[inline]
pub fn feed_forward(&mut self, board: &Board, bucket: usize) -> i16 {
pub fn feed_forward(&mut self, board: &Board) -> i16 {
let acc = &mut self.accumulator[self.head];
let mut incr = [0; MID * 2];
let (stm, nstm) = match board.side_to_move() {
Expand All @@ -218,6 +218,6 @@ impl Nnue {
layers::clipped_relu(*stm.get(), &mut incr);
layers::clipped_relu(*nstm.get(), &mut incr[MID..]);

layers::out(self.out_layer.ff(&incr, bucket)[bucket])
layers::out(self.out_layer.ff(&incr)[0])
}
}
10 changes: 5 additions & 5 deletions src/bm/nnue/include.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub fn dense_from_bytes_i16<
pub fn sparse_from_bytes_i16<
T: From<i16> + Copy + Default,
const INPUT: usize,
const OUTPUT: usize,
Expand Down Expand Up @@ -34,16 +34,16 @@ pub fn dense_from_bytes_i8<
const OUTPUT: usize,
>(
bytes: &[u8],
) -> Box<[[T; OUTPUT]; INPUT]> {
) -> Box<[[T; INPUT]; OUTPUT]> {
let mut weights = vec![];
for &byte in bytes.iter().take(INPUT * OUTPUT) {
weights.push(i8::from_le_bytes([byte]))
}
let mut dense = Box::new([[T::default(); OUTPUT]; INPUT]);
for (i, weights) in weights.chunks(OUTPUT).enumerate() {
let mut dense = Box::new([[T::default(); INPUT]; OUTPUT]);
for (i, weights) in weights.chunks(INPUT).enumerate() {
for (j, &weight) in weights.into_iter().enumerate() {
dense[i][j] = T::from(weight);
}
}
dense
}
}
13 changes: 5 additions & 8 deletions src/bm/nnue/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,20 @@ impl<'a, const INPUT: usize, const OUTPUT: usize> Incremental<INPUT, OUTPUT> {

#[derive(Debug, Clone)]
pub struct Dense<const INPUT: usize, const OUTPUT: usize> {
weights: Arc<[[i8; OUTPUT]; INPUT]>,
weights: Arc<[[i8; INPUT]; OUTPUT]>,
bias: [i32; OUTPUT],
}

impl<const INPUT: usize, const OUTPUT: usize> Dense<INPUT, OUTPUT> {
pub fn new(weights: Arc<[[i8; OUTPUT]; INPUT]>, bias: [i32; OUTPUT]) -> Self {
pub fn new(weights: Arc<[[i8; INPUT]; OUTPUT]>, bias: [i32; OUTPUT]) -> Self {
Self { weights, bias }
}

#[inline]
pub fn ff(&self, inputs: &[u8; INPUT], bucket: usize) -> [i32; OUTPUT] {
pub fn ff(&self, inputs: &[u8; INPUT]) -> [i32; OUTPUT] {
let mut out = self.bias;
for (&input, weights) in inputs.iter().zip(&*self.weights) {
for (out, &weight) in out[bucket..bucket + 1]
.iter_mut()
.zip(weights[bucket..bucket + 1].iter())
{
for (out, weights) in out.iter_mut().zip(&*self.weights) {
for (&input, &weight) in inputs.iter().zip(weights.iter()) {
*out += weight as i32 * input as i32;
}
}
Expand Down
7 changes: 0 additions & 7 deletions src/bm/uci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use crate::bm::bm_runner::ab_runner::AbRunner;
use crate::bm::bm_runner::config::{NoInfo, Run, UciInfo};

use crate::bm::bm_runner::time::{TimeManagementInfo, TimeManager};
use crate::bm::nnue::Nnue;

const VERSION: &str = "6.0";

Expand Down Expand Up @@ -121,12 +120,6 @@ impl UciAdapter {
let runner = &mut *self.bm_runner.lock().unwrap();

println!("eval : {}", runner.raw_eval().raw());
{
let mut nnue = Nnue::new();
for i in 0..1 {
println!("bucket {}: {}", i, nnue.feed_forward(runner.get_board(), i));
}
}
}
UciCommand::Go(commands) => self.go(commands),
UciCommand::NewGame => {
Expand Down

0 comments on commit 21bf762

Please sign in to comment.