diff --git a/src/mixins/__init__.py b/src/mixins/__init__.py index 77f2171..574b788 100644 --- a/src/mixins/__init__.py +++ b/src/mixins/__init__.py @@ -1,31 +1,4 @@ -from .debuggable import DebuggableMixin -from .multiprocessingable import MultiprocessingMixin -from .saveable import SaveableMixin from .seedable import SeedableMixin -from .swapcacheable import SwapcacheableMixin from .timeable import TimeableMixin -__all__ = [ - "DebuggableMixin", - "MultiprocessingMixin", - "SaveableMixin", - "SeedableMixin", - "SwapcacheableMixin", - "TimeableMixin", -] - -# Tensorable and Tqdmable rely on packages that may or may not be installed. - -try: - from .tensorable import TensorableMixin # noqa - - __all__.append("TensorableMixin") -except ImportError: - pass - -try: - from .tqdmable import TQDMableMixin # noqa - - __all__.append("TQDMableMixin") -except ImportError: - pass +__all__ = ["SeedableMixin", "TimeableMixin"] diff --git a/src/mixins/debuggable.py b/src/mixins/debuggable.py deleted file mode 100644 index 4ef58b0..0000000 --- a/src/mixins/debuggable.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -import functools -import inspect -import pickle -from copy import deepcopy -from pathlib import Path - -from .utils import doublewrap - - -class DebuggableMixin: - @property - def _do_debug(self): - if hasattr(self, "do_debug"): - return self.do_debug - else: - return False - - @staticmethod - @doublewrap - def CaptureErrorState(fn, store_global: bool | None = None, filepath: Path | None = None): - if store_global is None: - store_global = filepath is None - - @functools.wraps(fn) - def debugging_wrapper(self, *args, seed: int | None = None, **kwargs): - if not self._do_debug: - return fn(self, *args, **kwargs) - - try: - return fn(self, *args, **kwargs) - except Exception: - T = inspect.trace() - for t in T: - if t[3] == fn.__name__: - break - - new_vars = deepcopy(t[0].f_locals) - if store_global: - __builtins__["_DEBUGGER_VARS"] = new_vars - if filepath: - with open(filepath, mode="wb") as f: - pickle.dump(new_vars, f) - raise - - return debugging_wrapper diff --git a/src/mixins/multiprocessingable.py b/src/mixins/multiprocessingable.py deleted file mode 100644 index b8443af..0000000 --- a/src/mixins/multiprocessingable.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable, Sequence -from multiprocessing import Pool - - -class MultiprocessingMixin: - def __init__(self, *args, multiprocessing_pool_size: int | None = None, **kwargs): - self.multiprocessing_pool_size = multiprocessing_pool_size - - @property - def _multiprocessing_pool_size(self): - if hasattr(self, "multiprocessing_pool_size"): - return self.multiprocessing_pool_size - else: - return None - - @property - def _use_multiprocessing(self): - return self._multiprocessing_pool_size is not None and self._multiprocessing_pool_size > 1 - - def _map(self, fn: Callable, iterable: Sequence, tqdm: Callable | None = None, **tqdm_kwargs) -> Sequence: - if self._use_multiprocessing: - with Pool(self._multiprocessing_pool_size) as p: - if tqdm is None: - return p.map(fn, iterable) - else: - return list(tqdm(p.imap(fn, iterable), **tqdm_kwargs)) - elif tqdm is None: - return [fn(x) for x in iterable] - else: - return [fn(x) for x in tqdm(iterable, **tqdm_kwargs)] diff --git a/src/mixins/saveable.py b/src/mixins/saveable.py deleted file mode 100644 index fa6ea72..0000000 --- a/src/mixins/saveable.py +++ /dev/null @@ -1,85 +0,0 @@ -from __future__ import annotations - -import pickle as pickle - -try: - import dill - - dill_imported = True - dill_import_error = None -except ImportError as e: - dill_import_error = e - dill_imported = False - -from pathlib import Path - - -class SaveableMixin: - _DEL_BEFORE_SAVING_ATTRS = [] - # TODO(mmd): Make StrEnum upon conversion to python 3.11 - _PICKLER = "dill" if dill_imported else "pickle" - - def __init__(self, *args, **kwargs): - self.do_overwrite = kwargs.get("do_overwrite", False) - if self._PICKLER == "dill" and not dill_imported: - raise dill_import_error - - @classmethod - def _load(cls, filepath: Path, **add_kwargs) -> None: - if not filepath.exists(): - raise FileNotFoundError(f"{filepath} does not exist.") - elif not filepath.is_file(): - raise IsADirectoryError(f"{filepath} is not a file.") - - with open(filepath, mode="rb") as f: - match cls._PICKLER: - case "dill": - if not dill_imported: - raise dill_import_error - obj = dill.load(f) - case "pickle": - obj = pickle.load(f) - case _: - raise NotImplementedError(f"{cls._PICKLER} not supported! Options: {'dill', 'pickle'}") - - for a, v in add_kwargs.items(): - setattr(obj, a, v) - obj._post_load(add_kwargs) - - return obj - - def _post_load(self, load_add_kwargs: dict) -> None: - # Overwrite this in the base class if desired. - return - - def _save(self, filepath: Path, do_overwrite: bool | None = False) -> None: - if not hasattr(self, "do_overwrite"): - self.do_overwrite = False - if not (self.do_overwrite or do_overwrite): - if filepath.exists(): - raise FileExistsError(f"Filepath {filepath} already exists!") - - skipped_attrs = {} - for attr in self._DEL_BEFORE_SAVING_ATTRS: - if hasattr(self, attr): - skipped_attrs[attr] = self.__dict__.pop(attr) - - try: - with open(filepath, mode="wb") as f: - match self._PICKLER: - case "dill": - if not dill_imported: - raise dill_import_error - dill.dump(self, f) - case "pickle": - pickle.dump(self, f) - case _: - raise NotImplementedError( - f"{self._PICKLER} not supported! Options: {'dill', 'pickle'}" - ) - except Exception: - filepath.unlink() - raise - - for attr, val in skipped_attrs.items(): - setattr(self, attr, val) diff --git a/src/mixins/swapcacheable.py b/src/mixins/swapcacheable.py deleted file mode 100644 index 5103adb..0000000 --- a/src/mixins/swapcacheable.py +++ /dev/null @@ -1,74 +0,0 @@ -from __future__ import annotations - -import time -from collections.abc import Hashable - - -class SwapcacheableMixin: - def __init__(self, *args, **kwargs): - self._cache_size = kwargs.get("cache_size", 5) - - def _init_attrs(self): - if not hasattr(self, "_cache"): - self._cache = {"keys": [], "values": []} - if not hasattr(self, "_cache_size"): - self._cache_size = 5 - if not hasattr(self, "_front_attrs"): - self._front_attrs = [] - if not hasattr(self, "_front_cache_key"): - self._front_cache_key = None - if not hasattr(self, "_front_cache_idx"): - self._front_cache_idx = None - - def _set_swapcache_key(self, key: Hashable) -> None: - self._init_attrs() - if key == self._front_cache_key: - return - - seen_key = self._swapcache_has_key(key) - if seen_key: - idx = next(i for i, (k, t) in enumerate(self._seen_parameters) if k == key) - else: - self._cache["keys"].append((key, time.time())) - self._cache["values"].append({}) - - self._cache["keys"] = self._cache["keys"][-self._cache_size :] - self._cache["values"] = self._cache["values"][-self._cache_size :] - - idx = -1 - - # Clear out the old front-and-center attributes - for attr in self._front_attrs: - delattr(self, attr) - - self._front_cache_key = key - self._front_cache_idx = idx - - self._update_front_attrs() - - def _swapcache_has_key(self, key: Hashable) -> bool: - self._init_attrs() - return any(k == key for k, t in self._cache["keys"]) - - def _swap_to_key(self, key: Hashable) -> None: - self._init_attrs() - assert self._swapcache_has_key(key) - self._set_swapcache_key(key) - - def _update_front_attrs(self): - self._init_attrs() - # Set the new front-and-center attributes - for key, val in self._cache["values"][self._front_cache_idx].items(): - setattr(self, key, val) - - def _update_swapcache_key_and_swap(self, key: Hashable, values_dict: dict): - self._init_attrs() - assert key is not None - - self._set_swapcache_key(key) - self._cache["values"][self._front_cache_idx].update(values_dict) - self._update_front_attrs() - - def _update_current_swapcache_key(self, values_dict: dict): - self._init_attrs() - self._update_swapcache_key_and_swap(self._front_cache_key, values_dict) diff --git a/src/mixins/tensorable.py b/src/mixins/tensorable.py deleted file mode 100644 index c603fe1..0000000 --- a/src/mixins/tensorable.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -from collections.abc import Hashable -from typing import Union - -import numpy as np -import torch - - -class TensorableMixin: - Tensorable_T = Union[np.ndarray, list[float], tuple["Tensorable_T"], dict[Hashable, "Tensorable_T"]] - Tensor_T = Union[torch.Tensor, tuple["Tensor_T"], dict[Hashable, "Tensor_T"]] - - def __init__(self, *args, **kwargs): - self.do_cuda = kwargs.get("do_cuda", torch.cuda.is_available()) - - def _cuda(self, T: torch.Tensor, do_cuda: bool | None = None): - if do_cuda is None: - do_cuda = self.do_cuda if hasattr(self, "do_cuda") else torch.cuda.is_available() - - return T.cuda() if do_cuda else T - - def _from_numpy(self, obj: np.ndarray) -> torch.Tensor: - # I keep getting errors about "RuntimeError: expected scalar type Float but found Double" - if obj.dtype == np.float64: - obj = obj.astype(np.float32) - return self._cuda(torch.from_numpy(obj)) - - def _nested_to_tensor(self, obj: TensorableMixin.Tensorable_T) -> TensorableMixin.Tensor_T: - if isinstance(obj, np.ndarray): - return self._from_numpy(obj) - elif isinstance(obj, list): - return self._from_numpy(np.array(obj)) - elif isinstance(obj, dict): - return {k: self._nested_to_tensor(v) for k, v in obj.items()} - elif isinstance(obj, tuple): - return tuple(self._nested_to_tensor(e) for e in obj) - - raise ValueError(f"Don't know how to convert {type(obj)} object {obj} to tensor!") diff --git a/src/mixins/tqdmable.py b/src/mixins/tqdmable.py deleted file mode 100644 index b7f129f..0000000 --- a/src/mixins/tqdmable.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from tqdm.auto import tqdm - - -class TQDMableMixin: - _SKIP_TQDM_IF_LE = 3 - - def __init__(self, *args, **kwargs): - self.tqdm = kwargs.get("tqdm", tqdm) - - def _tqdm(self, rng, **kwargs): - if not hasattr(self, "tqdm"): - self.tqdm = tqdm - - if self.tqdm is None: - return rng - - try: - N = len(rng) - except Exception: - return rng - - if N <= self._SKIP_TQDM_IF_LE: - return rng - - return tqdm(rng, **kwargs) diff --git a/tests/test_saveable_mixin.py b/tests/test_saveable_mixin.py deleted file mode 100644 index 21b8865..0000000 --- a/tests/test_saveable_mixin.py +++ /dev/null @@ -1,82 +0,0 @@ -import unittest -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Any - -from mixins import SaveableMixin - - -class Derived(SaveableMixin): - _PICKLER = "pickle" - - def __init__(self, a: int = -1, b: str = "unset", **kwargs): - super().__init__(**kwargs) - self.a = a - self.b = b - - def __eq__(self, other: Any) -> bool: - return type(self) is type(other) and (self.a == other.a) and (self.b == other.b) - - -class DillDerived(SaveableMixin): - _PICKLER = "dill" - - def __init__(self, a: int = -1, b: str = "unset", **kwargs): - super().__init__(**kwargs) - self.a = a - self.b = b - - def __eq__(self, other: Any) -> bool: - return type(self) is type(other) and (self.a == other.a) and (self.b == other.b) - - -class BadDerived(SaveableMixin): - _PICKLER = "not_supported" - - def __init__(self, a: int = -1, b: str = "unset", **kwargs): - super().__init__(**kwargs) - self.a = a - self.b = b - - def __eq__(self, other: Any) -> bool: - return type(self) is type(other) and (self.a == other.a) and (self.b == other.b) - - -class TestSaveableMixin(unittest.TestCase): - def test_saveable_mixin(self): - T = Derived(a=2, b="hi") - - with TemporaryDirectory() as d: - save_path = Path(d) / "save.pkl" - T._save(save_path) - - with self.assertRaises(FileExistsError): - new_t = Derived(a=3, b="bar") - new_t._save(save_path) - - got_T = Derived._load(save_path) - self.assertEqual(T, got_T) - - bad_T = BadDerived(a=2, b="hi") - with self.assertRaises(NotImplementedError): - bad_T._save(Path(d) / "no_save.pkl") - - # This should error as that pickler isn't supported. - with self.assertRaises(FileNotFoundError): - got_T = Derived._load(Path(d) / "no_save.pkl") - with self.assertRaises(IsADirectoryError): - got_T = Derived._load(Path(d)) - - # This should error as dill isn't installed. - with self.assertRaises(ImportError): - bad_T = DillDerived(a=3, b="baz") - T._PICKLER = "dill" - with self.assertRaises(ImportError): - T._save(Path(d) / "no_save.pkl") - Derived._PICKLER = "dill" - with self.assertRaises(ImportError): - got_T = Derived._load(save_path) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_swapcacheable_mixin.py b/tests/test_swapcacheable_mixin.py deleted file mode 100644 index eee0627..0000000 --- a/tests/test_swapcacheable_mixin.py +++ /dev/null @@ -1,12 +0,0 @@ -import unittest - -from mixins import SwapcacheableMixin - - -class TestSwapcacheableMixin(unittest.TestCase): - def test_constructs(self): - SwapcacheableMixin() - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_tensorable_mixin.py b/tests/test_tensorable_mixin.py deleted file mode 100644 index 9e485a0..0000000 --- a/tests/test_tensorable_mixin.py +++ /dev/null @@ -1,12 +0,0 @@ -import unittest - -from mixins import TensorableMixin - - -class TestTensorableMixin(unittest.TestCase): - def test_constructs(self): - TensorableMixin() - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_tqdmable_mixin.py b/tests/test_tqdmable_mixin.py deleted file mode 100644 index c881cad..0000000 --- a/tests/test_tqdmable_mixin.py +++ /dev/null @@ -1,12 +0,0 @@ -import unittest - -from mixins import TQDMableMixin - - -class TestTQDMableMixin(unittest.TestCase): - def test_constructs(self): - TQDMableMixin() - - -if __name__ == "__main__": - unittest.main()