diff --git a/.github/workflows/code-quality-main.yaml b/.github/workflows/code-quality-main.yaml index ec878bf..1fe53e6 100644 --- a/.github/workflows/code-quality-main.yaml +++ b/.github/workflows/code-quality-main.yaml @@ -23,5 +23,9 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install packages + run: | + pip install -e .[dev] + - name: Run pre-commits uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/code-quality-pr.yaml b/.github/workflows/code-quality-pr.yaml index 2e08be0..a942c5e 100644 --- a/.github/workflows/code-quality-pr.yaml +++ b/.github/workflows/code-quality-pr.yaml @@ -26,6 +26,10 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install packages + run: | + pip install -e .[dev] + - name: Find modified files id: file_changes uses: trilom/file-changes-action@v1.2.4 diff --git a/.github/workflows/python-build.yaml b/.github/workflows/python-build.yaml index a32827f..1420804 100644 --- a/.github/workflows/python-build.yaml +++ b/.github/workflows/python-build.yaml @@ -74,7 +74,7 @@ jobs: path: dist/ - name: Sign the dists with Sigstore - uses: sigstore/gh-action-sigstore-python@v2.1.1 + uses: sigstore/gh-action-sigstore-python@v3.0.0 with: inputs: >- ./dist/*.tar.gz diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index aface00..2e0fab8 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -28,19 +28,39 @@ jobs: - name: Install packages run: | - pip install -e .[tests,tqdmable,tensorable] + pip install -e .[tests] #---------------------------------------------- # run test suite #---------------------------------------------- - - name: Run tests + - name: Run non-torch tests run: | - pytest -v --doctest-modules --cov=src --junitxml=junit.xml -s --ignore=docs + pytest -v --doctest-modules --cov=src --junitxml=junit.xml -s --ignore=docs --ignore=tests/test_torch.py - name: Upload coverage to Codecov uses: codecov/codecov-action@v4.0.1 with: token: ${{ secrets.CODECOV_TOKEN }} + + - name: Upload test results to Codecov + if: ${{ !cancelled() }} + uses: codecov/test-results-action@v1 + with: + token: ${{ secrets.CODECOV_TOKEN }} + + - name: Install torch as well + run: | + pip install torch + + - name: Run torch tests + run: | + pytest -v --cov=src --junitxml=junit.xml -s tests/test_torch.py + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4.0.1 + with: + token: ${{ secrets.CODECOV_TOKEN }} + - name: Upload test results to Codecov if: ${{ !cancelled() }} uses: codecov/test-results-action@v1 diff --git a/pyproject.toml b/pyproject.toml index 573db20..6152edb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,10 +19,8 @@ classifiers = [ dependencies = ["numpy"] [project.optional-dependencies] -dev = ["pre-commit"] -tests = ["pytest", "pytest-cov"] -tqdmable = ["tqdm"] -tensorable = ["torch"] +dev = ["pre-commit<4"] +tests = ["pytest", "pytest-cov", "pytest-benchmark"] [tool.setuptools_scm] diff --git a/src/mixins/seedable.py b/src/mixins/seedable.py index b519d7a..b309043 100644 --- a/src/mixins/seedable.py +++ b/src/mixins/seedable.py @@ -7,41 +7,110 @@ import numpy as np -from .utils import doublewrap +_SEED_FUNCTIONS = { + "numpy": np.random.seed, + "random": random.seed, +} +try: + import torch -def seed_everything(seed: int | None = None, try_import_torch: bool | None = True) -> int: - max_seed_value = np.iinfo(np.uint32).max - min_seed_value = np.iinfo(np.uint32).min + def seed_torch(seed: int): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) - try: - if seed is None: - seed = os.environ.get("PL_GLOBAL_SEED") - seed = int(seed) - except (TypeError, ValueError): - seed = np.random.randint(min_seed_value, max_seed_value) + _SEED_FUNCTIONS["torch"] = seed_torch +except ModuleNotFoundError: + pass - assert min_seed_value <= seed <= max_seed_value - random.seed(seed) - np.random.seed(seed) - if try_import_torch: - try: - import torch +from .utils import doublewrap - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - except ModuleNotFoundError: - pass + +def seed_everything(seed: int | None = None, seed_engines: set[str] | None = None) -> int: + """A simple helper function to seed everything that needs to be seeded. + + Args: + seed: The seed to use. If None, a random seed is chosen. + + Returns: + The seed that was used. + + Examples: + >>> random.seed(0) + >>> np.random.seed(0) + >>> random.randint(0, 10) + 6 + >>> random.randint(0, 10) + 6 + >>> np.random.randint(0, 10) + 5 + >>> np.random.randint(0, 10) + 0 + >>> seed_everything(0) + 0 + >>> random.randint(0, 10) + 6 + >>> random.randint(0, 10) + 6 + >>> np.random.randint(0, 10) + 5 + >>> np.random.randint(0, 10) + 0 + """ + + if seed_engines is None: + seed_engines = set(_SEED_FUNCTIONS.keys()) + + if seed is None: + if "PL_GLOBAL_SEED" in os.environ: + seed = int(os.environ["PL_GLOBAL_SEED"]) + else: + max_seed_value = np.iinfo(np.uint32).max + min_seed_value = np.iinfo(np.uint32).min + seed = np.random.randint(min_seed_value, max_seed_value) + + for s in seed_engines: + _SEED_FUNCTIONS[s](seed) return seed class SeedableMixin: + """This class provides easy utilities to reliably seed stochastic processes. + + This seeding can be used to ensure reproducibility in experiments, both in individual examples with an + integral seed or in a stochastic process both at a per-event level and at a whole process level by seeding + with `None`, in which case a new seed is chosen for each event in the process based on the prior seed and + stored. + """ + def __init__(self, *args, **kwargs): self._past_seeds = kwargs.get("_past_seeds", []) - - def _last_seed(self, key: str): + self._seed_engines = kwargs.get("_seed_engines", set(_SEED_FUNCTIONS.keys())) + + def _last_seed(self, key: str) -> tuple[int, int | None]: + """This returns the most recently used seed with a given key. + + Args: + key: The key to search for. + + Returns: + The index of the most recent seed with a given key in the list of past seeds and the seed itself. + + Examples: + >>> M = SeedableMixin() + >>> _ = M._seed(0, "foo") + >>> _ = M._seed(2, "bar") + >>> _ = M._seed(4, "foo") + >>> _ = M._seed(6, "baz") + >>> M._last_seed("foo") + (2, 4) + >>> M._last_seed("bar") + (1, 2) + >>> M._last_seed("baz") + (3, 6) + """ for idx, (s, k, time) in enumerate(self._past_seeds[::-1]): if k == key: idx = len(self._past_seeds) - 1 - idx @@ -49,7 +118,46 @@ def _last_seed(self, key: str): return -1, None - def _seed(self, seed: int | None = None, key: str | None = None): + def _seed(self, seed: int | None = None, key: str | None = None) -> int: + """This seeds the random number generators. + + Args: + seed: The seed to use. If None, a new seed is chosen. + key: The key to associate with this seed. + + Returns: + The seed that was used. + + Examples: + >>> M = SeedableMixin() + >>> M._seed(0, "foo") + 0 + >>> M._seed(2, "bar") + 2 + >>> M._seed(4, "foo") + 4 + + Note that by virtue of the fact that we've already seeded `M`, future seeds are deterministic (though + they are still pseudo-random, as they are simply random integers drawn from the current random + distribution, which in this test was seeded at 4 immediately prior to this call). + >>> M._seed() + 31681838 + + Past seeds and keys are stored in the `_past_seeds` attribute, which is created if the object does not + have it at the start. + >>> M = SeedableMixin() + >>> del M._past_seeds + >>> M._seed(0, "foo") + 0 + >>> M._seed(2, "bar") + 2 + >>> M._seed(4, "foo") + 4 + >>> M._seed() + 31681838 + >>> M._past_seeds + [(0, 'foo', ...), (2, 'bar', ...), (4, 'foo', ...), (31681838, '', ...)] + """ if seed is None: seed = random.randint(0, int(1e8)) if key is None: @@ -62,12 +170,51 @@ def _seed(self, seed: int | None = None, key: str | None = None): else: self._past_seeds = [(self.seed, key, time)] - seed_everything(seed) + seed_everything(seed, getattr(self, "_seed_engines", None)) return seed @staticmethod @doublewrap - def WithSeed(fn, key: str | None = None): + def WithSeed(fn, key: str | None = None) -> callable: + """This function is a decorator that returns a function that also takes a seed which seeds the RNG. + + This decorator can either be called with a `key` argument or without arguments. In the latter case, + the decorator is used like this: + + ``` + @SeedableMixin.WithSeed + def func(...): + ... + ``` + + In this case, the name of the function is used as the key to the associated seed call. If a key is + provided, the decorator is used like this: + + ``` + @SeedableMixin.WithSeed(key="foo") + def func(...): + ... + ``` + + In this case, the key is used as the key to the associated seed call. This is useful when the function + name is not the desired seed. + + Args: + fn: The function to wrap. This argument _does not need to be provided_ if a key is used; instead + the `doublewrap` decorator is used to allow the key to be passed as a keyword argument to a + meta-function that returns the true decorator applied to the target function. + key: The key to use for the seed. If None, the function name is used. + + Returns: + A function that takes all the input arguments of the wrapped function and a seed keyword argument. + If the seed is not provided, a new seed is chosen. The seed is used to seed the RNG before calling + the wrapped function, under the provided key. + + Note that if the function being wrapped explicitly takes a seed argument, this decorator will not + work, and the failure will not necessarily be graceful. + + Examples: + """ if key is None: key = fn.__name__ diff --git a/src/mixins/timeable.py b/src/mixins/timeable.py index 14fb8ca..f8e8e42 100644 --- a/src/mixins/timeable.py +++ b/src/mixins/timeable.py @@ -11,6 +11,20 @@ class TimeableMixin: + """A mixin class to add timing functionality to a class for profiling its methods. + + This mixin class provides the following functionality: + - Timing of methods using the TimeAs decorator. + - Timing of arbitrary code blocks using the _time_as context manager. + - Profiling of the durations of the timed methods. + + Attributes: + _timings: A dictionary of lists of dictionaries containing the start and end times of timed methods. + The keys of the dictionary are the names of the timed methods. + The values are lists of dictionaries containing the start and end times of each timed method call. + The dictionaries contain the keys "start" and "end" with the corresponding times. + """ + _START_TIME = "start" _END_TIME = "end" diff --git a/tests/test_seedable_mixin.py b/tests/test_seedable_mixin.py index 0fa8432..62c0075 100644 --- a/tests/test_seedable_mixin.py +++ b/tests/test_seedable_mixin.py @@ -1,9 +1,17 @@ +import os import random -import unittest import numpy as np from mixins import SeedableMixin +from mixins.seedable import seed_everything + +try: + pass + + raise ImportError("This test requires torch not to be installed to run.") +except (ImportError, ModuleNotFoundError): + pass class SeedableDerived(SeedableMixin): @@ -24,125 +32,190 @@ def decorated_auto_key(self): return random.random() -class TestSeedableMixin(unittest.TestCase): - def test_constructs(self): - SeedableMixin() - SeedableDerived() +def test_benchmark_seed_everything(benchmark): + benchmark(seed_everything) + + +def test_benchmark_seed_everything_with_seed(benchmark): + benchmark(seed_everything, 1) + + +def test_benchmark_seed_everything_with_env(benchmark): + os.environ["PL_GLOBAL_SEED"] = "1" + benchmark(seed_everything) + + +def test_seed_everything(): + os.environ["PL_GLOBAL_SEED"] = "1" + seed_everything() + + rand_1 = random.randint(0, 100000000) + np_rand_1 = np.random.randint(0, 100000000) + rand_2 = random.randint(0, 100000000) + np_rand_2 = np.random.randint(0, 100000000) + + seed_everything(1) + rand_1_1 = random.randint(0, 100000000) + np_rand_1_1 = np.random.randint(0, 100000000) + rand_2_1 = random.randint(0, 100000000) + np_rand_2_1 = np.random.randint(0, 100000000) + + seed_everything(1, seed_engines={"random"}) + rand_1_2 = random.randint(0, 100000000) + np_rand_1_2 = np.random.randint(0, 100000000) + rand_2_2 = random.randint(0, 100000000) + np_rand_2_2 = np.random.randint(0, 100000000) + + seed_everything(1, seed_engines={"numpy"}) + rand_1_3 = random.randint(0, 100000000) + np_rand_1_3 = np.random.randint(0, 100000000) + rand_2_3 = random.randint(0, 100000000) + np_rand_2_3 = np.random.randint(0, 100000000) + + assert rand_1 == rand_1_1 + assert rand_1 == rand_1_2 + assert rand_1 != rand_1_3 + assert rand_1 != rand_2 + + assert np_rand_1 == np_rand_1_1 + assert np_rand_1 == np_rand_1_3 + assert np_rand_1 != np_rand_1_2 + assert np_rand_1 != np_rand_2 + + assert rand_2 == rand_2_1 + assert rand_2 == rand_2_2 + assert rand_2 != rand_2_3 + + assert np_rand_2 == np_rand_2_1 + assert np_rand_2 == np_rand_2_3 + assert np_rand_2 != np_rand_2_2 + + +def test_constructs(): + SeedableMixin() + SeedableDerived() + + +def test_benchmark_seeding(benchmark): + T = SeedableDerived() + + benchmark(T._seed) + + +def test_responds_to_methods(): + T = SeedableMixin() + + T._seed() + T._last_seed("foo") - def test_responds_to_methods(self): - T = SeedableMixin() + T = SeedableDerived() + T._seed() + T._last_seed("foo") - T._seed() - T._last_seed("foo") - T = SeedableDerived() - T._seed() - T._last_seed("foo") +def test_seeding_freezes_randomness(): + T = SeedableDerived() - def test_seeding_freezes_randomness(self): - T = SeedableDerived() + unseeded_1 = T.gen_random_num() + unseeded_2 = T.gen_random_num() - unseeded_1 = T.gen_random_num() - unseeded_2 = T.gen_random_num() + # Without seeding, repeated calls should be different. + assert unseeded_1 != unseeded_2, "Unseeded calls should be different." - # Without seeding, repeated calls should be different. - self.assertNotEqual(unseeded_1, unseeded_2) + T._seed(1) + seeded_1_1 = T.gen_random_num() + seeded_2_1 = T.gen_random_num() - T._seed(1) - seeded_1_1 = T.gen_random_num() - seeded_2_1 = T.gen_random_num() + # Even if I seeded at the start, repeated calls should still be different. + assert seeded_1_1 != seeded_2_1, "Seeded calls should be different when called repeatedly." - # Even if I seeded at the start, repeated calls should still be different. - self.assertNotEqual(seeded_1_1, seeded_2_1) + T._seed(1) + seeded_1_2 = T.gen_random_num() + seeded_2_2 = T.gen_random_num() - T._seed(1) - seeded_1_2 = T.gen_random_num() - seeded_2_2 = T.gen_random_num() + # Since I seeded again, they should match the prior sequence. + assert seeded_1_1 == seeded_1_2 + assert seeded_2_1 == seeded_2_2 - # Since I seeded again, they should match the prior sequence. - self.assertEqual(seeded_1_1, seeded_1_2) - self.assertEqual(seeded_2_1, seeded_2_2) - def test_decorated_seeding_freezes_randomness(self): - T = SeedableDerived() +def test_decorated_seeding_freezes_randomness(): + T = SeedableDerived() - unseeded_1 = T.decorated_gen_random_num() - unseeded_2 = T.decorated_gen_random_num() + unseeded_1 = T.decorated_gen_random_num() + unseeded_2 = T.decorated_gen_random_num() - # Without seeding, repeated calls should be different. - self.assertNotEqual(unseeded_1, unseeded_2) + # Without seeding, repeated calls should be different. + assert unseeded_1 != unseeded_2 - seeded_1_1 = T.decorated_gen_random_num(seed=1) - seeded_2_1 = T.decorated_gen_random_num(seed=2) + seeded_1_1 = T.decorated_gen_random_num(seed=1) + seeded_2_1 = T.decorated_gen_random_num(seed=2) - # Even if I seeded at the start, repeated calls should still be different. - self.assertNotEqual(seeded_1_1, seeded_2_1) + # Even if I seeded at the start, repeated calls should still be different. + assert seeded_1_1 != seeded_2_1 - seeded_1_2 = T.decorated_gen_random_num(seed=1) - seeded_2_2 = T.decorated_gen_random_num(seed=2) + seeded_1_2 = T.decorated_gen_random_num(seed=1) + seeded_2_2 = T.decorated_gen_random_num(seed=2) - # Since they are seeded, they should match the prior sequence. - self.assertEqual(seeded_1_1, seeded_1_2) - self.assertEqual(seeded_2_1, seeded_2_2) + # Since they are seeded, they should match the prior sequence. + assert seeded_1_1 == seeded_1_2 + assert seeded_2_1 == seeded_2_2 - # Now we want to make sure the seeding is consistent even interrupted. + # Now we want to make sure the seeding is consistent even interrupted. - T._seed(0) - seeded_1_3 = T.decorated_gen_random_num(seed=1) - T._seed(10) - seeded_2_3 = T.decorated_gen_random_num(seed=2) + T._seed(0) + seeded_1_3 = T.decorated_gen_random_num(seed=1) + T._seed(10) + seeded_2_3 = T.decorated_gen_random_num(seed=2) - self.assertEqual(seeded_1_1, seeded_1_3) - self.assertEqual(seeded_2_1, seeded_2_3) + assert seeded_1_1 == seeded_1_3 + assert seeded_2_1 == seeded_2_3 - def test_seeds_follow_consistent_sequence(self): - T = SeedableDerived() - unseeded_seq = [T._seed() for i in range(5)] +def test_seeds_follow_consistent_sequence(): + T = SeedableDerived() - seed_1 = T._seed(1) + unseeded_seq = [T._seed() for i in range(5)] - # seed_1 should be 1 given I passed a seed in: - self.assertEqual(seed_1, 1) + seed_1 = T._seed(1) - next_seeds_1 = [T._seed() for i in range(5)] + # seed_1 should be 1 given I passed a seed in: + assert seed_1 == 1 - # These should differ from the unseeded sequence of seeds - self.assertNotEqual(unseeded_seq, next_seeds_1) + next_seeds_1 = [T._seed() for i in range(5)] - T._seed(1) + # These should differ from the unseeded sequence of seeds + assert unseeded_seq != next_seeds_1 - next_seeds_2 = [T._seed() for i in range(5)] + T._seed(1) - # The sequence of seeds should be the same here. - self.assertEqual(next_seeds_1, next_seeds_2) + next_seeds_2 = [T._seed() for i in range(5)] - def test_get_last_seed(self): - T = SeedableDerived() + # The sequence of seeds should be the same here. + assert next_seeds_1 == next_seeds_2 - key = "key" - non_key = "not_key" - seed_key_early = 1 - seed_key_late = 1 - seed_non_key = 2 +def test_get_last_seed(): + T = SeedableDerived() - T._seed() + key = "key" + non_key = "not_key" - idx, seed = T._last_seed(key) - self.assertEqual(idx, -1) - self.assertEqual(seed, None) + seed_key_early = 1 + seed_key_late = 1 + seed_non_key = 2 - T._seed(seed_key_early, key) - T._seed() - T._seed(seed_non_key, non_key) - T._seed(seed_key_late, key) - T._seed(seed_non_key, non_key) + T._seed() - idx, seed = T._last_seed(key) - self.assertEqual(idx, 4) - self.assertEqual(seed, seed_key_late) + idx, seed = T._last_seed(key) + assert idx == -1 + assert seed is None + T._seed(seed_key_early, key) + T._seed() + T._seed(seed_non_key, non_key) + T._seed(seed_key_late, key) + T._seed(seed_non_key, non_key) -if __name__ == "__main__": - unittest.main() + idx, seed = T._last_seed(key) + assert idx == 4 + assert seed == seed_key_late diff --git a/tests/test_timeable_mixin.py b/tests/test_timeable_mixin.py index 4b70b70..286459e 100644 --- a/tests/test_timeable_mixin.py +++ b/tests/test_timeable_mixin.py @@ -1,5 +1,4 @@ import time -import unittest import numpy as np @@ -24,69 +23,84 @@ def decorated_takes_time_auto_key(self, num_seconds: int = 10): time.sleep(num_seconds) -class TestTimeableMixin(unittest.TestCase): - def test_constructs(self): - TimeableMixin() - TimeableDerived() +def test_constructs(): + TimeableMixin() + TimeableDerived() - def test_responds_to_methods(self): - T = TimeableMixin() - T._register_start("key") - T._register_end("key") +def test_responds_to_methods(): + T = TimeableMixin() - T._times_for("key") + T._register_start("key") + T._register_end("key") - T._register_start("key") - T._time_so_far("key") - T._register_end("key") + T._times_for("key") - def test_pprint_num_unit(self): - self.assertEqual((5, "μs"), TimeableMixin._get_pprint_num_unit(5 * 1e-6)) + T._register_start("key") + T._time_so_far("key") + T._register_end("key") - class Derived(TimeableMixin): - _CUTOFFS_AND_UNITS = [(10, "foo"), (2, "bar"), (None, "biz")] - self.assertEqual((3, "biz"), Derived._get_pprint_num_unit(3, "biz")) - self.assertEqual((3, "foo"), Derived._get_pprint_num_unit(3 / 20, "biz")) - self.assertEqual((1.2, "biz"), Derived._get_pprint_num_unit(2.4 * 10, "foo")) +def test_benchmark_timing(benchmark): + T = TimeableDerived() - def test_context_manager(self): - T = TimeableDerived() + benchmark(T.decorated_takes_time, 0.00001) - T.uses_contextlib(num_seconds=1) - duration = T._times_for("using_contextlib")[-1] - np.testing.assert_almost_equal(duration, 1, decimal=1) +def test_pprint_num_unit(): + assert (5, "μs") == TimeableMixin._get_pprint_num_unit(5 * 1e-6) - def test_times_and_profiling(self): - T = TimeableDerived() - T.decorated_takes_time(num_seconds=2) + class Derived(TimeableMixin): + _CUTOFFS_AND_UNITS = [(10, "foo"), (2, "bar"), (None, "biz")] - duration = T._times_for("decorated")[-1] - np.testing.assert_almost_equal(duration, 2, decimal=1) + assert (3, "biz") == Derived._get_pprint_num_unit(3, "biz") + assert (3, "foo") == Derived._get_pprint_num_unit(3 / 20, "biz") + assert (1.2, "biz") == Derived._get_pprint_num_unit(2.4 * 10, "foo") - T.decorated_takes_time_auto_key(num_seconds=2) - duration = T._times_for("decorated_takes_time_auto_key")[-1] - np.testing.assert_almost_equal(duration, 2, decimal=1) + try: + Derived._get_pprint_num_unit(1, "WRONG") + raise AssertionError("Should have raised an exception") + except LookupError: + pass + except Exception as e: + raise AssertionError(f"Raised the wrong exception: {e}") - T.decorated_takes_time(num_seconds=1) - stats = T._duration_stats - self.assertEqual({"decorated", "decorated_takes_time_auto_key"}, set(stats.keys())) - np.testing.assert_almost_equal(1.5, stats["decorated"][0], decimal=1) - self.assertEqual(2, stats["decorated"][1]) - np.testing.assert_almost_equal(0.5, stats["decorated"][2], decimal=1) - np.testing.assert_almost_equal(2, stats["decorated_takes_time_auto_key"][0], decimal=1) - self.assertEqual(1, stats["decorated_takes_time_auto_key"][1]) - self.assertEqual(0, stats["decorated_takes_time_auto_key"][2]) +def test_context_manager(): + T = TimeableDerived() - got_str = T._profile_durations() - want_str = ( - "decorated_takes_time_auto_key: 2.0 sec\n" "decorated: 1.5 ± 0.5 sec (x2)" - ) - self.assertEqual(want_str, got_str, msg=f"Want:\n{want_str}\nGot:\n{got_str}") + T.uses_contextlib(num_seconds=1) + duration = T._times_for("using_contextlib")[-1] + np.testing.assert_almost_equal(duration, 1, decimal=1) -if __name__ == "__main__": - unittest.main() + +def test_times_and_profiling(): + T = TimeableDerived() + T.decorated_takes_time(num_seconds=2) + + duration = T._times_for("decorated")[-1] + np.testing.assert_almost_equal(duration, 2, decimal=1) + + T.decorated_takes_time_auto_key(num_seconds=2) + duration = T._times_for("decorated_takes_time_auto_key")[-1] + np.testing.assert_almost_equal(duration, 2, decimal=1) + + T.decorated_takes_time(num_seconds=1) + stats = T._duration_stats + + assert {"decorated", "decorated_takes_time_auto_key"} == set(stats.keys()) + np.testing.assert_almost_equal(1.5, stats["decorated"][0], decimal=1) + assert 2 == stats["decorated"][1] + np.testing.assert_almost_equal(0.5, stats["decorated"][2], decimal=1) + np.testing.assert_almost_equal(2, stats["decorated_takes_time_auto_key"][0], decimal=1) + assert 1 == stats["decorated_takes_time_auto_key"][1] + assert 0 == stats["decorated_takes_time_auto_key"][2] + + got_str = T._profile_durations() + want_str = "decorated_takes_time_auto_key: 2.0 sec\ndecorated: 1.5 ± 0.5 sec (x2)" + assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" + + got_str = T._profile_durations(only_keys=["decorated_takes_time_auto_key"]) + want_str = "decorated_takes_time_auto_key: 2.0 sec" + assert want_str == got_str, f"Want:\n{want_str}\nGot:\n{got_str}" diff --git a/tests/test_torch.py b/tests/test_torch.py new file mode 100644 index 0000000..941b540 --- /dev/null +++ b/tests/test_torch.py @@ -0,0 +1,46 @@ +import random + +import numpy as np + +from mixins.seedable import seed_everything + +try: + import torch +except (ImportError, ModuleNotFoundError): + raise ImportError("This test requires torch to run.") + + +def test_benchmark_seed_everything_torch(benchmark): + benchmark(seed_everything, seed_engines={"torch"}) + + +def test_seed_everything(): + seed_everything(1, seed_engines={"torch"}) + + rand_1_1 = random.randint(0, 100000000) + np_rand_1_1 = np.random.randint(0, 100000000) + torch_rand_1_1 = torch.randint(0, 100000000, (1,)).item() + rand_2_1 = random.randint(0, 100000000) + np_rand_2_1 = np.random.randint(0, 100000000) + torch_rand_2_1 = torch.randint(0, 100000000, (1,)).item() + + seed_everything(1, seed_engines={"torch"}) + + rand_1_2 = random.randint(0, 100000000) + np_rand_1_2 = np.random.randint(0, 100000000) + torch_rand_1_2 = torch.randint(0, 100000000, (1,)).item() + rand_2_2 = random.randint(0, 100000000) + np_rand_2_2 = np.random.randint(0, 100000000) + torch_rand_2_2 = torch.randint(0, 100000000, (1,)).item() + + assert rand_1_1 != rand_1_2 + assert rand_1_1 != rand_2_1 + assert rand_2_1 != rand_2_2 + + assert np_rand_1_1 != np_rand_1_2 + assert np_rand_1_1 != np_rand_2_1 + assert np_rand_2_1 != np_rand_2_2 + + assert torch_rand_1_1 == torch_rand_1_2 + assert torch_rand_1_1 != torch_rand_2_1 + assert torch_rand_2_1 == torch_rand_2_2