Skip to content

Commit

Permalink
Cleanup & More options
Browse files Browse the repository at this point in the history
  • Loading branch information
CodingTil committed Oct 31, 2023
1 parent ed2a176 commit b03887d
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 66 deletions.
3 changes: 2 additions & 1 deletion data/.gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
lol_dataset/
intermediate_images/
pixel_dataset.ds
pixel_dataset/
checkpoints/

33 changes: 22 additions & 11 deletions eiuie/batch_process.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Literal, Tuple
from multiprocessing import Pool, Queue, Process, Manager, cpu_count
import glob
import os
Expand Down Expand Up @@ -32,7 +32,11 @@ def write_to_file(queue: Queue) -> None:


def process_and_enqueue(
models: List[bm.BaseModel], image_name: str, image_path: str, queue: Queue
models: List[bm.BaseModel],
exposure_type: Literal["high", "low"],
image_name: str,
image_path: str,
queue: Queue,
) -> None:
"""
Process the image using the models and enqueue the result for writing.
Expand All @@ -50,14 +54,18 @@ def process_and_enqueue(
"""
image = cv2.imread(image_path)
for model in models:
print(f"Processing {image_name} with {model.name}")
print(f"Processing {exposure_type}/{image_name} with {model.name}")
processed_image = model.process_image(image)
absolute_file_name = f"{SAVE_LOCATION}/{model.name}/{image_name}.png"
absolute_file_name = (
f"{SAVE_LOCATION}/{exposure_type}/{model.name}/{image_name}.png"
)
absolute_file_name = os.path.abspath(absolute_file_name)
queue.put((absolute_file_name, processed_image))


