Skip to content

Commit

Permalink
Batch optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
CodingTil committed Oct 31, 2023
1 parent 32d9359 commit f2c76c2
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 22 deletions.
33 changes: 21 additions & 12 deletions eiuie/fusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,22 @@ def process_image(self, image: np.ndarray) -> np.ndarray:

def train_model(
self,
total_epochs=50,
patience=5,
data_to_use=0.005,
train_ratio=0.8,
batch_size=1024,
total_epochs: int = 50,
patience: int = 5,
data_to_use: float = 0.005,
train_ratio: float = 0.8,
batch_size: int = 1024,
pre_shuffle: bool = False,
shuffle: bool = True,
):
print("Loading dataset...")
dataset = pxds.PixelDataset(use_fraction=data_to_use, use_exposures="both")
dataset = pxds.PixelDataset(
use_fraction=data_to_use,
use_exposures="both",
batch_size=batch_size,
pre_shuffle=pre_shuffle,
shuffle_batch=shuffle,
)
# Splitting dataset into training and validation subsets
print("Splitting dataset into training and validation subsets...")
data_len = len(dataset)
Expand All @@ -302,17 +310,18 @@ def train_model(

train_loader = DataLoader(
train_dataset,
batch_sampler=BatchSampler(
RandomSampler(train_dataset), batch_size=batch_size, drop_last=False
),
batch_size=1,
shuffle=shuffle,
collate_fn=lambda x: x[0],
num_workers=0,
pin_memory=True,
)

val_loader = DataLoader(
val_dataset,
batch_sampler=BatchSampler(
RandomSampler(val_dataset), batch_size=batch_size, drop_last=False
),
batch_size=1,
shuffle=False,
collate_fn=lambda x: x[0],
num_workers=0,
pin_memory=True,
)
Expand Down
42 changes: 32 additions & 10 deletions eiuie/pixel_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Tuple, Literal
import random

import numpy as np
import torch
Expand All @@ -12,22 +13,27 @@
class PixelDataset(Dataset):
def __init__(
self,
chunk_size=10000,
use_fraction=1.0,
chunk_size: int = 10000,
use_fraction: float = 0.5,
use_exposures: Literal["high", "low", "both"] = "both",
pre_shuffle: int = False,
shuffle_batch: int = True,
batch_size: int = 1,
):
# Ensure use_fraction is within valid bounds
use_fraction = max(0.0, min(1.0, use_fraction))

# Use numpy's memory mapping
f = (0.25 + random.random() * 0.5) * use_fraction
fractions = [f, use_fraction - f]
raw_data_low = np.memmap(FILE_LOW, dtype=np.uint8, mode="r").reshape(-1, 15)
if use_fraction < 1.0:
n_samples = int(len(raw_data_low) * use_fraction)
n_samples = int(len(raw_data_low) * fractions[0])
idxs = np.random.choice(len(raw_data_low), n_samples, replace=False)
raw_data_low = raw_data_low[idxs]
raw_data_high = np.memmap(FILE_HIGH, dtype=np.uint8, mode="r").reshape(-1, 15)
if use_fraction < 1.0:
n_samples = int(len(raw_data_high) * use_fraction)
n_samples = int(len(raw_data_high) * fractions[1])
idxs = np.random.choice(len(raw_data_high), n_samples, replace=False)
raw_data_high = raw_data_high[idxs]

Expand All @@ -36,8 +42,10 @@ def __init__(
match use_exposures:
case "high":
raw_data = raw_data_high
del raw_data_low
case "low":
raw_data = raw_data_low
del raw_data_high
case "both":
raw_data = np.concatenate((raw_data_low, raw_data_high), axis=0)

Expand All @@ -58,14 +66,28 @@ def __init__(
data_array[start_idx:end_idx] = np.concatenate(hsi_data_list, axis=1)

# Shuffle data_array
np.random.shuffle(data_array)
if pre_shuffle:
np.random.shuffle(data_array)

self.data_array = torch.from_numpy(data_array)
self.data_array = data_array

self.shuffle_batch = shuffle_batch
self.batch_size = batch_size

def __len__(self) -> int:
return len(self.data_array)
return len(self.data_array) // self.batch_size

def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
start = idx * self.batch_size
end = start + self.batch_size

batch_data = self.data_array[start:end]

# Shuffle batch_data
if self.shuffle_batch:
np.random.shuffle(batch_data)

inputs = torch.tensor(batch_data[:, :12], dtype=torch.float32)
outputs = torch.tensor(batch_data[:, 12:], dtype=torch.float32)

def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor]:
inputs = self.data_array[index, :12]
outputs = self.data_array[index, 12:]
return inputs, outputs

0 comments on commit f2c76c2

Please sign in to comment.