diff --git a/data/.gitignore b/data/.gitignore index f0c8d3a..3c3e7b4 100644 --- a/data/.gitignore +++ b/data/.gitignore @@ -1,4 +1,5 @@ lol_dataset/ intermediate_images/ -pixel_dataset.ds +pixel_dataset/ +checkpoints/ diff --git a/eiuie/batch_process.py b/eiuie/batch_process.py index 61c6cb3..f7f6caa 100644 --- a/eiuie/batch_process.py +++ b/eiuie/batch_process.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Literal, Tuple from multiprocessing import Pool, Queue, Process, Manager, cpu_count import glob import os @@ -32,7 +32,11 @@ def write_to_file(queue: Queue) -> None: def process_and_enqueue( - models: List[bm.BaseModel], image_name: str, image_path: str, queue: Queue + models: List[bm.BaseModel], + exposure_type: Literal["high", "low"], + image_name: str, + image_path: str, + queue: Queue, ) -> None: """ Process the image using the models and enqueue the result for writing. @@ -50,14 +54,18 @@ def process_and_enqueue( """ image = cv2.imread(image_path) for model in models: - print(f"Processing {image_name} with {model.name}") + print(f"Processing {exposure_type}/{image_name} with {model.name}") processed_image = model.process_image(image) - absolute_file_name = f"{SAVE_LOCATION}/{model.name}/{image_name}.png" + absolute_file_name = ( + f"{SAVE_LOCATION}/{exposure_type}/{model.name}/{image_name}.png" + ) absolute_file_name = os.path.abspath(absolute_file_name) queue.put((absolute_file_name, processed_image)) -def batch_process(models: List[bm.BaseModel], images: Dict[str, str]) -> None: +def batch_process( + models: List[bm.BaseModel], images: Dict[Tuple[Literal["high", "low"], str], str] +) -> None: """ Batch process images using the model, and saves the results. @@ -81,10 +89,10 @@ def batch_process(models: List[bm.BaseModel], images: Dict[str, str]) -> None: # Create a pool for parallel processing async_results = [] # Collect all the AsyncResult objects here with Pool(cpu_count() - 1) as pool: - for image_name, image_path in images.items(): + for (exposure_type, image_name), image_path in images.items(): res = pool.apply_async( process_and_enqueue, - args=(models, image_name, image_path, write_queue), + args=(models, exposure_type, image_name, image_path, write_queue), ) async_results.append(res) @@ -103,11 +111,14 @@ def batch_process(models: List[bm.BaseModel], images: Dict[str, str]) -> None: def batch_process_dataset() -> None: - glob_pattern = "data/lol_dataset/*/low/*.png" + glob_pattern = "data/lol_dataset/*/*/*.png" images = glob.glob(glob_pattern) - images_dict = { - image.split("/")[-1].split(".")[0]: os.path.abspath(image) for image in images - } + images_dict: Dict[Tuple[Literal["high", "low"], str], str] = {} + for image in images: + exposure_type: str = image.split("/")[-2] + assert exposure_type in ["high", "low"] + image_name = image.split("/")[-1].split(".")[0] + images_dict[(exposure_type, image_name)] = os.path.abspath(image) print(f"Found {len(images_dict)} images") models = [ unsharp_masking.UnsharpMasking(), diff --git a/eiuie/consolidate_dataset.py b/eiuie/consolidate_dataset.py index 58b4c7d..ae5ae02 100644 --- a/eiuie/consolidate_dataset.py +++ b/eiuie/consolidate_dataset.py @@ -9,11 +9,12 @@ def prepare_dataset() -> None: - generator = __consolidate_data() - if os.path.exists(pxds.FILE): - os.remove(pxds.FILE) - with open(pxds.FILE, "wb") as file: - for data in generator: + generator_low = __consolidate_data_low() + if os.path.exists(pxds.FILE_LOW): + os.remove(pxds.FILE_LOW) + os.makedirs(os.path.dirname(pxds.FILE_LOW), exist_ok=True) + with open(pxds.FILE_LOW, "wb") as file: + for data in generator_low: combined = np.hstack( ( data["original"], @@ -26,8 +27,26 @@ def prepare_dataset() -> None: for row in combined: file.write(bytes(row)) + generator_high = __consolidate_data_high() + if os.path.exists(pxds.FILE_HIGH): + os.remove(pxds.FILE_HIGH) + os.makedirs(os.path.dirname(pxds.FILE_HIGH), exist_ok=True) + with open(pxds.FILE_HIGH, "wb") as file: + for data in generator_high: + combined = np.hstack( + ( + data["original"], + data["unsharp"], + data["homomorphic"], + data["retinex"], + data["ground_truth"], + ) + ) + for row in combined: + file.write(bytes(row)) -def __consolidate_data() -> ( + +def __consolidate_data_low() -> ( Generator[ Dict[ Literal["original", "retinex", "unsharp", "homomorphic", "ground_truth"], @@ -38,9 +57,9 @@ def __consolidate_data() -> ( ] ): # Path to intermediate images - path_retinex = "data/intermediate_images/retinex/" - path_unsharp = "data/intermediate_images/unsharp_masking/" - path_homomorphic = "data/intermediate_images/homomorphic_filtering/" + path_retinex = "data/intermediate_images/low/retinex/" + path_unsharp = "data/intermediate_images/low/unsharp_masking/" + path_homomorphic = "data/intermediate_images/low/homomorphic_filtering/" files = glob.glob("data/lol_dataset/*/low/*.png") @@ -78,3 +97,52 @@ def __consolidate_data() -> ( "ground_truth": image2D_ground_truth, } yield data + + +def __consolidate_data_high() -> ( + Generator[ + Dict[ + Literal["original", "retinex", "unsharp", "homomorphic", "ground_truth"], + np.ndarray, + ], + None, + None, + ] +): + # Path to intermediate images + path_retinex = "data/intermediate_images/high/retinex/" + path_unsharp = "data/intermediate_images/high/unsharp_masking/" + path_homomorphic = "data/intermediate_images/high/homomorphic_filtering/" + + files = glob.glob("data/lol_dataset/*/high/*.png") + + for image in files: + # read original image + image_original = cv2.imread(image) + + # extract image id + i = image.split("/")[-1].split(".")[0] + + # read corresponding intermediate images + image_retinex = cv2.imread(path_retinex + str(i) + ".png") + image_unsharp = cv2.imread(path_unsharp + str(i) + ".png") + image_homomorphic = cv2.imread(path_homomorphic + str(i) + ".png") + + # reshape image to 2D array + image2D_original = image_original.reshape(-1, 3) + image2D_retinex = image_retinex.reshape(-1, 3) + image2D_unsharp = image_unsharp.reshape(-1, 3) + image2D_homomorphic = image_homomorphic.reshape(-1, 3) + + # convert to single pandas dataframe + data: Dict[ + Literal["original", "retinex", "unsharp", "homomorphic", "ground_truth"], + np.ndarray, + ] = { + "original": image2D_original, + "retinex": image2D_retinex, + "unsharp": image2D_unsharp, + "homomorphic": image2D_homomorphic, + "ground_truth": image2D_original, + } + yield data diff --git a/eiuie/fusion_model.py b/eiuie/fusion_model.py index 72b029f..4495eb7 100644 --- a/eiuie/fusion_model.py +++ b/eiuie/fusion_model.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Callable, Any import os import numpy as np @@ -6,7 +6,8 @@ import torch.nn as nn import torch.optim as optim import torch.cuda as cuda -from torch.utils.data import Dataset, DataLoader, random_split +from torch.utils.data import DataLoader, random_split +from torch.utils.data.sampler import BatchSampler, RandomSampler import base_model as bm import unsharp_masking as um @@ -21,9 +22,9 @@ class ChannelNet(nn.Module): """Single layer perceptron for individual channels.""" - def __init__(self, input_size=4, output_size=1): + def __init__(self, input_size: int): super(ChannelNet, self).__init__() - self.fc = nn.Linear(input_size, output_size) + self.fc = nn.Linear(input_size, 1) def forward(self, x): return self.fc(x) @@ -32,20 +33,29 @@ def forward(self, x): class FusionNet(nn.Module): """Unifying model for all channels.""" - def __init__(self): + use_original: bool + + def __init__(self, use_original: bool): super(FusionNet, self).__init__() - self.h_net = ChannelNet() - self.s_net = ChannelNet() - self.i_net = ChannelNet() + self.use_original = use_original + self.h_net = ChannelNet(4 if use_original else 3) + self.s_net = ChannelNet(4 if use_original else 3) + self.i_net = ChannelNet(4 if use_original else 3) def forward(self, x): # Flatten the middle dimensions x = x.view(-1, 12) # This will reshape the input to (batch_size, 12) # Splitting the input for the three channels - h_channel = x[:, 0::3] # Every third value starting from index 0 - s_channel = x[:, 1::3] # Every third value starting from index 1 - i_channel = x[:, 2::3] # Every third value starting from index 2 + h_channel = x[ + :, 0 if self.use_original else 3 :: 3 + ] # Every third value starting from index 0 + s_channel = x[ + :, 1 if self.use_original else 4 :: 3 + ] # Every third value starting from index 1 + i_channel = x[ + :, 2 if self.use_original else 5 :: 3 + ] # Every third value starting from index 2 # Getting the outputs h_out = self.h_net(h_channel) @@ -57,7 +67,24 @@ def forward(self, x): class EarlyStopping: - def __init__(self, patience=5, verbose=False, delta=0, trace_func=print): + """Early stops the training if validation loss doesn't improve after a given patience.""" + + patience: int + verbose: bool + counter: int + best_score: Optional[float] + early_stop: bool + val_loss_min: float + delta: float + trace_func: Callable[[Any], None] + + def __init__( + self, + patience: int = 5, + verbose: bool = False, + delta: float = 0.0, + trace_func: Callable[[Any], None] = print, + ): """ Parameters ---------- @@ -79,7 +106,7 @@ def __init__(self, patience=5, verbose=False, delta=0, trace_func=print): self.delta = delta self.trace_func = trace_func - def __call__(self, val_loss, model): + def __call__(self, val_loss: float): score = -val_loss if self.best_score is None: @@ -130,7 +157,7 @@ def __init__( self.device = torch.device("cuda" if cuda.is_available() else "cpu") # Neural Network Model - self.net = FusionNet().to(self.device) + self.net = FusionNet(use_original=False).to(self.device) self.optimizer = optim.Adam(self.net.parameters()) self.criterion = nn.MSELoss() # assuming regression task self.start_epoch = 0 @@ -258,15 +285,15 @@ def train_model( batch_size=1024, ): print("Loading dataset...") - dataset = pxds.PixelDataset(batch_size=batch_size, use_fraction=data_to_use) + dataset = pxds.PixelDataset(use_fraction=data_to_use, use_exposures="both") # Splitting dataset into training and validation subsets print("Splitting dataset into training and validation subsets...") data_len = len(dataset) - print("Data points to use:", data_len * batch_size) + print("Data points to use:", data_len) train_size = int(train_ratio * data_len) - print("Training data points:", train_size * batch_size) + print("Training data points:", train_size) val_size = data_len - train_size - print("Validation data points:", val_size * batch_size) + print("Validation data points:", val_size) train_dataset, val_dataset = random_split( dataset, [train_size, val_size], @@ -275,15 +302,17 @@ def train_model( train_loader = DataLoader( train_dataset, - batch_size=1, - shuffle=False, + batch_sampler=BatchSampler( + RandomSampler(train_dataset), batch_size=batch_size, drop_last=False + ), num_workers=0, pin_memory=True, ) val_loader = DataLoader( val_dataset, - batch_size=1, - shuffle=False, + batch_sampler=BatchSampler( + RandomSampler(val_dataset), batch_size=batch_size, drop_last=False + ), num_workers=0, pin_memory=True, ) @@ -293,6 +322,8 @@ def train_model( verbose=True, ) + best_val_loss = float("inf") + self.net.train() for epoch in range(self.start_epoch, total_epochs): print() @@ -309,12 +340,16 @@ def train_model( val_loss = self.validate(val_loader) print(f"Validation loss: {val_loss}") + if val_loss < best_val_loss: + best_val_loss = val_loss + print("Saving best model...") + self.save_checkpoint(epoch, "best_model.pt") + print("Checking early stopping...") - early_stopping(val_loss, self.net) + early_stopping(val_loss) if early_stopping.early_stop: print("Early stopping") - self.save_checkpoint(epoch, "best_model.pt") break # Save checkpoint after every epoch diff --git a/eiuie/pixel_dataset.py b/eiuie/pixel_dataset.py index 2d7ecf4..008ae79 100644 --- a/eiuie/pixel_dataset.py +++ b/eiuie/pixel_dataset.py @@ -1,20 +1,37 @@ -from typing import Tuple +from typing import Tuple, Literal import numpy as np import torch from torch.utils.data import Dataset import base_model as bm -FILE = "data/pixel_dataset.ds" +FILE_LOW = "data/pixel_dataset/low.ds" +FILE_HIGH = "data/pixel_dataset/high.ds" class PixelDataset(Dataset): - def __init__(self, batch_size=1, chunk_size=10000, use_fraction=1.0): + def __init__( + self, + chunk_size=10000, + use_fraction=1.0, + use_exposures: Literal["high", "low", "both"] = "both", + ): # Ensure use_fraction is within valid bounds use_fraction = max(0.0, min(1.0, use_fraction)) # Use numpy's memory mapping - raw_data = np.memmap(FILE, dtype=np.uint8, mode="r").reshape(-1, 15) + raw_data_low = np.memmap(FILE_LOW, dtype=np.uint8, mode="r").reshape(-1, 15) + raw_data_high = np.memmap(FILE_HIGH, dtype=np.uint8, mode="r").reshape(-1, 15) + + # Select exposures to use + raw_data: np.ndarray + match use_exposures: + case "high": + raw_data = raw_data_high + case "low": + raw_data = raw_data_low + case "both": + raw_data = np.concatenate((raw_data_low, raw_data_high), axis=0) # Randomly select a fraction of the data if use_fraction < 1.0 if use_fraction < 1.0: @@ -22,7 +39,7 @@ def __init__(self, batch_size=1, chunk_size=10000, use_fraction=1.0): idxs = np.random.choice(len(raw_data), n_samples, replace=False) raw_data = raw_data[idxs] - self.data_array = np.zeros_like(raw_data, dtype=np.float32) + data_array = np.zeros_like(raw_data, dtype=np.float32) n_rows = raw_data.shape[0] # Convert each set of BGR values to HSI in chunks @@ -36,26 +53,17 @@ def __init__(self, batch_size=1, chunk_size=10000, use_fraction=1.0): hsi_img = bm.BGR2HSI(bgr_img) hsi_data_list.append(hsi_img.reshape(-1, 3)) - self.data_array[start_idx:end_idx] = np.concatenate(hsi_data_list, axis=1) + data_array[start_idx:end_idx] = np.concatenate(hsi_data_list, axis=1) # Shuffle data_array - np.random.shuffle(self.data_array) + np.random.shuffle(data_array) - self.batch_size = batch_size + self.data_array = torch.from_numpy(data_array) def __len__(self) -> int: - return len(self.data_array) // self.batch_size - - def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]: - start = idx * self.batch_size - end = start + self.batch_size - - batch_data = self.data_array[start:end] - - # Shuffle batch_data - np.random.shuffle(batch_data) - - inputs = torch.tensor(batch_data[:, :12], dtype=torch.float32) - outputs = torch.tensor(batch_data[:, 12:], dtype=torch.float32) + return len(self.data_array) + def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor]: + inputs = self.data_array[index, :12] + outputs = self.data_array[index, 12:] return inputs, outputs