Skip to content

Commit

Permalink
Implement tensor random and zeros functions. Add size to shape
Browse files Browse the repository at this point in the history
  • Loading branch information
chaseWillden committed Nov 26, 2024
1 parent 3304265 commit fb9f343
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 11 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,6 @@ rust-project.json
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
.ionide

.cache
1 change: 1 addition & 0 deletions delta_common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ publish = false
maintenance = { status = "actively-developed" }

[dependencies]
rand = "0.8.5"
16 changes: 15 additions & 1 deletion delta_common/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,24 @@ impl Shape {
pub fn len(&self) -> usize {
self.0.iter().product()
}

/// Returns the number of elements in the shape
///
/// # Examples
///
/// ```
/// use delta_common::shape::Shape;
///
/// let shape = Shape::new(vec![2, 3, 4]);
/// assert_eq!(shape.size(), 24);
/// ```
pub fn size(&self) -> usize {
self.0.iter().product()
}
}

impl From<(usize, usize)> for Shape {
fn from(dimensions: (usize, usize)) -> Self {
Shape(vec![dimensions.0, dimensions.1])
}
}
}
52 changes: 49 additions & 3 deletions delta_common/src/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
//! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
//! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use rand::Rng;

use crate::shape::Shape;

#[derive(Debug)]
Expand All @@ -46,12 +48,38 @@ impl Tensor {
Tensor::new(vec![], self.shape.clone())
}

/// Create a tensor filled with zeros
///
/// # Arguments
///
/// * `shape` - The shape of the tensor
///
/// # Returns
///
/// A tensor filled with zeros
pub fn zeros(shape: &Shape) -> Self {
todo!("Create a tensor filled with zeros")
let size = shape.size();
let data = vec![0.0; size];

Self {
data,
shape: shape.clone(),
}
}

/// Create a tensor filled with random values
///
/// # Arguments
///
/// * `shape` - The shape of the tensor
pub fn random(shape: &Shape) -> Self {
todo!("Create a tensor filled with random values")
let size = shape.size();
let data = generate_random_data(size);

Self {
data,
shape: shape.clone(),
}
}

pub fn matmul(&self, other: &Tensor) -> Tensor {
Expand All @@ -72,4 +100,22 @@ impl Tensor {
pub fn shape(&self) -> &Shape {
&self.shape
}
}
}

/// Generate a vector of random numbers
///
/// # Arguments
///
/// * `length` - The length of the vector
///
/// # Returns
///
/// A vector of random numbers
fn generate_random_data(length: usize) -> Vec<f32> {
let mut random_number_generator = rand::thread_rng();
let mut data = Vec::with_capacity(length);
for _ in 0..length {
data.push(random_number_generator.gen::<f32>());
}
data
}
73 changes: 67 additions & 6 deletions delta_data/src/mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ impl MnistDataset {
const TRAIN_EXAMPLES: usize = 60000;
const TEST_EXAMPLES: usize = 10000;

/// Load the MNIST dataset
///
/// # Arguments
///
/// * `is_train` - Whether to load the training or test dataset
///
/// # Returns
///
/// A dataset containing the MNIST data
async fn load_data(is_train: bool) -> Dataset {
let (data_filename, labels_filename, num_examples) = if is_train {
(
Expand All @@ -74,6 +83,16 @@ impl MnistDataset {
Dataset::new(data, labels)
}

/// Parse the images from the MNIST dataset
///
/// # Arguments
///
/// * `data` - The data to parse
/// * `num_images` - The number of images to parse
///
/// # Returns
///
/// A tensor containing the parsed images
fn parse_images(data: &[u8], num_images: usize) -> Tensor {
let image_data = &data[16..]; // Skip the 16-byte header
let num_pixels = Self::MNIST_IMAGE_SIZE * Self::MNIST_IMAGE_SIZE;
Expand All @@ -98,6 +117,16 @@ impl MnistDataset {
)
}

/// Parse the labels from the MNIST dataset
///
/// # Arguments
///
/// * `data` - The data to parse
/// * `num_labels` - The number of labels to parse
///
/// # Returns
///
/// A tensor containing the parsed labels
fn parse_labels(data: &[u8], num_labels: usize) -> Tensor {
let label_data = &data[8..]; // Skip the 8-byte header
let mut tensor_data = vec![0.0; num_labels * Self::MNIST_NUM_CLASSES];
Expand All @@ -112,6 +141,15 @@ impl MnistDataset {
)
}

/// Download a file from the MNIST dataset
///
/// # Arguments
///
/// * `name` - The name of the file to download
///
/// # Returns
///
/// A vector of bytes containing the downloaded data
async fn get_bytes_data(name: &str) -> Vec<u8> {
let file_path = format!(".cache/data/mnist/{}", name);
if std::path::Path::new(&file_path).exists() {
Expand All @@ -135,6 +173,15 @@ impl MnistDataset {
Self::decompress_gz(&file_path).unwrap()
}

/// Decompress a gzip file
///
/// # Arguments
///
/// * `file_path` - The path to the gzip file
///
/// # Returns
///
/// A vector of bytes containing the decompressed data
fn decompress_gz(file_path: &str) -> io::Result<Vec<u8>> {
let file = File::open(file_path)?;
let mut decoder = GzDecoder::new(file);
Expand All @@ -145,20 +192,34 @@ impl MnistDataset {
}

impl DatasetOps for MnistDataset {
/// Load the training dataset
///
/// # Examples
///
/// ```rust
/// use delta_data::mnist::MnistDataset;
///
/// let dataset = MnistDataset::load_train().await;
/// ```
async fn load_train() -> Self {
let train = tokio::runtime::Runtime::new()
.unwrap()
.block_on(Self::load_data(true));
let train = Self::load_data(true).await;
MnistDataset {
train: Some(train),
test: None,
}
}

/// Load the test dataset
///
/// # Examples
///
/// ```rust
/// use delta_data::mnist::MnistDataset;
///
/// let dataset = MnistDataset::load_test().await;
/// ```
async fn load_test() -> Self {
let test = tokio::runtime::Runtime::new()
.unwrap()
.block_on(Self::load_data(false));
let test = Self::load_data(false).await;
MnistDataset {
train: None,
test: Some(test),
Expand Down
1 change: 1 addition & 0 deletions examples/mnist/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ async fn main() {
model.compile(optimizer);

// Train the model
println!("Training...");
let train_data = MnistDataset::load_train().await;
let test_data = MnistDataset::load_test().await;

Expand Down

0 comments on commit fb9f343

Please sign in to comment.