From b77d955ffc8d1328f4800facab968ba2a77a8b01 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 24 Sep 2024 09:46:39 -0400 Subject: [PATCH 1/8] Removed poorly supported mixins. --- src/mixins/debuggable.py | 47 ----------------- src/mixins/multiprocessingable.py | 32 ------------ src/mixins/saveable.py | 85 ------------------------------- src/mixins/swapcacheable.py | 74 --------------------------- src/mixins/tensorable.py | 39 -------------- src/mixins/tqdmable.py | 27 ---------- tests/test_saveable_mixin.py | 82 ----------------------------- tests/test_swapcacheable_mixin.py | 12 ----- tests/test_tensorable_mixin.py | 12 ----- tests/test_tqdmable_mixin.py | 12 ----- 10 files changed, 422 deletions(-) delete mode 100644 src/mixins/debuggable.py delete mode 100644 src/mixins/multiprocessingable.py delete mode 100644 src/mixins/saveable.py delete mode 100644 src/mixins/swapcacheable.py delete mode 100644 src/mixins/tensorable.py delete mode 100644 src/mixins/tqdmable.py delete mode 100644 tests/test_saveable_mixin.py delete mode 100644 tests/test_swapcacheable_mixin.py delete mode 100644 tests/test_tensorable_mixin.py delete mode 100644 tests/test_tqdmable_mixin.py diff --git a/src/mixins/debuggable.py b/src/mixins/debuggable.py deleted file mode 100644 index 4ef58b0..0000000 --- a/src/mixins/debuggable.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -import functools -import inspect -import pickle -from copy import deepcopy -from pathlib import Path - -from .utils import doublewrap - - -class DebuggableMixin: - @property - def _do_debug(self): - if hasattr(self, "do_debug"): - return self.do_debug - else: - return False - - @staticmethod - @doublewrap - def CaptureErrorState(fn, store_global: bool | None = None, filepath: Path | None = None): - if store_global is None: - store_global = filepath is None - - @functools.wraps(fn) - def debugging_wrapper(self, *args, seed: int | None = None, **kwargs): - if not self._do_debug: - return fn(self, *args, **kwargs) - - try: - return fn(self, *args, **kwargs) - except Exception: - T = inspect.trace() - for t in T: - if t[3] == fn.__name__: - break - - new_vars = deepcopy(t[0].f_locals) - if store_global: - __builtins__["_DEBUGGER_VARS"] = new_vars - if filepath: - with open(filepath, mode="wb") as f: - pickle.dump(new_vars, f) - raise - - return debugging_wrapper diff --git a/src/mixins/multiprocessingable.py b/src/mixins/multiprocessingable.py deleted file mode 100644 index b8443af..0000000 --- a/src/mixins/multiprocessingable.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable, Sequence -from multiprocessing import Pool - - -class MultiprocessingMixin: - def __init__(self, *args, multiprocessing_pool_size: int | None = None, **kwargs): - self.multiprocessing_pool_size = multiprocessing_pool_size - - @property - def _multiprocessing_pool_size(self): - if hasattr(self, "multiprocessing_pool_size"): - return self.multiprocessing_pool_size - else: - return None - - @property - def _use_multiprocessing(self): - return self._multiprocessing_pool_size is not None and self._multiprocessing_pool_size > 1 - - def _map(self, fn: Callable, iterable: Sequence, tqdm: Callable | None = None, **tqdm_kwargs) -> Sequence: - if self._use_multiprocessing: - with Pool(self._multiprocessing_pool_size) as p: - if tqdm is None: - return p.map(fn, iterable) - else: - return list(tqdm(p.imap(fn, iterable), **tqdm_kwargs)) - elif tqdm is None: - return [fn(x) for x in iterable] - else: - return [fn(x) for x in tqdm(iterable, **tqdm_kwargs)] diff --git a/src/mixins/saveable.py b/src/mixins/saveable.py deleted file mode 100644 index fa6ea72..0000000 --- a/src/mixins/saveable.py +++ /dev/null @@ -1,85 +0,0 @@ -from __future__ import annotations - -import pickle as pickle - -try: - import dill - - dill_imported = True - dill_import_error = None -except ImportError as e: - dill_import_error = e - dill_imported = False - -from pathlib import Path - - -class SaveableMixin: - _DEL_BEFORE_SAVING_ATTRS = [] - # TODO(mmd): Make StrEnum upon conversion to python 3.11 - _PICKLER = "dill" if dill_imported else "pickle" - - def __init__(self, *args, **kwargs): - self.do_overwrite = kwargs.get("do_overwrite", False) - if self._PICKLER == "dill" and not dill_imported: - raise dill_import_error - - @classmethod - def _load(cls, filepath: Path, **add_kwargs) -> None: - if not filepath.exists(): - raise FileNotFoundError(f"{filepath} does not exist.") - elif not filepath.is_file(): - raise IsADirectoryError(f"{filepath} is not a file.") - - with open(filepath, mode="rb") as f: - match cls._PICKLER: - case "dill": - if not dill_imported: - raise dill_import_error - obj = dill.load(f) - case "pickle": - obj = pickle.load(f) - case _: - raise NotImplementedError(f"{cls._PICKLER} not supported! Options: {'dill', 'pickle'}") - - for a, v in add_kwargs.items(): - setattr(obj, a, v) - obj._post_load(add_kwargs) - - return obj - - def _post_load(self, load_add_kwargs: dict) -> None: - # Overwrite this in the base class if desired. - return - - def _save(self, filepath: Path, do_overwrite: bool | None = False) -> None: - if not hasattr(self, "do_overwrite"): - self.do_overwrite = False - if not (self.do_overwrite or do_overwrite): - if filepath.exists(): - raise FileExistsError(f"Filepath {filepath} already exists!") - - skipped_attrs = {} - for attr in self._DEL_BEFORE_SAVING_ATTRS: - if hasattr(self, attr): - skipped_attrs[attr] = self.__dict__.pop(attr) - - try: - with open(filepath, mode="wb") as f: - match self._PICKLER: - case "dill": - if not dill_imported: - raise dill_import_error - dill.dump(self, f) - case "pickle": - pickle.dump(self, f) - case _: - raise NotImplementedError( - f"{self._PICKLER} not supported! Options: {'dill', 'pickle'}" - ) - except Exception: - filepath.unlink() - raise - - for attr, val in skipped_attrs.items(): - setattr(self, attr, val) diff --git a/src/mixins/swapcacheable.py b/src/mixins/swapcacheable.py deleted file mode 100644 index 5103adb..0000000 --- a/src/mixins/swapcacheable.py +++ /dev/null @@ -1,74 +0,0 @@ -from __future__ import annotations - -import time -from collections.abc import Hashable - - -class SwapcacheableMixin: - def __init__(self, *args, **kwargs): - self._cache_size = kwargs.get("cache_size", 5) - - def _init_attrs(self): - if not hasattr(self, "_cache"): - self._cache = {"keys": [], "values": []} - if not hasattr(self, "_cache_size"): - self._cache_size = 5 - if not hasattr(self, "_front_attrs"): - self._front_attrs = [] - if not hasattr(self, "_front_cache_key"): - self._front_cache_key = None - if not hasattr(self, "_front_cache_idx"): - self._front_cache_idx = None - - def _set_swapcache_key(self, key: Hashable) -> None: - self._init_attrs() - if key == self._front_cache_key: - return - - seen_key = self._swapcache_has_key(key) - if seen_key: - idx = next(i for i, (k, t) in enumerate(self._seen_parameters) if k == key) - else: - self._cache["keys"].append((key, time.time())) - self._cache["values"].append({}) - - self._cache["keys"] = self._cache["keys"][-self._cache_size :] - self._cache["values"] = self._cache["values"][-self._cache_size :] - - idx = -1 - - # Clear out the old front-and-center attributes - for attr in self._front_attrs: - delattr(self, attr) - - self._front_cache_key = key - self._front_cache_idx = idx - - self._update_front_attrs() - - def _swapcache_has_key(self, key: Hashable) -> bool: - self._init_attrs() - return any(k == key for k, t in self._cache["keys"]) - - def _swap_to_key(self, key: Hashable) -> None: - self._init_attrs() - assert self._swapcache_has_key(key) - self._set_swapcache_key(key) - - def _update_front_attrs(self): - self._init_attrs() - # Set the new front-and-center attributes - for key, val in self._cache["values"][self._front_cache_idx].items(): - setattr(self, key, val) - - def _update_swapcache_key_and_swap(self, key: Hashable, values_dict: dict): - self._init_attrs() - assert key is not None - - self._set_swapcache_key(key) - self._cache["values"][self._front_cache_idx].update(values_dict) - self._update_front_attrs() - - def _update_current_swapcache_key(self, values_dict: dict): - self._init_attrs() - self._update_swapcache_key_and_swap(self._front_cache_key, values_dict) diff --git a/src/mixins/tensorable.py b/src/mixins/tensorable.py deleted file mode 100644 index d07a2b5..0000000 --- a/src/mixins/tensorable.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -from collections.abc import Hashable -from typing import Union - -import numpy as np -import torch - - -class TensorableMixin: - Tensorable_T = Union[np.ndarray, list[float], tuple["Tensorable_T"], dict[Hashable, "Tensorable_T"]] - Tensor_T = Union[torch.Tensor, tuple["Tensor_T"], dict[Hashable, "Tensor_T"]] - - def __init__(self, *args, **kwargs): - self.do_cuda = kwargs.get("do_cuda", torch.cuda.is_available) - - def _cuda(self, T: torch.Tensor, do_cuda: bool | None = None): - if do_cuda is None: - do_cuda = self.do_cuda if hasattr(self, "do_cuda") else torch.cuda.is_available - - return T.cuda() if do_cuda else T - - def _from_numpy(self, obj: np.ndarray) -> torch.Tensor: - # I keep getting errors about "RuntimeError: expected scalar type Float but found Double" - if obj.dtype == np.float64: - obj = obj.astype(np.float32) - return self._cuda(torch.from_numpy(obj)) - - def _nested_to_tensor(self, obj: TensorableMixin.Tensorable_T) -> TensorableMixin.Tensor_T: - if isinstance(obj, np.ndarray): - return self._from_numpy(obj) - elif isinstance(obj, list): - return self._from_numpy(np.array(obj)) - elif isinstance(obj, dict): - return {k: self._nested_to_tensor(v) for k, v in obj.items()} - elif isinstance(obj, tuple): - return tuple(self._nested_to_tensor(e) for e in obj) - - raise ValueError(f"Don't know how to convert {type(obj)} object {obj} to tensor!") diff --git a/src/mixins/tqdmable.py b/src/mixins/tqdmable.py deleted file mode 100644 index b7f129f..0000000 --- a/src/mixins/tqdmable.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from tqdm.auto import tqdm - - -class TQDMableMixin: - _SKIP_TQDM_IF_LE = 3 - - def __init__(self, *args, **kwargs): - self.tqdm = kwargs.get("tqdm", tqdm) - - def _tqdm(self, rng, **kwargs): - if not hasattr(self, "tqdm"): - self.tqdm = tqdm - - if self.tqdm is None: - return rng - - try: - N = len(rng) - except Exception: - return rng - - if N <= self._SKIP_TQDM_IF_LE: - return rng - - return tqdm(rng, **kwargs) diff --git a/tests/test_saveable_mixin.py b/tests/test_saveable_mixin.py deleted file mode 100644 index 21b8865..0000000 --- a/tests/test_saveable_mixin.py +++ /dev/null @@ -1,82 +0,0 @@ -import unittest -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Any - -from mixins import SaveableMixin - - -class Derived(SaveableMixin): - _PICKLER = "pickle" - - def __init__(self, a: int = -1, b: str = "unset", **kwargs): - super().__init__(**kwargs) - self.a = a - self.b = b - - def __eq__(self, other: Any) -> bool: - return type(self) is type(other) and (self.a == other.a) and (self.b == other.b) - - -class DillDerived(SaveableMixin): - _PICKLER = "dill" - - def __init__(self, a: int = -1, b: str = "unset", **kwargs): - super().__init__(**kwargs) - self.a = a - self.b = b - - def __eq__(self, other: Any) -> bool: - return type(self) is type(other) and (self.a == other.a) and (self.b == other.b) - - -class BadDerived(SaveableMixin): - _PICKLER = "not_supported" - - def __init__(self, a: int = -1, b: str = "unset", **kwargs): - super().__init__(**kwargs) - self.a = a - self.b = b - - def __eq__(self, other: Any) -> bool: - return type(self) is type(other) and (self.a == other.a) and (self.b == other.b) - - -class TestSaveableMixin(unittest.TestCase): - def test_saveable_mixin(self): - T = Derived(a=2, b="hi") - - with TemporaryDirectory() as d: - save_path = Path(d) / "save.pkl" - T._save(save_path) - - with self.assertRaises(FileExistsError): - new_t = Derived(a=3, b="bar") - new_t._save(save_path) - - got_T = Derived._load(save_path) - self.assertEqual(T, got_T) - - bad_T = BadDerived(a=2, b="hi") - with self.assertRaises(NotImplementedError): - bad_T._save(Path(d) / "no_save.pkl") - - # This should error as that pickler isn't supported. - with self.assertRaises(FileNotFoundError): - got_T = Derived._load(Path(d) / "no_save.pkl") - with self.assertRaises(IsADirectoryError): - got_T = Derived._load(Path(d)) - - # This should error as dill isn't installed. - with self.assertRaises(ImportError): - bad_T = DillDerived(a=3, b="baz") - T._PICKLER = "dill" - with self.assertRaises(ImportError): - T._save(Path(d) / "no_save.pkl") - Derived._PICKLER = "dill" - with self.assertRaises(ImportError): - got_T = Derived._load(save_path) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_swapcacheable_mixin.py b/tests/test_swapcacheable_mixin.py deleted file mode 100644 index eee0627..0000000 --- a/tests/test_swapcacheable_mixin.py +++ /dev/null @@ -1,12 +0,0 @@ -import unittest - -from mixins import SwapcacheableMixin - - -class TestSwapcacheableMixin(unittest.TestCase): - def test_constructs(self): - SwapcacheableMixin() - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_tensorable_mixin.py b/tests/test_tensorable_mixin.py deleted file mode 100644 index 9e485a0..0000000 --- a/tests/test_tensorable_mixin.py +++ /dev/null @@ -1,12 +0,0 @@ -import unittest - -from mixins import TensorableMixin - - -class TestTensorableMixin(unittest.TestCase): - def test_constructs(self): - TensorableMixin() - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_tqdmable_mixin.py b/tests/test_tqdmable_mixin.py deleted file mode 100644 index c881cad..0000000 --- a/tests/test_tqdmable_mixin.py +++ /dev/null @@ -1,12 +0,0 @@ -import unittest - -from mixins import TQDMableMixin - - -class TestTQDMableMixin(unittest.TestCase): - def test_constructs(self): - TQDMableMixin() - - -if __name__ == "__main__": - unittest.main() From a1975d8e5cd9fa412ee82239101ca95a19e6991a Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 24 Sep 2024 10:02:13 -0400 Subject: [PATCH 2/8] Removed missing references. --- src/mixins/__init__.py | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/src/mixins/__init__.py b/src/mixins/__init__.py index 77f2171..574b788 100644 --- a/src/mixins/__init__.py +++ b/src/mixins/__init__.py @@ -1,31 +1,4 @@ -from .debuggable import DebuggableMixin -from .multiprocessingable import MultiprocessingMixin -from .saveable import SaveableMixin from .seedable import SeedableMixin -from .swapcacheable import SwapcacheableMixin from .timeable import TimeableMixin -__all__ = [ - "DebuggableMixin", - "MultiprocessingMixin", - "SaveableMixin", - "SeedableMixin", - "SwapcacheableMixin", - "TimeableMixin", -] - -# Tensorable and Tqdmable rely on packages that may or may not be installed. - -try: - from .tensorable import TensorableMixin # noqa - - __all__.append("TensorableMixin") -except ImportError: - pass - -try: - from .tqdmable import TQDMableMixin # noqa - - __all__.append("TQDMableMixin") -except ImportError: - pass +__all__ = ["SeedableMixin", "TimeableMixin"] From d30296dae2ec2ed2c92b39206017a69ba5fca122 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 14 Oct 2024 12:15:48 -0400 Subject: [PATCH 3/8] Added some new tests. --- .github/workflows/code-quality-main.yaml | 4 + .github/workflows/code-quality-pr.yaml | 4 + .github/workflows/python-build.yaml | 2 +- .github/workflows/tests.yaml | 2 +- pyproject.toml | 4 +- src/mixins/seedable.py | 173 ++++++++++++++++++++--- src/mixins/timeable.py | 14 ++ 7 files changed, 182 insertions(+), 21 deletions(-) diff --git a/.github/workflows/code-quality-main.yaml b/.github/workflows/code-quality-main.yaml index ec878bf..1fe53e6 100644 --- a/.github/workflows/code-quality-main.yaml +++ b/.github/workflows/code-quality-main.yaml @@ -23,5 +23,9 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install packages + run: | + pip install -e .[dev] + - name: Run pre-commits uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/code-quality-pr.yaml b/.github/workflows/code-quality-pr.yaml index 2e08be0..a942c5e 100644 --- a/.github/workflows/code-quality-pr.yaml +++ b/.github/workflows/code-quality-pr.yaml @@ -26,6 +26,10 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install packages + run: | + pip install -e .[dev] + - name: Find modified files id: file_changes uses: trilom/file-changes-action@v1.2.4 diff --git a/.github/workflows/python-build.yaml b/.github/workflows/python-build.yaml index a32827f..1420804 100644 --- a/.github/workflows/python-build.yaml +++ b/.github/workflows/python-build.yaml @@ -74,7 +74,7 @@ jobs: path: dist/ - name: Sign the dists with Sigstore - uses: sigstore/gh-action-sigstore-python@v2.1.1 + uses: sigstore/gh-action-sigstore-python@v3.0.0 with: inputs: >- ./dist/*.tar.gz diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index aface00..c82160d 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -28,7 +28,7 @@ jobs: - name: Install packages run: | - pip install -e .[tests,tqdmable,tensorable] + pip install -e .[tests] #---------------------------------------------- # run test suite diff --git a/pyproject.toml b/pyproject.toml index 573db20..5072361 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,10 +19,8 @@ classifiers = [ dependencies = ["numpy"] [project.optional-dependencies] -dev = ["pre-commit"] +dev = ["pre-commit<4"] tests = ["pytest", "pytest-cov"] -tqdmable = ["tqdm"] -tensorable = ["torch"] [tool.setuptools_scm] diff --git a/src/mixins/seedable.py b/src/mixins/seedable.py index b519d7a..14f9ede 100644 --- a/src/mixins/seedable.py +++ b/src/mixins/seedable.py @@ -7,41 +7,104 @@ import numpy as np +try: + import torch + + def seed_torch(seed: int): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +except ModuleNotFoundError: + + def seed_torch(seed: int): + pass + + from .utils import doublewrap -def seed_everything(seed: int | None = None, try_import_torch: bool | None = True) -> int: - max_seed_value = np.iinfo(np.uint32).max - min_seed_value = np.iinfo(np.uint32).min +def seed_everything(seed: int | None = None) -> int: + """A simple helper function to seed everything that needs to be seeded. + + Args: + seed: The seed to use. If None, a random seed is chosen. + + Returns: + The seed that was used. + + Examples: + >>> random.seed(0) + >>> np.random.seed(0) + >>> random.randint(0, 10) + 6 + >>> random.randint(0, 10) + 6 + >>> np.random.randint(0, 10) + 5 + >>> np.random.randint(0, 10) + 0 + >>> seed_everything(0) + 0 + >>> random.randint(0, 10) + 6 + >>> random.randint(0, 10) + 6 + >>> np.random.randint(0, 10) + 5 + >>> np.random.randint(0, 10) + 0 + """ try: if seed is None: seed = os.environ.get("PL_GLOBAL_SEED") seed = int(seed) except (TypeError, ValueError): + max_seed_value = np.iinfo(np.uint32).max + min_seed_value = np.iinfo(np.uint32).min seed = np.random.randint(min_seed_value, max_seed_value) - assert min_seed_value <= seed <= max_seed_value - random.seed(seed) np.random.seed(seed) - if try_import_torch: - try: - import torch - - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - except ModuleNotFoundError: - pass + seed_torch(seed) return seed class SeedableMixin: + """This class provides easy utilities to reliably seed stochastic processes. + + This seeding can be used to ensure reproducibility in experiments, both in individual examples with an + integral seed or in a stochastic process both at a per-event level and at a whole process level by seeding + with `None`, in which case a new seed is chosen for each event in the process based on the prior seed and + stored. + """ + def __init__(self, *args, **kwargs): self._past_seeds = kwargs.get("_past_seeds", []) - def _last_seed(self, key: str): + def _last_seed(self, key: str) -> tuple[int, int | None]: + """This returns the most recently used seed with a given key. + + Args: + key: The key to search for. + + Returns: + The index of the most recent seed with a given key in the list of past seeds and the seed itself. + + Examples: + >>> M = SeedableMixin() + >>> _ = M._seed(0, "foo") + >>> _ = M._seed(2, "bar") + >>> _ = M._seed(4, "foo") + >>> _ = M._seed(6, "baz") + >>> M._last_seed("foo") + (2, 4) + >>> M._last_seed("bar") + (1, 2) + >>> M._last_seed("baz") + (3, 6) + """ for idx, (s, k, time) in enumerate(self._past_seeds[::-1]): if k == key: idx = len(self._past_seeds) - 1 - idx @@ -49,7 +112,46 @@ def _last_seed(self, key: str): return -1, None - def _seed(self, seed: int | None = None, key: str | None = None): + def _seed(self, seed: int | None = None, key: str | None = None) -> int: + """This seeds the random number generators. + + Args: + seed: The seed to use. If None, a new seed is chosen. + key: The key to associate with this seed. + + Returns: + The seed that was used. + + Examples: + >>> M = SeedableMixin() + >>> M._seed(0, "foo") + 0 + >>> M._seed(2, "bar") + 2 + >>> M._seed(4, "foo") + 4 + + Note that by virtue of the fact that we've already seeded `M`, future seeds are deterministic (though + they are still pseudo-random, as they are simply random integers drawn from the current random + distribution, which in this test was seeded at 4 immediately prior to this call). + >>> M._seed() + 31681838 + + Past seeds and keys are stored in the `_past_seeds` attribute, which is created if the object does not + have it at the start. + >>> M = SeedableMixin() + >>> del M._past_seeds + >>> M._seed(0, "foo") + 0 + >>> M._seed(2, "bar") + 2 + >>> M._seed(4, "foo") + 4 + >>> M._seed() + 31681838 + >>> M._past_seeds + [(0, 'foo', ...), (2, 'bar', ...), (4, 'foo', ...), (31681838, '', ...)] + """ if seed is None: seed = random.randint(0, int(1e8)) if key is None: @@ -67,7 +169,46 @@ def _seed(self, seed: int | None = None, key: str | None = None): @staticmethod @doublewrap - def WithSeed(fn, key: str | None = None): + def WithSeed(fn, key: str | None = None) -> callable: + """This function is a decorator that returns a function that also takes a seed which seeds the RNG. + + This decorator can either be called with a `key` argument or without arguments. In the latter case, + the decorator is used like this: + + ``` + @SeedableMixin.WithSeed + def func(...): + ... + ``` + + In this case, the name of the function is used as the key to the associated seed call. If a key is + provided, the decorator is used like this: + + ``` + @SeedableMixin.WithSeed(key="foo") + def func(...): + ... + ``` + + In this case, the key is used as the key to the associated seed call. This is useful when the function + name is not the desired seed. + + Args: + fn: The function to wrap. This argument _does not need to be provided_ if a key is used; instead + the `doublewrap` decorator is used to allow the key to be passed as a keyword argument to a + meta-function that returns the true decorator applied to the target function. + key: The key to use for the seed. If None, the function name is used. + + Returns: + A function that takes all the input arguments of the wrapped function and a seed keyword argument. + If the seed is not provided, a new seed is chosen. The seed is used to seed the RNG before calling + the wrapped function, under the provided key. + + Note that if the function being wrapped explicitly takes a seed argument, this decorator will not + work, and the failure will not necessarily be graceful. + + Examples: + """ if key is None: key = fn.__name__ diff --git a/src/mixins/timeable.py b/src/mixins/timeable.py index 14fb8ca..f8e8e42 100644 --- a/src/mixins/timeable.py +++ b/src/mixins/timeable.py @@ -11,6 +11,20 @@ class TimeableMixin: + """A mixin class to add timing functionality to a class for profiling its methods. + + This mixin class provides the following functionality: + - Timing of methods using the TimeAs decorator. + - Timing of arbitrary code blocks using the _time_as context manager. + - Profiling of the durations of the timed methods. + + Attributes: + _timings: A dictionary of lists of dictionaries containing the start and end times of timed methods. + The keys of the dictionary are the names of the timed methods. + The values are lists of dictionaries containing the start and end times of each timed method call. + The dictionaries contain the keys "start" and "end" with the corresponding times. + """ + _START_TIME = "start" _END_TIME = "end" From 3d758802a213c5b386a21e37d6b0fed07a673d07 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 14 Oct 2024 12:43:54 -0400 Subject: [PATCH 4/8] Added benchmarking to the tests and converted to pytest syntax --- .github/workflows/tests.yaml | 2 +- pyproject.toml | 2 +- tests/test_seedable_mixin.py | 175 ++++++++++++++++++----------------- tests/test_timeable_mixin.py | 100 ++++++++++---------- 4 files changed, 143 insertions(+), 136 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index c82160d..56904e4 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -35,7 +35,7 @@ jobs: #---------------------------------------------- - name: Run tests run: | - pytest -v --doctest-modules --cov=src --junitxml=junit.xml -s --ignore=docs + pytest -v --doctest-modules --cov=src --junitxml=junit.xml -s --ignore=docs --benchmark-autosave --benchmark-compare-fail=mean:5% - name: Upload coverage to Codecov uses: codecov/codecov-action@v4.0.1 diff --git a/pyproject.toml b/pyproject.toml index 5072361..6152edb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = ["numpy"] [project.optional-dependencies] dev = ["pre-commit<4"] -tests = ["pytest", "pytest-cov"] +tests = ["pytest", "pytest-cov", "pytest-benchmark"] [tool.setuptools_scm] diff --git a/tests/test_seedable_mixin.py b/tests/test_seedable_mixin.py index 0fa8432..0dd1936 100644 --- a/tests/test_seedable_mixin.py +++ b/tests/test_seedable_mixin.py @@ -1,5 +1,4 @@ import random -import unittest import numpy as np @@ -24,125 +23,131 @@ def decorated_auto_key(self): return random.random() -class TestSeedableMixin(unittest.TestCase): - def test_constructs(self): - SeedableMixin() - SeedableDerived() +def test_constructs(): + SeedableMixin() + SeedableDerived() - def test_responds_to_methods(self): - T = SeedableMixin() - T._seed() - T._last_seed("foo") +def test_benchmark_seeding(benchmark): + T = SeedableDerived() - T = SeedableDerived() - T._seed() - T._last_seed("foo") + benchmark(T._seed) - def test_seeding_freezes_randomness(self): - T = SeedableDerived() - unseeded_1 = T.gen_random_num() - unseeded_2 = T.gen_random_num() +def test_responds_to_methods(): + T = SeedableMixin() - # Without seeding, repeated calls should be different. - self.assertNotEqual(unseeded_1, unseeded_2) + T._seed() + T._last_seed("foo") - T._seed(1) - seeded_1_1 = T.gen_random_num() - seeded_2_1 = T.gen_random_num() + T = SeedableDerived() + T._seed() + T._last_seed("foo") - # Even if I seeded at the start, repeated calls should still be different. - self.assertNotEqual(seeded_1_1, seeded_2_1) - T._seed(1) - seeded_1_2 = T.gen_random_num() - seeded_2_2 = T.gen_random_num() +def test_seeding_freezes_randomness(): + T = SeedableDerived() - # Since I seeded again, they should match the prior sequence. - self.assertEqual(seeded_1_1, seeded_1_2) - self.assertEqual(seeded_2_1, seeded_2_2) + unseeded_1 = T.gen_random_num() + unseeded_2 = T.gen_random_num() - def test_decorated_seeding_freezes_randomness(self): - T = SeedableDerived() + # Without seeding, repeated calls should be different. + assert unseeded_1 != unseeded_2, "Unseeded calls should be different." - unseeded_1 = T.decorated_gen_random_num() - unseeded_2 = T.decorated_gen_random_num() + T._seed(1) + seeded_1_1 = T.gen_random_num() + seeded_2_1 = T.gen_random_num() - # Without seeding, repeated calls should be different. - self.assertNotEqual(unseeded_1, unseeded_2) + # Even if I seeded at the start, repeated calls should still be different. + assert seeded_1_1 != seeded_2_1, "Seeded calls should be different when called repeatedly." - seeded_1_1 = T.decorated_gen_random_num(seed=1) - seeded_2_1 = T.decorated_gen_random_num(seed=2) + T._seed(1) + seeded_1_2 = T.gen_random_num() + seeded_2_2 = T.gen_random_num() - # Even if I seeded at the start, repeated calls should still be different. - self.assertNotEqual(seeded_1_1, seeded_2_1) + # Since I seeded again, they should match the prior sequence. + assert seeded_1_1 == seeded_1_2 + assert seeded_2_1 == seeded_2_2 - seeded_1_2 = T.decorated_gen_random_num(seed=1) - seeded_2_2 = T.decorated_gen_random_num(seed=2) - # Since they are seeded, they should match the prior sequence. - self.assertEqual(seeded_1_1, seeded_1_2) - self.assertEqual(seeded_2_1, seeded_2_2) +def test_decorated_seeding_freezes_randomness(): + T = SeedableDerived() - # Now we want to make sure the seeding is consistent even interrupted. + unseeded_1 = T.decorated_gen_random_num() + unseeded_2 = T.decorated_gen_random_num() - T._seed(0) - seeded_1_3 = T.decorated_gen_random_num(seed=1) - T._seed(10) - seeded_2_3 = T.decorated_gen_random_num(seed=2) + # Without seeding, repeated calls should be different. + assert unseeded_1 != unseeded_2 - self.assertEqual(seeded_1_1, seeded_1_3) - self.assertEqual(seeded_2_1, seeded_2_3) + seeded_1_1 = T.decorated_gen_random_num(seed=1) + seeded_2_1 = T.decorated_gen_random_num(seed=2) - def test_seeds_follow_consistent_sequence(self): - T = SeedableDerived() + # Even if I seeded at the start, repeated calls should still be different. + assert seeded_1_1 != seeded_2_1 - unseeded_seq = [T._seed() for i in range(5)] + seeded_1_2 = T.decorated_gen_random_num(seed=1) + seeded_2_2 = T.decorated_gen_random_num(seed=2) - seed_1 = T._seed(1) + # Since they are seeded, they should match the prior sequence. + assert seeded_1_1 == seeded_1_2 + assert seeded_2_1 == seeded_2_2 - # seed_1 should be 1 given I passed a seed in: - self.assertEqual(seed_1, 1) + # Now we want to make sure the seeding is consistent even interrupted. - next_seeds_1 = [T._seed() for i in range(5)] + T._seed(0) + seeded_1_3 = T.decorated_gen_random_num(seed=1) + T._seed(10) + seeded_2_3 = T.decorated_gen_random_num(seed=2) - # These should differ from the unseeded sequence of seeds - self.assertNotEqual(unseeded_seq, next_seeds_1) + assert seeded_1_1 == seeded_1_3 + assert seeded_2_1 == seeded_2_3 - T._seed(1) - next_seeds_2 = [T._seed() for i in range(5)] +def test_seeds_follow_consistent_sequence(): + T = SeedableDerived() - # The sequence of seeds should be the same here. - self.assertEqual(next_seeds_1, next_seeds_2) + unseeded_seq = [T._seed() for i in range(5)] - def test_get_last_seed(self): - T = SeedableDerived() + seed_1 = T._seed(1) - key = "key" - non_key = "not_key" + # seed_1 should be 1 given I passed a seed in: + assert seed_1 == 1 - seed_key_early = 1 - seed_key_late = 1 - seed_non_key = 2 + next_seeds_1 = [T._seed() for i in range(5)] - T._seed() + # These should differ from the unseeded sequence of seeds + assert unseeded_seq != next_seeds_1 - idx, seed = T._last_seed(key) - self.assertEqual(idx, -1) - self.assertEqual(seed, None) + T._seed(1) - T._seed(seed_key_early, key) - T._seed() - T._seed(seed_non_key, non_key) - T._seed(seed_key_late, key) - T._seed(seed_non_key, non_key) + next_seeds_2 = [T._seed() for i in range(5)] - idx, seed = T._last_seed(key) - self.assertEqual(idx, 4) - self.assertEqual(seed, seed_key_late) + # The sequence of seeds should be the same here. + assert next_seeds_1 == next_seeds_2 -if __name__ == "__main__": - unittest.main() +def test_get_last_seed(): + T = SeedableDerived() + + key = "key" + non_key = "not_key" + + seed_key_early = 1 + seed_key_late = 1 + seed_non_key = 2 + + T._seed() + + idx, seed = T._last_seed(key) + assert idx == -1 + assert seed is None + + T._seed(seed_key_early, key) + T._seed() + T._seed(seed_non_key, non_key) + T._seed(seed_key_late, key) + T._seed(seed_non_key, non_key) + + idx, seed = T._last_seed(key) + assert idx == 4 + assert seed == seed_key_late diff --git a/tests/test_timeable_mixin.py b/tests/test_timeable_mixin.py index 4b70b70..e8def6c 100644 --- a/tests/test_timeable_mixin.py +++ b/tests/test_timeable_mixin.py @@ -1,5 +1,4 @@ import time -import unittest import numpy as np @@ -24,69 +23,72 @@ def decorated_takes_time_auto_key(self, num_seconds: int = 10): time.sleep(num_seconds) -class TestTimeableMixin(unittest.TestCase): - def test_constructs(self): - TimeableMixin() - TimeableDerived() +def test_constructs(): + TimeableMixin() + TimeableDerived() - def test_responds_to_methods(self): - T = TimeableMixin() - T._register_start("key") - T._register_end("key") +def test_responds_to_methods(): + T = TimeableMixin() - T._times_for("key") + T._register_start("key") + T._register_end("key") - T._register_start("key") - T._time_so_far("key") - T._register_end("key") + T._times_for("key") - def test_pprint_num_unit(self): - self.assertEqual((5, "μs"), TimeableMixin._get_pprint_num_unit(5 * 1e-6)) + T._register_start("key") + T._time_so_far("key") + T._register_end("key") - class Derived(TimeableMixin): - _CUTOFFS_AND_UNITS = [(10, "foo"), (2, "bar"), (None, "biz")] - self.assertEqual((3, "biz"), Derived._get_pprint_num_unit(3, "biz")) - self.assertEqual((3, "foo"), Derived._get_pprint_num_unit(3 / 20, "biz")) - self.assertEqual((1.2, "biz"), Derived._get_pprint_num_unit(2.4 * 10, "foo")) +def test_benchmark_timing(benchmark): + T = TimeableDerived() - def test_context_manager(self): - T = TimeableDerived() + benchmark(T.decorated_takes_time, 0.001) - T.uses_contextlib(num_seconds=1) - duration = T._times_for("using_contextlib")[-1] - np.testing.assert_almost_equal(duration, 1, decimal=1) +def test_pprint_num_unit(): + assert (5, "μs") == TimeableMixin._get_pprint_num_unit(5 * 1e-6) - def test_times_and_profiling(self): - T = TimeableDerived() - T.decorated_takes_time(num_seconds=2) + class Derived(TimeableMixin): + _CUTOFFS_AND_UNITS = [(10, "foo"), (2, "bar"), (None, "biz")] - duration = T._times_for("decorated")[-1] - np.testing.assert_almost_equal(duration, 2, decimal=1) + assert (3, "biz") == Derived._get_pprint_num_unit(3, "biz") + assert (3, "foo") == Derived._get_pprint_num_unit(3 / 20, "biz") + assert (1.2, "biz") == Derived._get_pprint_num_unit(2.4 * 10, "foo") - T.decorated_takes_time_auto_key(num_seconds=2) - duration = T._times_for("decorated_takes_time_auto_key")[-1] - np.testing.assert_almost_equal(duration, 2, decimal=1) - T.decorated_takes_time(num_seconds=1) - stats = T._duration_stats +def test_context_manager(): + T = TimeableDerived() - self.assertEqual({"decorated", "decorated_takes_time_auto_key"}, set(stats.keys())) - np.testing.assert_almost_equal(1.5, stats["decorated"][0], decimal=1) - self.assertEqual(2, stats["decorated"][1]) - np.testing.assert_almost_equal(0.5, stats["decorated"][2], decimal=1) - np.testing.assert_almost_equal(2, stats["decorated_takes_time_auto_key"][0], decimal=1) - self.assertEqual(1, stats["decorated_takes_time_auto_key"][1]) - self.assertEqual(0, stats["decorated_takes_time_auto_key"][2]) + T.uses_contextlib(num_seconds=1) - got_str = T._profile_durations() - want_str = ( - "decorated_takes_time_auto_key: 2.0 sec\n" "decorated: 1.5 ± 0.5 sec (x2)" - ) - self.assertEqual(want_str, got_str, msg=f"Want:\n{want_str}\nGot:\n{got_str}") + duration = T._times_for("using_contextlib")[-1] + np.testing.assert_almost_equal(duration, 1, decimal=1) -if __name__ == "__main__": - unittest.main() +def test_times_and_profiling(): + T = TimeableDerived() + T.decorated_takes_time(num_seconds=2) + + duration = T._times_for("decorated")[-1] + np.testing.assert_almost_equal(duration, 2, decimal=1) + + T.decorated_takes_time_auto_key(num_seconds=2) + duration = T._times_for("decorated_takes_time_auto_key")[-1] + np.testing.assert_almost_equal(duration, 2, decimal=1) + + T.decorated_takes_time(num_seconds=1) + stats = T._duration_stats + + assert {"decorated", "decorated_takes_time_auto_key"} == set(stats.keys()) + np.testing.assert_almost_equal(1.5, stats["decorated"][0], decimal=1) + assert 2 == stats["decorated"][1] + np.testing.assert_almost_equal(0.5, stats["decorated"][2], decimal=1) + np.testing.assert_almost_equal(2, stats["decorated_takes_time_auto_key"][0], decimal=1) + assert 1 == stats["decorated_takes_time_auto_key"][1] + assert 0 == stats["decorated_takes_time_auto_key"][2] + + got_str = T._profile_durations() + want_str = "decorated_takes_time_auto_key: 2.0 sec\n" "decorated: 1.5 ± 0.5 sec (x2)" + assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" From 8ac802997345fac6411b82bf8231594b6f5aa724 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 14 Oct 2024 13:27:30 -0400 Subject: [PATCH 5/8] Added tests of seed everything with and without torch; made it easier to set what should be seeded and what not for faster usage. --- .github/workflows/benchmarks.yaml | 42 +++++++++++++++++++ .github/workflows/tests.yaml | 24 ++++++++++- src/mixins/seedable.py | 38 +++++++++-------- tests/test_seedable_mixin.py | 68 +++++++++++++++++++++++++++++++ tests/test_timeable_mixin.py | 2 +- tests/test_torch.py | 46 +++++++++++++++++++++ 6 files changed, 201 insertions(+), 19 deletions(-) create mode 100644 .github/workflows/benchmarks.yaml create mode 100644 tests/test_torch.py diff --git a/.github/workflows/benchmarks.yaml b/.github/workflows/benchmarks.yaml new file mode 100644 index 0000000..5532fa5 --- /dev/null +++ b/.github/workflows/benchmarks.yaml @@ -0,0 +1,42 @@ +name: Benchmarks + +on: + push: + branches: [main] + pull_request: + branches: [main, "release/*", "dev"] + +jobs: + run_tests_ubuntu: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: 3.11 + + - name: Install packages + run: | + pip install -e .[tests] + + - name: Run benchmark + run: | + pytest tests/ --benchmark-json output.json + + - name: Store benchmark result + uses: benchmark-action/github-action-benchmark@v1 + with: + name: Python Benchmark with pytest-benchmark + tool: "pytest" + output-file-path: examples/pytest/output.json + # Use personal access token instead of GITHUB_TOKEN due to https://github.community/t/github-action-not-triggering-gh-pages-upon-push/16096 + github-token: ${{ secrets.GITHUB_TOKEN }} + auto-push: true + # Show alert with commit comment on detecting possible performance regression + alert-threshold: "100%" + comment-on-alert: true + fail-on-alert: false diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 56904e4..2e0fab8 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -33,14 +33,34 @@ jobs: #---------------------------------------------- # run test suite #---------------------------------------------- - - name: Run tests + - name: Run non-torch tests run: | - pytest -v --doctest-modules --cov=src --junitxml=junit.xml -s --ignore=docs --benchmark-autosave --benchmark-compare-fail=mean:5% + pytest -v --doctest-modules --cov=src --junitxml=junit.xml -s --ignore=docs --ignore=tests/test_torch.py - name: Upload coverage to Codecov uses: codecov/codecov-action@v4.0.1 with: token: ${{ secrets.CODECOV_TOKEN }} + + - name: Upload test results to Codecov + if: ${{ !cancelled() }} + uses: codecov/test-results-action@v1 + with: + token: ${{ secrets.CODECOV_TOKEN }} + + - name: Install torch as well + run: | + pip install torch + + - name: Run torch tests + run: | + pytest -v --cov=src --junitxml=junit.xml -s tests/test_torch.py + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4.0.1 + with: + token: ${{ secrets.CODECOV_TOKEN }} + - name: Upload test results to Codecov if: ${{ !cancelled() }} uses: codecov/test-results-action@v1 diff --git a/src/mixins/seedable.py b/src/mixins/seedable.py index 14f9ede..b309043 100644 --- a/src/mixins/seedable.py +++ b/src/mixins/seedable.py @@ -7,6 +7,11 @@ import numpy as np +_SEED_FUNCTIONS = { + "numpy": np.random.seed, + "random": random.seed, +} + try: import torch @@ -14,16 +19,15 @@ def seed_torch(seed: int): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + _SEED_FUNCTIONS["torch"] = seed_torch except ModuleNotFoundError: - - def seed_torch(seed: int): - pass + pass from .utils import doublewrap -def seed_everything(seed: int | None = None) -> int: +def seed_everything(seed: int | None = None, seed_engines: set[str] | None = None) -> int: """A simple helper function to seed everything that needs to be seeded. Args: @@ -55,18 +59,19 @@ def seed_everything(seed: int | None = None) -> int: 0 """ - try: - if seed is None: - seed = os.environ.get("PL_GLOBAL_SEED") - seed = int(seed) - except (TypeError, ValueError): - max_seed_value = np.iinfo(np.uint32).max - min_seed_value = np.iinfo(np.uint32).min - seed = np.random.randint(min_seed_value, max_seed_value) + if seed_engines is None: + seed_engines = set(_SEED_FUNCTIONS.keys()) + + if seed is None: + if "PL_GLOBAL_SEED" in os.environ: + seed = int(os.environ["PL_GLOBAL_SEED"]) + else: + max_seed_value = np.iinfo(np.uint32).max + min_seed_value = np.iinfo(np.uint32).min + seed = np.random.randint(min_seed_value, max_seed_value) - random.seed(seed) - np.random.seed(seed) - seed_torch(seed) + for s in seed_engines: + _SEED_FUNCTIONS[s](seed) return seed @@ -82,6 +87,7 @@ class SeedableMixin: def __init__(self, *args, **kwargs): self._past_seeds = kwargs.get("_past_seeds", []) + self._seed_engines = kwargs.get("_seed_engines", set(_SEED_FUNCTIONS.keys())) def _last_seed(self, key: str) -> tuple[int, int | None]: """This returns the most recently used seed with a given key. @@ -164,7 +170,7 @@ def _seed(self, seed: int | None = None, key: str | None = None) -> int: else: self._past_seeds = [(self.seed, key, time)] - seed_everything(seed) + seed_everything(seed, getattr(self, "_seed_engines", None)) return seed @staticmethod diff --git a/tests/test_seedable_mixin.py b/tests/test_seedable_mixin.py index 0dd1936..4d00a4f 100644 --- a/tests/test_seedable_mixin.py +++ b/tests/test_seedable_mixin.py @@ -1,8 +1,17 @@ +import os import random import numpy as np from mixins import SeedableMixin +from mixins.seedable import seed_everything + +try: + pass + + raise ImportError("This test requires torch not to be installed to run.") +except (ImportError, ModuleNotFoundError): + pass class SeedableDerived(SeedableMixin): @@ -23,6 +32,65 @@ def decorated_auto_key(self): return random.random() +def test_benchmark_seed_everything(benchmark): + benchmark(seed_everything) + + +def test_benchmark_seed_everything_with_seed(benchmark): + benchmark(seed_everything, 1) + + +def test_benchmark_seed_everything_with_env(benchmark): + os.environ["PL_GLOBAL_SEED"] = "1" + benchmark(seed_everything) + + +def test_seed_everything(): + os.environ["PL_GLOBAL_SEED"] = "1" + seed_everything() + + rand_1 = random.randint(0, 10) + np_rand_1 = np.random.randint(0, 10) + rand_2 = random.randint(0, 10) + np_rand_2 = np.random.randint(0, 10) + + seed_everything(1) + rand_1_1 = random.randint(0, 10) + np_rand_1_1 = np.random.randint(0, 10) + rand_2_1 = random.randint(0, 10) + np_rand_2_1 = np.random.randint(0, 10) + + seed_everything(1, seed_engines={"random"}) + rand_1_2 = random.randint(0, 10) + np_rand_1_2 = np.random.randint(0, 10) + rand_2_2 = random.randint(0, 10) + np_rand_2_2 = np.random.randint(0, 10) + + seed_everything(1, seed_engines={"numpy"}) + rand_1_3 = random.randint(0, 10) + np_rand_1_3 = np.random.randint(0, 10) + rand_2_3 = random.randint(0, 10) + np_rand_2_3 = np.random.randint(0, 10) + + assert rand_1 == rand_1_1 + assert rand_1 == rand_1_2 + assert rand_1 != rand_1_3 + assert rand_1 != rand_2 + + assert np_rand_1 == np_rand_1_1 + assert np_rand_1 == np_rand_1_3 + assert np_rand_1 != np_rand_1_2 + assert np_rand_1 != np_rand_2 + + assert rand_2 == rand_2_1 + assert rand_2 == rand_2_2 + assert rand_2 != rand_2_3 + + assert np_rand_2 == np_rand_2_1 + assert np_rand_2 == np_rand_2_3 + assert np_rand_2 != np_rand_2_2 + + def test_constructs(): SeedableMixin() SeedableDerived() diff --git a/tests/test_timeable_mixin.py b/tests/test_timeable_mixin.py index e8def6c..7b6aeb4 100644 --- a/tests/test_timeable_mixin.py +++ b/tests/test_timeable_mixin.py @@ -44,7 +44,7 @@ def test_responds_to_methods(): def test_benchmark_timing(benchmark): T = TimeableDerived() - benchmark(T.decorated_takes_time, 0.001) + benchmark(T.decorated_takes_time, 0.00001) def test_pprint_num_unit(): diff --git a/tests/test_torch.py b/tests/test_torch.py new file mode 100644 index 0000000..82902e3 --- /dev/null +++ b/tests/test_torch.py @@ -0,0 +1,46 @@ +import random + +import numpy as np + +from mixins.seedable import seed_everything + +try: + import torch +except (ImportError, ModuleNotFoundError): + raise ImportError("This test requires torch to run.") + + +def test_benchmark_seed_everything_torch(benchmark): + benchmark(seed_everything, seed_engines={"torch"}) + + +def test_seed_everything(): + seed_everything(1, seed_engines={"torch"}) + + rand_1_1 = random.randint(0, 10) + np_rand_1_1 = np.random.randint(0, 10) + torch_rand_1_1 = torch.randint(0, 10, (1,)).item() + rand_2_1 = random.randint(0, 10) + np_rand_2_1 = np.random.randint(0, 10) + torch_rand_2_1 = torch.randint(0, 10, (1,)).item() + + seed_everything(1, seed_engines={"torch"}) + + rand_1_2 = random.randint(0, 10) + np_rand_1_2 = np.random.randint(0, 10) + torch_rand_1_2 = torch.randint(0, 10, (1,)).item() + rand_2_2 = random.randint(0, 10) + np_rand_2_2 = np.random.randint(0, 10) + torch_rand_2_2 = torch.randint(0, 10, (1,)).item() + + assert rand_1_1 != rand_1_2 + assert rand_1_1 != rand_2_1 + assert rand_2_1 != rand_2_2 + + assert np_rand_1_1 != np_rand_1_2 + assert np_rand_1_1 != np_rand_2_1 + assert np_rand_2_1 != np_rand_2_2 + + assert torch_rand_1_1 == torch_rand_1_2 + assert torch_rand_1_1 != torch_rand_2_1 + assert torch_rand_2_1 == torch_rand_2_2 From c19481222363ffdfbfb9746da899660e2702d2af Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 14 Oct 2024 13:30:51 -0400 Subject: [PATCH 6/8] Removed broken workflow --- .github/workflows/benchmarks.yaml | 42 ------------------------------- 1 file changed, 42 deletions(-) delete mode 100644 .github/workflows/benchmarks.yaml diff --git a/.github/workflows/benchmarks.yaml b/.github/workflows/benchmarks.yaml deleted file mode 100644 index 5532fa5..0000000 --- a/.github/workflows/benchmarks.yaml +++ /dev/null @@ -1,42 +0,0 @@ -name: Benchmarks - -on: - push: - branches: [main] - pull_request: - branches: [main, "release/*", "dev"] - -jobs: - run_tests_ubuntu: - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Set up Python 3.11 - uses: actions/setup-python@v5 - with: - python-version: 3.11 - - - name: Install packages - run: | - pip install -e .[tests] - - - name: Run benchmark - run: | - pytest tests/ --benchmark-json output.json - - - name: Store benchmark result - uses: benchmark-action/github-action-benchmark@v1 - with: - name: Python Benchmark with pytest-benchmark - tool: "pytest" - output-file-path: examples/pytest/output.json - # Use personal access token instead of GITHUB_TOKEN due to https://github.community/t/github-action-not-triggering-gh-pages-upon-push/16096 - github-token: ${{ secrets.GITHUB_TOKEN }} - auto-push: true - # Show alert with commit comment on detecting possible performance regression - alert-threshold: "100%" - comment-on-alert: true - fail-on-alert: false From 40a993e38a47f836bef375160c8300e091755e2e Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 14 Oct 2024 13:32:05 -0400 Subject: [PATCH 7/8] Made tests more reliable. --- tests/test_seedable_mixin.py | 32 ++++++++++++++++---------------- tests/test_torch.py | 24 ++++++++++++------------ 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/test_seedable_mixin.py b/tests/test_seedable_mixin.py index 4d00a4f..62c0075 100644 --- a/tests/test_seedable_mixin.py +++ b/tests/test_seedable_mixin.py @@ -49,28 +49,28 @@ def test_seed_everything(): os.environ["PL_GLOBAL_SEED"] = "1" seed_everything() - rand_1 = random.randint(0, 10) - np_rand_1 = np.random.randint(0, 10) - rand_2 = random.randint(0, 10) - np_rand_2 = np.random.randint(0, 10) + rand_1 = random.randint(0, 100000000) + np_rand_1 = np.random.randint(0, 100000000) + rand_2 = random.randint(0, 100000000) + np_rand_2 = np.random.randint(0, 100000000) seed_everything(1) - rand_1_1 = random.randint(0, 10) - np_rand_1_1 = np.random.randint(0, 10) - rand_2_1 = random.randint(0, 10) - np_rand_2_1 = np.random.randint(0, 10) + rand_1_1 = random.randint(0, 100000000) + np_rand_1_1 = np.random.randint(0, 100000000) + rand_2_1 = random.randint(0, 100000000) + np_rand_2_1 = np.random.randint(0, 100000000) seed_everything(1, seed_engines={"random"}) - rand_1_2 = random.randint(0, 10) - np_rand_1_2 = np.random.randint(0, 10) - rand_2_2 = random.randint(0, 10) - np_rand_2_2 = np.random.randint(0, 10) + rand_1_2 = random.randint(0, 100000000) + np_rand_1_2 = np.random.randint(0, 100000000) + rand_2_2 = random.randint(0, 100000000) + np_rand_2_2 = np.random.randint(0, 100000000) seed_everything(1, seed_engines={"numpy"}) - rand_1_3 = random.randint(0, 10) - np_rand_1_3 = np.random.randint(0, 10) - rand_2_3 = random.randint(0, 10) - np_rand_2_3 = np.random.randint(0, 10) + rand_1_3 = random.randint(0, 100000000) + np_rand_1_3 = np.random.randint(0, 100000000) + rand_2_3 = random.randint(0, 100000000) + np_rand_2_3 = np.random.randint(0, 100000000) assert rand_1 == rand_1_1 assert rand_1 == rand_1_2 diff --git a/tests/test_torch.py b/tests/test_torch.py index 82902e3..941b540 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -17,21 +17,21 @@ def test_benchmark_seed_everything_torch(benchmark): def test_seed_everything(): seed_everything(1, seed_engines={"torch"}) - rand_1_1 = random.randint(0, 10) - np_rand_1_1 = np.random.randint(0, 10) - torch_rand_1_1 = torch.randint(0, 10, (1,)).item() - rand_2_1 = random.randint(0, 10) - np_rand_2_1 = np.random.randint(0, 10) - torch_rand_2_1 = torch.randint(0, 10, (1,)).item() + rand_1_1 = random.randint(0, 100000000) + np_rand_1_1 = np.random.randint(0, 100000000) + torch_rand_1_1 = torch.randint(0, 100000000, (1,)).item() + rand_2_1 = random.randint(0, 100000000) + np_rand_2_1 = np.random.randint(0, 100000000) + torch_rand_2_1 = torch.randint(0, 100000000, (1,)).item() seed_everything(1, seed_engines={"torch"}) - rand_1_2 = random.randint(0, 10) - np_rand_1_2 = np.random.randint(0, 10) - torch_rand_1_2 = torch.randint(0, 10, (1,)).item() - rand_2_2 = random.randint(0, 10) - np_rand_2_2 = np.random.randint(0, 10) - torch_rand_2_2 = torch.randint(0, 10, (1,)).item() + rand_1_2 = random.randint(0, 100000000) + np_rand_1_2 = np.random.randint(0, 100000000) + torch_rand_1_2 = torch.randint(0, 100000000, (1,)).item() + rand_2_2 = random.randint(0, 100000000) + np_rand_2_2 = np.random.randint(0, 100000000) + torch_rand_2_2 = torch.randint(0, 100000000, (1,)).item() assert rand_1_1 != rand_1_2 assert rand_1_1 != rand_2_1 From 26dc29331399b629f23ea7d06fa1644f8f8dc980 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 14 Oct 2024 13:45:08 -0400 Subject: [PATCH 8/8] Added timeable tests to get coverage up. --- tests/test_timeable_mixin.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/test_timeable_mixin.py b/tests/test_timeable_mixin.py index 7b6aeb4..286459e 100644 --- a/tests/test_timeable_mixin.py +++ b/tests/test_timeable_mixin.py @@ -57,6 +57,14 @@ class Derived(TimeableMixin): assert (3, "foo") == Derived._get_pprint_num_unit(3 / 20, "biz") assert (1.2, "biz") == Derived._get_pprint_num_unit(2.4 * 10, "foo") + try: + Derived._get_pprint_num_unit(1, "WRONG") + raise AssertionError("Should have raised an exception") + except LookupError: + pass + except Exception as e: + raise AssertionError(f"Raised the wrong exception: {e}") + def test_context_manager(): T = TimeableDerived() @@ -90,5 +98,9 @@ def test_times_and_profiling(): assert 0 == stats["decorated_takes_time_auto_key"][2] got_str = T._profile_durations() - want_str = "decorated_takes_time_auto_key: 2.0 sec\n" "decorated: 1.5 ± 0.5 sec (x2)" + want_str = "decorated_takes_time_auto_key: 2.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)" + assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" + + got_str = T._profile_durations(only_keys=["decorated_takes_time_auto_key"]) + want_str = "decorated_takes_time_auto_key: 2.0 sec" assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}"