From c01fa4878ccf06acc0e4a30404b6e0d433c9cd30 Mon Sep 17 00:00:00 2001 From: Minsoo Kang <30644997+eatpk@users.noreply.github.com> Date: Mon, 29 Jan 2024 22:42:22 -0800 Subject: [PATCH] Metadata Diet (#87) * Metadata Diet * test fix * nitfix * nit --- .gitignore | 6 ++- analog/logging/log_loader.py | 32 +++++++++------ analog/logging/log_loader_util.py | 2 +- analog/logging/log_saver.py | 23 ----------- analog/logging/mmap.py | 26 +++++++----- .../expected_data_metadata.json | 40 +++++++++++++------ tests/storage/test_util_integration.py | 20 ++++++---- 7 files changed, 82 insertions(+), 67 deletions(-) diff --git a/.gitignore b/.gitignore index 03d33c1c..0a90f594 100644 --- a/.gitignore +++ b/.gitignore @@ -172,4 +172,8 @@ cython_debug/ !tests/storage/test_mmap_data/*.mmap !tests/examples/checkpoints/*.pt -files/ \ No newline at end of file +files/ + +# Generated Config files. +examples/**/*.yaml +**/config.yaml diff --git a/analog/logging/log_loader.py b/analog/logging/log_loader.py index 5356ec71..4d953098 100644 --- a/analog/logging/log_loader.py +++ b/analog/logging/log_loader.py @@ -1,4 +1,6 @@ from collections import OrderedDict +from functools import reduce + import numpy as np import torch from torch.utils.data import Dataset @@ -42,30 +44,34 @@ def fetch_data(self): def __getitem__(self, index): data_id = list(self.data_id_to_chunk.keys())[index] - chunk_idx, entries = self.data_id_to_chunk[data_id] + chunk_idx, entry = self.data_id_to_chunk[data_id] nested_dict = {} mmap = self.memmaps[chunk_idx] - if self.flatten: - blocksize, dtype = get_entry_metadata(entries) - return data_id, get_flatten_item(mmap, index, blocksize, dtype) + return data_id, get_flatten_item( + mmap, index, entry["block_size"], entry["dtype"] + ) - for entry in entries: - # Read the data and put it into the nested dictionary - path = entry["path"] - offset = entry["offset"] - shape = tuple(entry["shape"]) - dtype = np.dtype(entry["dtype"]) - array = np.ndarray(shape, dtype, buffer=mmap, offset=offset, order="C") - tensor = torch.from_numpy(array) + offset = entry["offset"] + dtype = entry["dtype"] + for i in range(len(entry["path"])): + path = entry["path"][i] + shape = tuple(entry["shape"][i]) + tensor = torch.from_numpy( + np.ndarray(shape, dtype, buffer=mmap, offset=offset, order="C") + ) - # Place the tensor in the correct location within the nested dictionary current_level = nested_dict for key in path[:-1]: if key not in current_level: current_level[key] = {} current_level = current_level[key] current_level[path[-1]] = tensor + offset += reduce(lambda x, y: x * y, shape) * np.dtype(dtype).itemsize + + assert ( + offset == entry["offset"] + entry["block_size"] * np.dtype(dtype).itemsize + ), f"the block_size does not match the shape for data_id: {entry['data_id']}" return data_id, nested_dict def __len__(self): diff --git a/analog/logging/log_loader_util.py b/analog/logging/log_loader_util.py index 8922925b..bb43914b 100644 --- a/analog/logging/log_loader_util.py +++ b/analog/logging/log_loader_util.py @@ -70,7 +70,7 @@ def get_mmap_metadata( # Append to the existing list for this data_id data_id_to_chunk[data_id][1].append(entry) continue - data_id_to_chunk[data_id] = (chunk_index, [entry]) + data_id_to_chunk[data_id] = (chunk_index, entry) return data_id_to_chunk diff --git a/analog/logging/log_saver.py b/analog/logging/log_saver.py index 44d89427..26b25408 100644 --- a/analog/logging/log_saver.py +++ b/analog/logging/log_saver.py @@ -25,7 +25,6 @@ def buffer_write(self, binfo): """ Add log state on exit. """ - data_id = binfo.data_id log = binfo.log @@ -38,29 +37,7 @@ def _add(log, buffer, idx): continue _add(value, buffer[key], idx) - # Confirm delete. - # def _get_numpy_value(log, key_tuple, idx): - # current_level = log - # for key in key_tuple: - # if key in current_level: - # current_level = current_level[key] - # if isinstance(current_level, torch.Tensor): - # current_level = current_level[idx] - # else: - # raise ValueError(f"no path {key} exist in the log") - # return current_level - for idx, did in enumerate(data_id): - # if self.flatten: - # paths = self.state.get_state("model_module") - # concat_numpys = [] - # for path in paths['path']: - # numpy_value = to_numpy(_get_numpy_value(log, path, idx)) - # self.buffer_size += numpy_value.size - # concat_numpys.append(numpy_value) - # - # self.buffer[did] = concat_numpys - # continue _add(log, self.buffer[did], idx) def _flush_unsafe(self, log_dir, buffer, flush_count) -> str: diff --git a/analog/logging/mmap.py b/analog/logging/mmap.py index d20d84fa..80ed191f 100644 --- a/analog/logging/mmap.py +++ b/analog/logging/mmap.py @@ -2,6 +2,7 @@ import json from contextlib import contextmanager +from functools import reduce import numpy as np @@ -42,22 +43,27 @@ def write(save_path, filename, data_buffer, write_order_key, dtype="uint8"): metadata = [] offset = 0 for data_id, nested_dict in data_buffer: + data_level_offset = offset + data_shape = [] + block_size = 0 # Enforcing the insert order based on the module path. for key in write_order_key: arr = get_from_nested_dict(nested_dict, key) bytes = arr.nbytes + data_shape.append(arr.shape) + block_size += reduce(lambda x, y: x * y, arr.shape) mmap[offset : offset + bytes] = arr.ravel().view(dtype) - metadata.append( - { - "data_id": data_id, - "size": bytes, - "path": key, - "offset": offset, - "shape": arr.shape, - "dtype": str(arr.dtype), - } - ) offset += arr.nbytes + metadata.append( + { + "data_id": data_id, + "path": write_order_key, + "offset": data_level_offset, + "shape": data_shape, + "block_size": block_size, + "dtype": str(arr.dtype), + } + ) mmap.flush() del mmap # Release the memmap object diff --git a/tests/storage/test_mmap_data/expected_data_metadata.json b/tests/storage/test_mmap_data/expected_data_metadata.json index 592ade1a..9645c390 100644 --- a/tests/storage/test_mmap_data/expected_data_metadata.json +++ b/tests/storage/test_mmap_data/expected_data_metadata.json @@ -1,50 +1,66 @@ [ { "data_id": 0, - "size": 32, "path": [ - "dummy_data" + [ + "dummy_data" + ] ], "offset": 0, "shape": [ - 8 + [ + 8 + ] ], + "block_size": 8, "dtype": "float32" }, { "data_id": 1, - "size": 64, "path": [ - "dummy_data" + [ + "dummy_data" + ] ], "offset": 32, "shape": [ - 8 + [ + 8 + ] ], + "block_size": 8, "dtype": "float64" }, { "data_id": 2, - "size": 64, "path": [ - "dummy_data" + [ + "dummy_data" + ] ], "offset": 96, "shape": [ - 8 + [ + 8 + ] ], + "block_size": 8, "dtype": "float64" }, { "data_id": 3, - "size": 64, "path": [ - "dummy_data" + [ + "dummy_data" + ] ], "offset": 160, "shape": [ - 8 + [ + 8 + ] ], + "block_size": 8, "dtype": "float64" } ] \ No newline at end of file diff --git a/tests/storage/test_util_integration.py b/tests/storage/test_util_integration.py index de4e22d7..4a806d1a 100644 --- a/tests/storage/test_util_integration.py +++ b/tests/storage/test_util_integration.py @@ -63,12 +63,15 @@ def test_write_and_read(self): for item in metadata: offset = item["offset"] - size = item["size"] - shape = tuple(item["shape"]) + shape = [] + for path in item["shape"]: + shape.append(path) + shape = tuple(shape) if len(shape) > 1 else shape[0] dtype = np.dtype(item["dtype"]) + block_size = item["block_size"] expected_data = data_buffer[item["data_id"]][1]["dummy_data"] read_data = np.frombuffer( - mmap, dtype=dtype, count=size // dtype.itemsize, offset=offset + mmap, dtype=dtype, count=block_size, offset=offset ).reshape(shape) # Test if expected value and read value equals self.assertTrue(np.array_equal(read_data, expected_data), "Data mismatch") @@ -93,14 +96,17 @@ def test_read(self): expected_mmap = mm for item in metadata: offset = item["offset"] - size = item["size"] - shape = tuple(item["shape"]) + shape = [] + for path in item["shape"]: + shape.append(path) + shape = tuple(shape) if len(shape) > 1 else shape[0] dtype = np.dtype(item["dtype"]) + block_size = item["block_size"] test_data = np.frombuffer( - mmap, dtype=dtype, count=size // dtype.itemsize, offset=offset + mmap, dtype=dtype, count=block_size, offset=offset ).reshape(shape) expected_data = np.frombuffer( - mmap, dtype=dtype, count=size // dtype.itemsize, offset=offset + mmap, dtype=dtype, count=block_size, offset=offset ).reshape(shape) self.assertTrue(np.allclose(test_data, expected_data), "Data mismatch") cleanup(expected_files_path, filename)