def batch_process(models: List[bm.BaseModel], images: Dict[str, str]) -> None:
def batch_process(
models: List[bm.BaseModel], images: Dict[Tuple[Literal["high", "low"], str], str]
) -> None:
"""
Batch process images using the model, and saves the results.
Expand All @@ -81,10 +89,10 @@ def batch_process(models: List[bm.BaseModel], images: Dict[str, str]) -> None:
# Create a pool for parallel processing
async_results = [] # Collect all the AsyncResult objects here
with Pool(cpu_count() - 1) as pool:
for image_name, image_path in images.items():
for (exposure_type, image_name), image_path in images.items():
res = pool.apply_async(
process_and_enqueue,
args=(models, image_name, image_path, write_queue),
args=(models, exposure_type, image_name, image_path, write_queue),
)
async_results.append(res)

Expand All @@ -103,11 +111,14 @@ def batch_process(models: List[bm.BaseModel], images: Dict[str, str]) -> None:


def batch_process_dataset() -> None:
glob_pattern = "data/lol_dataset/*/low/*.png"
glob_pattern = "data/lol_dataset/*/*/*.png"
images = glob.glob(glob_pattern)
images_dict = {
image.split("/")[-1].split(".")[0]: os.path.abspath(image) for image in images
}
images_dict: Dict[Tuple[Literal["high", "low"], str], str] = {}
for image in images:
exposure_type: str = image.split("/")[-2]
assert exposure_type in ["high", "low"]
image_name = image.split("/")[-1].split(".")[0]
images_dict[(exposure_type, image_name)] = os.path.abspath(image)
print(f"Found {len(images_dict)} images")
models = [
unsharp_masking.UnsharpMasking(),
Expand Down
86 changes: 77 additions & 9 deletions eiuie/consolidate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@


def prepare_dataset() -> None:
generator = __consolidate_data()
if os.path.exists(pxds.FILE):
os.remove(pxds.FILE)
with open(pxds.FILE, "wb") as file:
for data in generator:
generator_low = __consolidate_data_low()
if os.path.exists(pxds.FILE_LOW):
os.remove(pxds.FILE_LOW)
os.makedirs(os.path.dirname(pxds.FILE_LOW), exist_ok=True)
with open(pxds.FILE_LOW, "wb") as file:
for data in generator_low:
combined = np.hstack(
(
data["original"],
Expand All @@ -26,8 +27,26 @@ def prepare_dataset() -> None:
for row in combined:
file.write(bytes(row))

generator_high = __consolidate_data_high()
if os.path.exists(pxds.FILE_HIGH):
os.remove(pxds.FILE_HIGH)
os.makedirs(os.path.dirname(pxds.FILE_HIGH), exist_ok=True)
with open(pxds.FILE_HIGH, "wb") as file:
for data in generator_high:
combined = np.hstack(
(
data["original"],
data["unsharp"],
data["homomorphic"],
data["retinex"],
data["ground_truth"],
)
)
for row in combined:
file.write(bytes(row))

def __consolidate_data() -> (

def __consolidate_data_low() -> (
Generator[
Dict[
Literal["original", "retinex", "unsharp", "homomorphic", "ground_truth"],
Expand All @@ -38,9 +57,9 @@ def __consolidate_data() -> (
]
):
# Path to intermediate images
path_retinex = "data/intermediate_images/retinex/"
path_unsharp = "data/intermediate_images/unsharp_masking/"
path_homomorphic = "data/intermediate_images/homomorphic_filtering/"
path_retinex = "data/intermediate_images/low/retinex/"
path_unsharp = "data/intermediate_images/low/unsharp_masking/"
path_homomorphic = "data/intermediate_images/low/homomorphic_filtering/"

files = glob.glob("data/lol_dataset/*/low/*.png")

Expand Down Expand Up @@ -78,3 +97,52 @@ def __consolidate_data() -> (
"ground_truth": image2D_ground_truth,
}
yield data


def __consolidate_data_high() -> (
Generator[
Dict[
Literal["original", "retinex", "unsharp", "homomorphic", "ground_truth"],
np.ndarray,
],
None,
None,
]
):
# Path to intermediate images
path_retinex = "data/intermediate_images/high/retinex/"
path_unsharp = "data/intermediate_images/high/unsharp_masking/"
path_homomorphic = "data/intermediate_images/high/homomorphic_filtering/"

files = glob.glob("data/lol_dataset/*/high/*.png")

for image in files:
# read original image
image_original = cv2.imread(image)

# extract image id
i = image.split("/")[-1].split(".")[0]

# read corresponding intermediate images
image_retinex = cv2.imread(path_retinex + str(i) + ".png")
image_unsharp = cv2.imread(path_unsharp + str(i) + ".png")
image_homomorphic = cv2.imread(path_homomorphic + str(i) + ".png")

# reshape image to 2D array
image2D_original = image_original.reshape(-1, 3)
image2D_retinex = image_retinex.reshape(-1, 3)
image2D_unsharp = image_unsharp.reshape(-1, 3)
image2D_homomorphic = image_homomorphic.reshape(-1, 3)

# convert to single pandas dataframe
data: Dict[
Literal["original", "retinex", "unsharp", "homomorphic", "ground_truth"],
np.ndarray,
] = {
"original": image2D_original,
"retinex": image2D_retinex,
"unsharp": image2D_unsharp,
"homomorphic": image2D_homomorphic,
"ground_truth": image2D_original,
}
yield data
83 changes: 59 additions & 24 deletions eiuie/fusion_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional
from typing import Optional, Callable, Any
import os

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda as cuda
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.data import DataLoader, random_split
from torch.utils.data.sampler import BatchSampler, RandomSampler

import base_model as bm
import unsharp_masking as um
Expand All @@ -21,9 +22,9 @@
class ChannelNet(nn.Module):
"""Single layer perceptron for individual channels."""

def __init__(self, input_size=4, output_size=1):
def __init__(self, input_size: int):
super(ChannelNet, self).__init__()
self.fc = nn.Linear(input_size, output_size)
self.fc = nn.Linear(input_size, 1)

def forward(self, x):
return self.fc(x)
Expand All @@ -32,20 +33,29 @@ def forward(self, x):
class FusionNet(nn.Module):
"""Unifying model for all channels."""

def __init__(self):
use_original: bool

def __init__(self, use_original: bool):
super(FusionNet, self).__init__()
self.h_net = ChannelNet()
self.s_net = ChannelNet()
self.i_net = ChannelNet()
self.use_original = use_original
self.h_net = ChannelNet(4 if use_original else 3)
self.s_net = ChannelNet(4 if use_original else 3)
self.i_net = ChannelNet(4 if use_original else 3)

def forward(self, x):
# Flatten the middle dimensions
x = x.view(-1, 12) # This will reshape the input to (batch_size, 12)

# Splitting the input for the three channels
h_channel = x[:, 0::3] # Every third value starting from index 0
s_channel = x[:, 1::3] # Every third value starting from index 1
i_channel = x[:, 2::3] # Every third value starting from index 2
h_channel = x[
:, 0 if self.use_original else 3 :: 3
] # Every third value starting from index 0
s_channel = x[
:, 1 if self.use_original else 4 :: 3
] # Every third value starting from index 1
i_channel = x[
:, 2 if self.use_original else 5 :: 3
] # Every third value starting from index 2

# Getting the outputs
h_out = self.h_net(h_channel)
Expand All @@ -57,7 +67,24 @@ def forward(self, x):


class EarlyStopping:
def __init__(self, patience=5, verbose=False, delta=0, trace_func=print):
"""Early stops the training if validation loss doesn't improve after a given patience."""

