From 4086ff11fa1238f521194b447549301e1bfb9ff8 Mon Sep 17 00:00:00 2001 From: CodingTil <36734749+CodingTil@users.noreply.github.com> Date: Mon, 30 Oct 2023 16:04:41 +0100 Subject: [PATCH] Fixed issues in training --- eiuie/fusion_model.py | 54 ++++++++++++++++++++++++++++++++++-------- eiuie/pixel_dataset.py | 46 ++++++++++++++++------------------- 2 files changed, 64 insertions(+), 36 deletions(-) diff --git a/eiuie/fusion_model.py b/eiuie/fusion_model.py index 3a91ec4..437997e 100644 --- a/eiuie/fusion_model.py +++ b/eiuie/fusion_model.py @@ -93,7 +93,9 @@ def save_checkpoint(self, val_loss, model): self.trace_func( f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..." ) - torch.save(model.state_dict(), self.path) + if not os.path.exists(CHECKPOINT_DIRECTORY): + os.makedirs(CHECKPOINT_DIRECTORY) + torch.save(model.state_dict(), f"{CHECKPOINT_DIRECTORY}/{self.path}") self.val_loss_min = val_loss @@ -144,6 +146,9 @@ def __init__( self.load_checkpoint(latest_checkpoint) def save_checkpoint(self, epoch: int, checkpoint_path: str): + # Ensure checkpoint directory exists + if not os.path.exists(CHECKPOINT_DIRECTORY): + os.makedirs(CHECKPOINT_DIRECTORY) torch.save( { "epoch": epoch, @@ -171,11 +176,16 @@ def _get_latest_checkpoint(self) -> Optional[str]: if not os.path.exists(CHECKPOINT_DIRECTORY): return None checkpoint_files = [ - f for f in os.listdir(CHECKPOINT_DIRECTORY) if "checkpoint_epoch_" in f + f + for f in os.listdir(CHECKPOINT_DIRECTORY) + if "checkpoint_epoch_" in f or "best_model" in f ] if not checkpoint_files: return None + if "best_model" in checkpoint_files: + return "best_model.pt" + # Sort based on epoch number checkpoint_files.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) @@ -242,23 +252,47 @@ def process_image(self, image: np.ndarray) -> np.ndarray: def train_model( self, - total_epochs=100, + total_epochs=50, patience=5, + data_to_use=0.005, train_ratio=0.8, + batch_size=1024, ): - dataset = pxds.PixelDataset() + print("Loading dataset...") + dataset = pxds.PixelDataset(batch_size=batch_size) # Splitting dataset into training and validation subsets - train_size = int(train_ratio * len(dataset)) - val_size = len(dataset) - train_size - train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + print("Splitting dataset into training and validation subsets...") + data_len = int(data_to_use * len(dataset)) + print("Data points to use:", data_len * batch_size) + train_size = int(train_ratio * data_len) + print("Training data points:", train_size * batch_size) + val_size = data_len - train_size + print("Validation data points:", val_size * batch_size) + _, train_dataset, val_dataset = random_split( + dataset, + [len(dataset) - data_len, train_size, val_size], + generator=torch.Generator().manual_seed(42), + ) - train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) - val_loader = DataLoader(val_dataset, batch_size=32) + train_loader = DataLoader( + train_dataset, + batch_size=1, + shuffle=False, + num_workers=0, + pin_memory=True, + ) + val_loader = DataLoader( + val_dataset, + batch_size=1, + shuffle=False, + num_workers=0, + pin_memory=True, + ) early_stopping = EarlyStopping( patience=patience, verbose=True, - path=f"{CHECKPOINT_DIRECTORY}/best_model.pt", + path=f"best_model.pt", ) self.net.train() diff --git a/eiuie/pixel_dataset.py b/eiuie/pixel_dataset.py index 86ff052..6ef5dbf 100644 --- a/eiuie/pixel_dataset.py +++ b/eiuie/pixel_dataset.py @@ -1,43 +1,37 @@ from typing import Tuple import numpy as np -import pandas as pd import torch from torch.utils.data import Dataset +import base_model as bm FILE = "data/pixel_dataset.ds" class PixelDataset(Dataset): - """ - PixelDataset class. + def __init__(self, batch_size=1): + # Use numpy's memory mapping + raw_data = np.memmap(FILE, dtype=np.uint8, mode="r").reshape(-1, 15) - Attributes - ---------- - df: pd.DataFrame - Dataframe. - """ + # Convert each set of BGR values to HSI + hsi_data_list = [] + for i in range(0, raw_data.shape[1], 3): + bgr_img = raw_data[:, i : i + 3].reshape(-1, 1, 3) + hsi_img = bm.BGR2HSI(bgr_img) + hsi_data_list.append(hsi_img.reshape(-1, 3)) - df: pd.DataFrame - - def __init__(self): - # Load binary data - with open(FILE, "rb") as f: - raw_data = f.read() - - # Convert binary data to a numpy array of shape (num_rows, 15) - data_array = np.frombuffer(raw_data, dtype=np.uint8).reshape(-1, 15) - - # Convert numpy array to pandas dataframe - self.df = pd.DataFrame(data_array) + self.data_array = np.concatenate(hsi_data_list, axis=1) + self.batch_size = batch_size def __len__(self) -> int: - return len(self.df) + return len(self.data_array) // self.batch_size def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]: - row = self.df.iloc[idx].values + start = idx * self.batch_size + end = start + self.batch_size + + batch_data = self.data_array[start:end] - # Splitting the 15 values into two tensors: first 12 and last 3. - input_tensor = torch.tensor(row[:12], dtype=torch.float32) - output_tensor = torch.tensor(row[12:], dtype=torch.float32) + inputs = torch.tensor(batch_data[:, :12], dtype=torch.float32) + outputs = torch.tensor(batch_data[:, 12:], dtype=torch.float32) - return input_tensor, output_tensor + return inputs, outputs