Skip to content

Commit

Permalink
Adding Debug for structs and adding summary function to Sequential im…
Browse files Browse the repository at this point in the history
…plementation
  • Loading branch information
mjovanc committed Nov 26, 2024
1 parent 9013e32 commit 55aa3e3
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 4 deletions.
4 changes: 3 additions & 1 deletion delta_common/src/layer.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::fmt::Debug;
use crate::tensor_ops::Tensor;

pub trait Layer {
pub trait Layer: Debug {
fn forward(&self, input: &Tensor) -> Tensor;
fn backward(&mut self, grad: &Tensor) -> Tensor;
}

#[derive(Debug)]
pub struct LayerOutput {
pub output: Tensor,
pub gradients: Tensor,
Expand Down
4 changes: 3 additions & 1 deletion delta_common/src/optimizer.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::fmt::Debug;
use crate::tensor_ops::Tensor;

pub trait Optimizer {
pub trait Optimizer: Debug {
fn step(&mut self, gradients: &mut [Tensor]);
}

#[derive(Debug)]
pub struct OptimizerConfig {
pub learning_rate: f32,
}
1 change: 1 addition & 0 deletions delta_common/src/tensor_ops.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::shape::Shape;

#[derive(Debug)]
pub struct Tensor {
pub data: Vec<f32>,
pub shape: Shape,
Expand Down
2 changes: 2 additions & 0 deletions delta_nn/src/layers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use delta_common::{Layer, Shape};
use delta_common::tensor_ops::Tensor;

#[derive(Debug)]
pub struct Dense {
weights: Tensor,
bias: Tensor,
Expand All @@ -25,6 +26,7 @@ impl Layer for Dense {
}
}

#[derive(Debug)]
pub struct Relu;

impl Relu {
Expand Down
9 changes: 9 additions & 0 deletions delta_nn/src/models.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use delta_common::{Dataset, Layer, Optimizer};
use delta_common::data::DatasetOps;

#[derive(Debug)]
pub struct Sequential {
layers: Vec<Box<dyn Layer>>,
optimizer: Option<Box<dyn Optimizer>>,
Expand Down Expand Up @@ -46,4 +47,12 @@ impl Sequential {
/*pub fn forward(&self, input: &Tensor) -> Tensor {
self.layers.iter().fold(input.clone(), |acc, layer| layer.forward(&acc))
}*/

pub fn summary(&self) -> String {
let mut summary = String::new();
for (i, layer) in self.layers.iter().enumerate() {
summary.push_str(&format!("Layer {}: {:?}\n", i + 1, layer));
}
summary
}
}
15 changes: 13 additions & 2 deletions delta_optimizers/src/adam.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
use std::fmt;
use std::fmt::Debug;
use delta_common::Optimizer;
use delta_common::tensor_ops::Tensor;

struct DebuggableScheduler(Box<dyn Fn(usize) -> f32>);

impl Debug for DebuggableScheduler {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("DebuggableScheduler")
}
}

#[derive(Debug)]
pub struct Adam {
learning_rate: f32,
scheduler: Option<Box<dyn Fn(usize) -> f32>>,
scheduler: Option<DebuggableScheduler>,
}

impl Adam {
Expand All @@ -15,7 +26,7 @@ impl Adam {
where
F: Fn(usize) -> f32 + 'static,
{
self.scheduler = Some(Box::new(scheduler));
self.scheduler = Some(DebuggableScheduler(Box::new(scheduler)));
}
}

Expand Down

0 comments on commit 55aa3e3

Please sign in to comment.