Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
CodingTil committed Oct 30, 2023
1 parent b1d8556 commit 61c9543
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
6 changes: 2 additions & 4 deletions eiuie/fusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def forward(self, x):


class EarlyStopping:
def __init__(
self, patience=5, verbose=False, delta=0, trace_func=print
):
def __init__(self, patience=5, verbose=False, delta=0, trace_func=print):
"""
Parameters
----------
Expand Down Expand Up @@ -170,7 +168,7 @@ def _get_latest_checkpoint(self) -> Optional[str]:

if "best_model.pt" in checkpoint_files:
return "best_model.pt"

# Sort based on epoch number
checkpoint_files.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))

Expand Down
15 changes: 8 additions & 7 deletions eiuie/pixel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,37 @@

FILE = "data/pixel_dataset.ds"


class PixelDataset(Dataset):
def __init__(self, batch_size=1, chunk_size=10000, use_fraction=1.0):
# Ensure use_fraction is within valid bounds
use_fraction = max(0.0, min(1.0, use_fraction))

# Use numpy's memory mapping
raw_data = np.memmap(FILE, dtype=np.uint8, mode="r").reshape(-1, 15)

# Randomly select a fraction of the data if use_fraction < 1.0
if use_fraction < 1.0:
n_samples = int(len(raw_data) * use_fraction)
idxs = np.random.choice(len(raw_data), n_samples, replace=False)
raw_data = raw_data[idxs]

self.data_array = np.zeros_like(raw_data, dtype=np.float32)
n_rows = raw_data.shape[0]

# Convert each set of BGR values to HSI in chunks
for start_idx in range(0, n_rows, chunk_size):
end_idx = min(start_idx + chunk_size, n_rows)
chunk_bgr = raw_data[start_idx:end_idx]

hsi_data_list = []
for i in range(0, chunk_bgr.shape[1], 3):
bgr_img = chunk_bgr[:, i:i+3].reshape(-1, 1, 3)
bgr_img = chunk_bgr[:, i : i + 3].reshape(-1, 1, 3)
hsi_img = bm.BGR2HSI(bgr_img)
hsi_data_list.append(hsi_img.reshape(-1, 3))

self.data_array[start_idx:end_idx] = np.concatenate(hsi_data_list, axis=1)

self.batch_size = batch_size

def __len__(self) -> int:
Expand Down

0 comments on commit 61c9543

Please sign in to comment.