Skip to content

Commit

Permalink
few updates
Browse files Browse the repository at this point in the history
  • Loading branch information
eatpk committed Jan 15, 2024
1 parent 35af02d commit 4cc5d7c
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 5 deletions.
7 changes: 5 additions & 2 deletions analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def build_log_dataset(self):
"""
Constructs the log dataset from the storage handler.
"""
return LogDataset(log_dir=self.log_dir)
return LogDataset(log_dir=self.log_dir, config=self.config)

def build_log_dataloader(
self, batch_size: int = 16, num_workers: int = 0, pin_memory: bool = False
Expand All @@ -249,13 +249,16 @@ def build_log_dataloader(
Constructs the log dataloader from the storage handler.
"""
log_dataset = self.build_log_dataset()
collate_fn = None
if not self.config.logging_config["flatten"]:
collate_fn = collate_nested_dicts
log_dataloader = torch.utils.data.DataLoader(
log_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=False,
collate_fn=collate_nested_dicts,
collate_fn=collate_fn,
)
return log_dataloader

Expand Down
43 changes: 42 additions & 1 deletion analog/logging/log_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import OrderedDict
from functools import reduce
import numpy as np
import torch
from torch.utils.data import Dataset
Expand All @@ -11,11 +12,13 @@


class LogDataset(Dataset):
def __init__(self, log_dir):
def __init__(self, log_dir, config):
self.chunk_indices = None
# TODO(@eatpk): Change to pre-set sized array.
self.memmaps = []
self.data_id_to_chunk = OrderedDict()
self.log_dir = log_dir
self.logging_config = config.logging_config

# Find all chunk indices
self.chunk_indices = find_chunk_indices(self.log_dir)
Expand All @@ -36,12 +39,30 @@ def fetch_data(self):
idx,
)

if self.logging_config["flatten"]:
# paths and path_size must equal for all data_ids.
self.paths = []
self.path_size = []
self.path_shape = []

key = next(iter(self.data_id_to_chunk))
_, entries = self.data_id_to_chunk[key]

self.dtype = np.dtype(entries[0]["dtype"])
for entry in entries:
self.paths.append(entry["path"])
self.path_size.append(reduce(lambda x, y: x * y, entry["shape"]))
self.path_shape.append(entry["shape"])

def __getitem__(self, index):
data_id = list(self.data_id_to_chunk.keys())[index]
chunk_idx, entries = self.data_id_to_chunk[data_id]
nested_dict = {}
mmap = self.memmaps[chunk_idx]

if self.logging_config["flatten"]:
return data_id, self._get_flatten_item(mmap, index)

for entry in entries:
# Read the data and put it into the nested dictionary
path = entry["path"]
Expand All @@ -58,8 +79,28 @@ def __getitem__(self, index):
current_level[key] = {}
current_level = current_level[key]
current_level[path[-1]] = tensor

return data_id, nested_dict

def _get_flatten_item(self, mmap, index):
block_size = reduce(lambda x, y: x + y, self.path_size)
arr = []
offset = index * block_size
block_offset = 0
for i in range(len(self.path_size)):
array = np.ndarray(
self.path_shape[i],
self.dtype,
buffer=mmap,
offset=offset * np.dtype(self.dtype).itemsize,
order="C",
)
offset += self.path_size[i]
arr.append(torch.from_numpy(array))
block_offset += self.path_size[i]

return arr

def __len__(self):
return len(self.data_id_to_chunk)

Expand Down
2 changes: 1 addition & 1 deletion analog/logging/log_loader_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_mmap_data(path, mmap_filename, dtype="uint8") -> List:
mmap_filename (str): Filename of the mmap file.
Returns:
List: A list of memory maps and an ordered dictionary mapping data_ids to chunks.
List: A list of memory maps.
"""
with MemoryMapHandler.read(path, mmap_filename, dtype) as mm:
return mm
Expand Down
1 change: 0 additions & 1 deletion analog/logging/mmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def read(path, filename, dtype="uint8"):
filename (str): filename for the path to mmap.
Returns:
mmap (np.mmap): memory mapped buffer read from filename.
metadata (json):
"""
_, file_ext = os.path.splitext(filename)
if file_ext == "":
Expand Down

0 comments on commit 4cc5d7c

Please sign in to comment.