From f2c76c2751b054ff26cf61a0589ca8b5490980df Mon Sep 17 00:00:00 2001 From: CodingTil <36734749+CodingTil@users.noreply.github.com> Date: Tue, 31 Oct 2023 13:56:32 +0100 Subject: [PATCH] Batch optimization --- eiuie/fusion_model.py | 33 +++++++++++++++++++++------------ eiuie/pixel_dataset.py | 42 ++++++++++++++++++++++++++++++++---------- 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/eiuie/fusion_model.py b/eiuie/fusion_model.py index 157d62e..716cf79 100644 --- a/eiuie/fusion_model.py +++ b/eiuie/fusion_model.py @@ -278,14 +278,22 @@ def process_image(self, image: np.ndarray) -> np.ndarray: def train_model( self, - total_epochs=50, - patience=5, - data_to_use=0.005, - train_ratio=0.8, - batch_size=1024, + total_epochs: int = 50, + patience: int = 5, + data_to_use: float = 0.005, + train_ratio: float = 0.8, + batch_size: int = 1024, + pre_shuffle: bool = False, + shuffle: bool = True, ): print("Loading dataset...") - dataset = pxds.PixelDataset(use_fraction=data_to_use, use_exposures="both") + dataset = pxds.PixelDataset( + use_fraction=data_to_use, + use_exposures="both", + batch_size=batch_size, + pre_shuffle=pre_shuffle, + shuffle_batch=shuffle, + ) # Splitting dataset into training and validation subsets print("Splitting dataset into training and validation subsets...") data_len = len(dataset) @@ -302,17 +310,18 @@ def train_model( train_loader = DataLoader( train_dataset, - batch_sampler=BatchSampler( - RandomSampler(train_dataset), batch_size=batch_size, drop_last=False - ), + batch_size=1, + shuffle=shuffle, + collate_fn=lambda x: x[0], num_workers=0, pin_memory=True, ) + val_loader = DataLoader( val_dataset, - batch_sampler=BatchSampler( - RandomSampler(val_dataset), batch_size=batch_size, drop_last=False - ), + batch_size=1, + shuffle=False, + collate_fn=lambda x: x[0], num_workers=0, pin_memory=True, ) diff --git a/eiuie/pixel_dataset.py b/eiuie/pixel_dataset.py index 1cfd996..64c94ae 100644 --- a/eiuie/pixel_dataset.py +++ b/eiuie/pixel_dataset.py @@ -1,4 +1,5 @@ from typing import Tuple, Literal +import random import numpy as np import torch @@ -12,22 +13,27 @@ class PixelDataset(Dataset): def __init__( self, - chunk_size=10000, - use_fraction=1.0, + chunk_size: int = 10000, + use_fraction: float = 0.5, use_exposures: Literal["high", "low", "both"] = "both", + pre_shuffle: int = False, + shuffle_batch: int = True, + batch_size: int = 1, ): # Ensure use_fraction is within valid bounds use_fraction = max(0.0, min(1.0, use_fraction)) # Use numpy's memory mapping + f = (0.25 + random.random() * 0.5) * use_fraction + fractions = [f, use_fraction - f] raw_data_low = np.memmap(FILE_LOW, dtype=np.uint8, mode="r").reshape(-1, 15) if use_fraction < 1.0: - n_samples = int(len(raw_data_low) * use_fraction) + n_samples = int(len(raw_data_low) * fractions[0]) idxs = np.random.choice(len(raw_data_low), n_samples, replace=False) raw_data_low = raw_data_low[idxs] raw_data_high = np.memmap(FILE_HIGH, dtype=np.uint8, mode="r").reshape(-1, 15) if use_fraction < 1.0: - n_samples = int(len(raw_data_high) * use_fraction) + n_samples = int(len(raw_data_high) * fractions[1]) idxs = np.random.choice(len(raw_data_high), n_samples, replace=False) raw_data_high = raw_data_high[idxs] @@ -36,8 +42,10 @@ def __init__( match use_exposures: case "high": raw_data = raw_data_high + del raw_data_low case "low": raw_data = raw_data_low + del raw_data_high case "both": raw_data = np.concatenate((raw_data_low, raw_data_high), axis=0) @@ -58,14 +66,28 @@ def __init__( data_array[start_idx:end_idx] = np.concatenate(hsi_data_list, axis=1) # Shuffle data_array - np.random.shuffle(data_array) + if pre_shuffle: + np.random.shuffle(data_array) - self.data_array = torch.from_numpy(data_array) + self.data_array = data_array + + self.shuffle_batch = shuffle_batch + self.batch_size = batch_size def __len__(self) -> int: - return len(self.data_array) + return len(self.data_array) // self.batch_size + + def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]: + start = idx * self.batch_size + end = start + self.batch_size + + batch_data = self.data_array[start:end] + + # Shuffle batch_data + if self.shuffle_batch: + np.random.shuffle(batch_data) + + inputs = torch.tensor(batch_data[:, :12], dtype=torch.float32) + outputs = torch.tensor(batch_data[:, 12:], dtype=torch.float32) - def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor]: - inputs = self.data_array[index, :12] - outputs = self.data_array[index, 12:] return inputs, outputs