From 61c954303babb669fb91c5f4f1c8d04bc00c9cf7 Mon Sep 17 00:00:00 2001 From: CodingTil <36734749+CodingTil@users.noreply.github.com> Date: Mon, 30 Oct 2023 16:27:54 +0100 Subject: [PATCH] Formatting --- eiuie/fusion_model.py | 6 ++---- eiuie/pixel_dataset.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/eiuie/fusion_model.py b/eiuie/fusion_model.py index 8006477..b888ccc 100644 --- a/eiuie/fusion_model.py +++ b/eiuie/fusion_model.py @@ -43,9 +43,7 @@ def forward(self, x): class EarlyStopping: - def __init__( - self, patience=5, verbose=False, delta=0, trace_func=print - ): + def __init__(self, patience=5, verbose=False, delta=0, trace_func=print): """ Parameters ---------- @@ -170,7 +168,7 @@ def _get_latest_checkpoint(self) -> Optional[str]: if "best_model.pt" in checkpoint_files: return "best_model.pt" - + # Sort based on epoch number checkpoint_files.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) diff --git a/eiuie/pixel_dataset.py b/eiuie/pixel_dataset.py index 432d407..533ae76 100644 --- a/eiuie/pixel_dataset.py +++ b/eiuie/pixel_dataset.py @@ -7,20 +7,21 @@ FILE = "data/pixel_dataset.ds" + class PixelDataset(Dataset): def __init__(self, batch_size=1, chunk_size=10000, use_fraction=1.0): # Ensure use_fraction is within valid bounds use_fraction = max(0.0, min(1.0, use_fraction)) - + # Use numpy's memory mapping raw_data = np.memmap(FILE, dtype=np.uint8, mode="r").reshape(-1, 15) - + # Randomly select a fraction of the data if use_fraction < 1.0 if use_fraction < 1.0: n_samples = int(len(raw_data) * use_fraction) idxs = np.random.choice(len(raw_data), n_samples, replace=False) raw_data = raw_data[idxs] - + self.data_array = np.zeros_like(raw_data, dtype=np.float32) n_rows = raw_data.shape[0] @@ -28,15 +29,15 @@ def __init__(self, batch_size=1, chunk_size=10000, use_fraction=1.0): for start_idx in range(0, n_rows, chunk_size): end_idx = min(start_idx + chunk_size, n_rows) chunk_bgr = raw_data[start_idx:end_idx] - + hsi_data_list = [] for i in range(0, chunk_bgr.shape[1], 3): - bgr_img = chunk_bgr[:, i:i+3].reshape(-1, 1, 3) + bgr_img = chunk_bgr[:, i : i + 3].reshape(-1, 1, 3) hsi_img = bm.BGR2HSI(bgr_img) hsi_data_list.append(hsi_img.reshape(-1, 3)) - + self.data_array[start_idx:end_idx] = np.concatenate(hsi_data_list, axis=1) - + self.batch_size = batch_size def __len__(self) -> int: