From 4cc5d7caef3481c75c3acbc5dce4122d3aff77fe Mon Sep 17 00:00:00 2001 From: Minsoo Patrick Kang Date: Sun, 14 Jan 2024 18:57:06 -0800 Subject: [PATCH] few updates --- analog/analog.py | 7 +++-- analog/logging/log_loader.py | 43 ++++++++++++++++++++++++++++++- analog/logging/log_loader_util.py | 2 +- analog/logging/mmap.py | 1 - 4 files changed, 48 insertions(+), 5 deletions(-) diff --git a/analog/analog.py b/analog/analog.py index b95f4a9f..b8281ec2 100644 --- a/analog/analog.py +++ b/analog/analog.py @@ -240,7 +240,7 @@ def build_log_dataset(self): """ Constructs the log dataset from the storage handler. """ - return LogDataset(log_dir=self.log_dir) + return LogDataset(log_dir=self.log_dir, config=self.config) def build_log_dataloader( self, batch_size: int = 16, num_workers: int = 0, pin_memory: bool = False @@ -249,13 +249,16 @@ def build_log_dataloader( Constructs the log dataloader from the storage handler. """ log_dataset = self.build_log_dataset() + collate_fn = None + if not self.config.logging_config["flatten"]: + collate_fn = collate_nested_dicts log_dataloader = torch.utils.data.DataLoader( log_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False, - collate_fn=collate_nested_dicts, + collate_fn=collate_fn, ) return log_dataloader diff --git a/analog/logging/log_loader.py b/analog/logging/log_loader.py index 4fe6a341..8eb63ae7 100644 --- a/analog/logging/log_loader.py +++ b/analog/logging/log_loader.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from functools import reduce import numpy as np import torch from torch.utils.data import Dataset @@ -11,11 +12,13 @@ class LogDataset(Dataset): - def __init__(self, log_dir): + def __init__(self, log_dir, config): self.chunk_indices = None + # TODO(@eatpk): Change to pre-set sized array. self.memmaps = [] self.data_id_to_chunk = OrderedDict() self.log_dir = log_dir + self.logging_config = config.logging_config # Find all chunk indices self.chunk_indices = find_chunk_indices(self.log_dir) @@ -36,12 +39,30 @@ def fetch_data(self): idx, ) + if self.logging_config["flatten"]: + # paths and path_size must equal for all data_ids. + self.paths = [] + self.path_size = [] + self.path_shape = [] + + key = next(iter(self.data_id_to_chunk)) + _, entries = self.data_id_to_chunk[key] + + self.dtype = np.dtype(entries[0]["dtype"]) + for entry in entries: + self.paths.append(entry["path"]) + self.path_size.append(reduce(lambda x, y: x * y, entry["shape"])) + self.path_shape.append(entry["shape"]) + def __getitem__(self, index): data_id = list(self.data_id_to_chunk.keys())[index] chunk_idx, entries = self.data_id_to_chunk[data_id] nested_dict = {} mmap = self.memmaps[chunk_idx] + if self.logging_config["flatten"]: + return data_id, self._get_flatten_item(mmap, index) + for entry in entries: # Read the data and put it into the nested dictionary path = entry["path"] @@ -58,8 +79,28 @@ def __getitem__(self, index): current_level[key] = {} current_level = current_level[key] current_level[path[-1]] = tensor + return data_id, nested_dict + def _get_flatten_item(self, mmap, index): + block_size = reduce(lambda x, y: x + y, self.path_size) + arr = [] + offset = index * block_size + block_offset = 0 + for i in range(len(self.path_size)): + array = np.ndarray( + self.path_shape[i], + self.dtype, + buffer=mmap, + offset=offset * np.dtype(self.dtype).itemsize, + order="C", + ) + offset += self.path_size[i] + arr.append(torch.from_numpy(array)) + block_offset += self.path_size[i] + + return arr + def __len__(self): return len(self.data_id_to_chunk) diff --git a/analog/logging/log_loader_util.py b/analog/logging/log_loader_util.py index e7a13d98..022dfaa5 100644 --- a/analog/logging/log_loader_util.py +++ b/analog/logging/log_loader_util.py @@ -49,7 +49,7 @@ def get_mmap_data(path, mmap_filename, dtype="uint8") -> List: mmap_filename (str): Filename of the mmap file. Returns: - List: A list of memory maps and an ordered dictionary mapping data_ids to chunks. + List: A list of memory maps. """ with MemoryMapHandler.read(path, mmap_filename, dtype) as mm: return mm diff --git a/analog/logging/mmap.py b/analog/logging/mmap.py index 8c1cdc75..208e842f 100644 --- a/analog/logging/mmap.py +++ b/analog/logging/mmap.py @@ -65,7 +65,6 @@ def read(path, filename, dtype="uint8"): filename (str): filename for the path to mmap. Returns: mmap (np.mmap): memory mapped buffer read from filename. - metadata (json): """ _, file_ext = os.path.splitext(filename) if file_ext == "":