Skip to content

Commit

Permalink
Metadata Diet (#87)
Browse files Browse the repository at this point in the history
* Metadata Diet

* test fix

* nitfix

* nit
  • Loading branch information
eatpk authored Jan 30, 2024
1 parent 70ada65 commit c01fa48
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 67 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,8 @@ cython_debug/
!tests/storage/test_mmap_data/*.mmap
!tests/examples/checkpoints/*.pt

files/
files/

# Generated Config files.
examples/**/*.yaml
**/config.yaml
32 changes: 19 additions & 13 deletions analog/logging/log_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections import OrderedDict
from functools import reduce

import numpy as np
import torch
from torch.utils.data import Dataset
Expand Down Expand Up @@ -42,30 +44,34 @@ def fetch_data(self):

def __getitem__(self, index):
data_id = list(self.data_id_to_chunk.keys())[index]
chunk_idx, entries = self.data_id_to_chunk[data_id]
chunk_idx, entry = self.data_id_to_chunk[data_id]
nested_dict = {}
mmap = self.memmaps[chunk_idx]

if self.flatten:
blocksize, dtype = get_entry_metadata(entries)
return data_id, get_flatten_item(mmap, index, blocksize, dtype)
return data_id, get_flatten_item(
mmap, index, entry["block_size"], entry["dtype"]
)

for entry in entries:
# Read the data and put it into the nested dictionary
path = entry["path"]
offset = entry["offset"]
shape = tuple(entry["shape"])
dtype = np.dtype(entry["dtype"])
array = np.ndarray(shape, dtype, buffer=mmap, offset=offset, order="C")
tensor = torch.from_numpy(array)
offset = entry["offset"]
dtype = entry["dtype"]
for i in range(len(entry["path"])):
path = entry["path"][i]
shape = tuple(entry["shape"][i])
tensor = torch.from_numpy(
np.ndarray(shape, dtype, buffer=mmap, offset=offset, order="C")
)

# Place the tensor in the correct location within the nested dictionary
current_level = nested_dict
for key in path[:-1]:
if key not in current_level:
current_level[key] = {}
current_level = current_level[key]
current_level[path[-1]] = tensor
offset += reduce(lambda x, y: x * y, shape) * np.dtype(dtype).itemsize

assert (
offset == entry["offset"] + entry["block_size"] * np.dtype(dtype).itemsize
), f"the block_size does not match the shape for data_id: {entry['data_id']}"
return data_id, nested_dict

def __len__(self):
Expand Down
2 changes: 1 addition & 1 deletion analog/logging/log_loader_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_mmap_metadata(
# Append to the existing list for this data_id
data_id_to_chunk[data_id][1].append(entry)
continue
data_id_to_chunk[data_id] = (chunk_index, [entry])
data_id_to_chunk[data_id] = (chunk_index, entry)
return data_id_to_chunk


Expand Down
23 changes: 0 additions & 23 deletions analog/logging/log_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def buffer_write(self, binfo):
"""
Add log state on exit.
"""

data_id = binfo.data_id
log = binfo.log

Expand All @@ -38,29 +37,7 @@ def _add(log, buffer, idx):
continue
_add(value, buffer[key], idx)

# Confirm delete.
# def _get_numpy_value(log, key_tuple, idx):
# current_level = log
# for key in key_tuple:
# if key in current_level:
# current_level = current_level[key]
# if isinstance(current_level, torch.Tensor):
# current_level = current_level[idx]
# else:
# raise ValueError(f"no path {key} exist in the log")
# return current_level

for idx, did in enumerate(data_id):
# if self.flatten:
# paths = self.state.get_state("model_module")
# concat_numpys = []
# for path in paths['path']:
# numpy_value = to_numpy(_get_numpy_value(log, path, idx))
# self.buffer_size += numpy_value.size
# concat_numpys.append(numpy_value)
#
# self.buffer[did] = concat_numpys
# continue
_add(log, self.buffer[did], idx)

def _flush_unsafe(self, log_dir, buffer, flush_count) -> str:
Expand Down
26 changes: 16 additions & 10 deletions analog/logging/mmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
from contextlib import contextmanager
from functools import reduce

import numpy as np

Expand Down Expand Up @@ -42,22 +43,27 @@ def write(save_path, filename, data_buffer, write_order_key, dtype="uint8"):
metadata = []
offset = 0
for data_id, nested_dict in data_buffer:
data_level_offset = offset
data_shape = []
block_size = 0
# Enforcing the insert order based on the module path.
for key in write_order_key:
arr = get_from_nested_dict(nested_dict, key)
bytes = arr.nbytes
data_shape.append(arr.shape)
block_size += reduce(lambda x, y: x * y, arr.shape)
mmap[offset : offset + bytes] = arr.ravel().view(dtype)
metadata.append(
{
"data_id": data_id,
"size": bytes,
"path": key,
"offset": offset,
"shape": arr.shape,
"dtype": str(arr.dtype),
}
)
offset += arr.nbytes
metadata.append(
{
"data_id": data_id,
"path": write_order_key,
"offset": data_level_offset,
"shape": data_shape,
"block_size": block_size,
"dtype": str(arr.dtype),
}
)

mmap.flush()
del mmap # Release the memmap object
Expand Down
40 changes: 28 additions & 12 deletions tests/storage/test_mmap_data/expected_data_metadata.json
Original file line number Diff line number Diff line change
@@ -1,50 +1,66 @@
[
{
"data_id": 0,
"size": 32,
"path": [
"dummy_data"
[
"dummy_data"
]
],
"offset": 0,
"shape": [
8
[
8
]
],
"block_size": 8,
"dtype": "float32"
},
{
"data_id": 1,
"size": 64,
"path": [
"dummy_data"
[
"dummy_data"
]
],
"offset": 32,
"shape": [
8
[
8
]
],
"block_size": 8,
"dtype": "float64"
},
{
"data_id": 2,
"size": 64,
"path": [
"dummy_data"
[
"dummy_data"
]
],
"offset": 96,
"shape": [
8
[
8
]
],
"block_size": 8,
"dtype": "float64"
},
{
"data_id": 3,
"size": 64,
"path": [
"dummy_data"
[
"dummy_data"
]
],
"offset": 160,
"shape": [
8
[
8
]
],
"block_size": 8,
"dtype": "float64"
}
]
20 changes: 13 additions & 7 deletions tests/storage/test_util_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,15 @@ def test_write_and_read(self):

for item in metadata:
offset = item["offset"]
size = item["size"]
shape = tuple(item["shape"])
shape = []
for path in item["shape"]:
shape.append(path)
shape = tuple(shape) if len(shape) > 1 else shape[0]
dtype = np.dtype(item["dtype"])
block_size = item["block_size"]
expected_data = data_buffer[item["data_id"]][1]["dummy_data"]
read_data = np.frombuffer(
mmap, dtype=dtype, count=size // dtype.itemsize, offset=offset
mmap, dtype=dtype, count=block_size, offset=offset
).reshape(shape)
# Test if expected value and read value equals
self.assertTrue(np.array_equal(read_data, expected_data), "Data mismatch")
Expand All @@ -93,14 +96,17 @@ def test_read(self):
expected_mmap = mm
for item in metadata:
offset = item["offset"]
size = item["size"]
shape = tuple(item["shape"])
shape = []
for path in item["shape"]:
shape.append(path)
shape = tuple(shape) if len(shape) > 1 else shape[0]
dtype = np.dtype(item["dtype"])
block_size = item["block_size"]
test_data = np.frombuffer(
mmap, dtype=dtype, count=size // dtype.itemsize, offset=offset
mmap, dtype=dtype, count=block_size, offset=offset
).reshape(shape)
expected_data = np.frombuffer(
mmap, dtype=dtype, count=size // dtype.itemsize, offset=offset
mmap, dtype=dtype, count=block_size, offset=offset
).reshape(shape)
self.assertTrue(np.allclose(test_data, expected_data), "Data mismatch")
cleanup(expected_files_path, filename)

0 comments on commit c01fa48

Please sign in to comment.