Skip to content

Commit

Permalink
Flattening the tensors before unflattening (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
eatpk authored Jun 4, 2024
1 parent 1ca5ec9 commit 91329c0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
26 changes: 9 additions & 17 deletions logix/logging/log_loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
from collections import OrderedDict
from functools import reduce

import numpy as np
import torch
from torch.utils.data import Dataset

from logix.logging.log_loader_utils import (
get_entry_metadata,
unflatten_tensor,
get_flatten_item,
get_mmap_data,
get_mmap_metadata,
Expand Down Expand Up @@ -49,29 +45,25 @@ def __getitem__(self, index):
nested_dict = {}
mmap = self.memmaps[chunk_idx]
offset = entry["offset"]
flat_tensor = get_flatten_item(
mmap, offset, entry["block_size"], entry["dtype"]
)
if self.flatten:
return data_id, get_flatten_item(
mmap, offset, entry["block_size"], entry["dtype"]
)
dtype = entry["dtype"]
return data_id, flat_tensor
start = 0
for i in range(len(entry["path"])):
path = entry["path"][i]
shape = tuple(entry["shape"][i])
tensor = torch.from_numpy(
np.ndarray(shape, dtype, buffer=mmap, offset=offset, order="C")
).clone()

tensor, start = unflatten_tensor(flat_tensor, shape, start)
current_level = nested_dict
for key in path[:-1]:
if key not in current_level:
current_level[key] = {}
current_level = current_level[key]
current_level[path[-1]] = tensor
offset += reduce(lambda x, y: x * y, shape) * np.dtype(dtype).itemsize

assert (
offset == entry["offset"] + entry["block_size"] * np.dtype(dtype).itemsize
), f"the block_size does not match the shape for data_id: {entry['data_id']}"
entry["block_size"] == start
), f"block_size does not match with the shape for data_id: {entry['data_id']}"
return data_id, nested_dict

def __len__(self):
Expand Down
7 changes: 7 additions & 0 deletions logix/logging/log_loader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def get_flatten_item(mmap, offset, block_size, dtype="float32"):
return torch.from_numpy(array).clone()


def unflatten_tensor(flat_tensor, shape, start):
num_elements = reduce(lambda x, y: x * y, shape)
end = start + num_elements
unflattened_tensor = flat_tensor[start:end].view(*shape)
return unflattened_tensor, end


def _init_collate_structure(nested_dict):
# Initialize the collate structure based on the first item
if isinstance(nested_dict, dict):
Expand Down

0 comments on commit 91329c0

Please sign in to comment.