patience: int
verbose: bool
counter: int
best_score: Optional[float]
early_stop: bool
val_loss_min: float
delta: float
trace_func: Callable[[Any], None]

def __init__(
self,
patience: int = 5,
verbose: bool = False,
delta: float = 0.0,
trace_func: Callable[[Any], None] = print,
):
"""
Parameters
----------
Expand All @@ -79,7 +106,7 @@ def __init__(self, patience=5, verbose=False, delta=0, trace_func=print):
self.delta = delta
self.trace_func = trace_func

def __call__(self, val_loss, model):
def __call__(self, val_loss: float):
score = -val_loss

if self.best_score is None:
Expand Down Expand Up @@ -130,7 +157,7 @@ def __init__(
self.device = torch.device("cuda" if cuda.is_available() else "cpu")

# Neural Network Model
self.net = FusionNet().to(self.device)
self.net = FusionNet(use_original=False).to(self.device)
self.optimizer = optim.Adam(self.net.parameters())
self.criterion = nn.MSELoss() # assuming regression task
self.start_epoch = 0
Expand Down Expand Up @@ -258,15 +285,15 @@ def train_model(
batch_size=1024,
):
print("Loading dataset...")
dataset = pxds.PixelDataset(batch_size=batch_size, use_fraction=data_to_use)
dataset = pxds.PixelDataset(use_fraction=data_to_use, use_exposures="both")
# Splitting dataset into training and validation subsets
print("Splitting dataset into training and validation subsets...")
data_len = len(dataset)
print("Data points to use:", data_len * batch_size)
print("Data points to use:", data_len)
train_size = int(train_ratio * data_len)
print("Training data points:", train_size * batch_size)
print("Training data points:", train_size)
val_size = data_len - train_size
print("Validation data points:", val_size * batch_size)
print("Validation data points:", val_size)
train_dataset, val_dataset = random_split(
dataset,
[train_size, val_size],
Expand All @@ -275,15 +302,17 @@ def train_model(

train_loader = DataLoader(
train_dataset,
batch_size=1,
shuffle=False,
batch_sampler=BatchSampler(
RandomSampler(train_dataset), batch_size=batch_size, drop_last=False
),
num_workers=0,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
batch_sampler=BatchSampler(
RandomSampler(val_dataset), batch_size=batch_size, drop_last=False
),
num_workers=0,
pin_memory=True,
)
Expand All @@ -293,6 +322,8 @@ def train_model(
verbose=True,
)

best_val_loss = float("inf")

self.net.train()
for epoch in range(self.start_epoch, total_epochs):
print()
Expand All @@ -309,12 +340,16 @@ def train_model(
val_loss = self.validate(val_loader)
print(f"Validation loss: {val_loss}")

if val_loss < best_val_loss:
best_val_loss = val_loss
print("Saving best model...")
self.save_checkpoint(epoch, "best_model.pt")

print("Checking early stopping...")
early_stopping(val_loss, self.net)
early_stopping(val_loss)

if early_stopping.early_stop:
print("Early stopping")
self.save_checkpoint(epoch, "best_model.pt")
break

# Save checkpoint after every epoch
Expand Down
Loading

0 comments on commit b03887d

Please sign in to comment.