Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added some new tests. #4

Merged
merged 6 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/code-quality-main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,9 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- name: Install packages
run: |
pip install -e .[dev]

- name: Run pre-commits
uses: pre-commit/[email protected]
4 changes: 4 additions & 0 deletions .github/workflows/code-quality-pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- name: Install packages
run: |
pip install -e .[dev]

- name: Find modified files
id: file_changes
uses: trilom/[email protected]
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
path: dist/

- name: Sign the dists with Sigstore
uses: sigstore/gh-action-sigstore-python@v2.1.1
uses: sigstore/gh-action-sigstore-python@v3.0.0
with:
inputs: >-
./dist/*.tar.gz
Expand Down
26 changes: 23 additions & 3 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,39 @@ jobs:

- name: Install packages
run: |
pip install -e .[tests,tqdmable,tensorable]
pip install -e .[tests]

#----------------------------------------------
# run test suite
#----------------------------------------------
- name: Run tests
- name: Run non-torch tests
run: |
pytest -v --doctest-modules --cov=src --junitxml=junit.xml -s --ignore=docs
pytest -v --doctest-modules --cov=src --junitxml=junit.xml -s --ignore=docs --ignore=tests/test_torch.py

- name: Upload coverage to Codecov
uses: codecov/[email protected]
with:
token: ${{ secrets.CODECOV_TOKEN }}

- name: Upload test results to Codecov
if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}

- name: Install torch as well
run: |
pip install torch

- name: Run torch tests
run: |
pytest -v --cov=src --junitxml=junit.xml -s tests/test_torch.py

- name: Upload coverage to Codecov
uses: codecov/[email protected]
with:
token: ${{ secrets.CODECOV_TOKEN }}

- name: Upload test results to Codecov
if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
Expand Down
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ classifiers = [
dependencies = ["numpy"]

[project.optional-dependencies]
dev = ["pre-commit"]
tests = ["pytest", "pytest-cov"]
tqdmable = ["tqdm"]
tensorable = ["torch"]
dev = ["pre-commit<4"]
tests = ["pytest", "pytest-cov", "pytest-benchmark"]

[tool.setuptools_scm]

Expand Down
197 changes: 172 additions & 25 deletions src/mixins/seedable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,157 @@

import numpy as np

from .utils import doublewrap
_SEED_FUNCTIONS = {
"numpy": np.random.seed,
"random": random.seed,
}

try:
import torch

def seed_everything(seed: int | None = None, try_import_torch: bool | None = True) -> int:
max_seed_value = np.iinfo(np.uint32).max
min_seed_value = np.iinfo(np.uint32).min
def seed_torch(seed: int):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

try:
if seed is None:
seed = os.environ.get("PL_GLOBAL_SEED")
seed = int(seed)
except (TypeError, ValueError):
seed = np.random.randint(min_seed_value, max_seed_value)
_SEED_FUNCTIONS["torch"] = seed_torch
except ModuleNotFoundError:
pass

assert min_seed_value <= seed <= max_seed_value

random.seed(seed)
np.random.seed(seed)
if try_import_torch:
try:
import torch
from .utils import doublewrap

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
except ModuleNotFoundError:
pass

def seed_everything(seed: int | None = None, seed_engines: set[str] | None = None) -> int:
"""A simple helper function to seed everything that needs to be seeded.

Args:
seed: The seed to use. If None, a random seed is chosen.

Returns:
The seed that was used.

Examples:
>>> random.seed(0)
>>> np.random.seed(0)
>>> random.randint(0, 10)
6
>>> random.randint(0, 10)
6
>>> np.random.randint(0, 10)
5
>>> np.random.randint(0, 10)
0
>>> seed_everything(0)
0
>>> random.randint(0, 10)
6
>>> random.randint(0, 10)
6
>>> np.random.randint(0, 10)
5
>>> np.random.randint(0, 10)
0
"""

if seed_engines is None:
seed_engines = set(_SEED_FUNCTIONS.keys())

if seed is None:
if "PL_GLOBAL_SEED" in os.environ:
seed = int(os.environ["PL_GLOBAL_SEED"])
else:
max_seed_value = np.iinfo(np.uint32).max
min_seed_value = np.iinfo(np.uint32).min
seed = np.random.randint(min_seed_value, max_seed_value)

for s in seed_engines:
_SEED_FUNCTIONS[s](seed)

return seed


class SeedableMixin:
"""This class provides easy utilities to reliably seed stochastic processes.

This seeding can be used to ensure reproducibility in experiments, both in individual examples with an
integral seed or in a stochastic process both at a per-event level and at a whole process level by seeding
with `None`, in which case a new seed is chosen for each event in the process based on the prior seed and
stored.
"""

def __init__(self, *args, **kwargs):
self._past_seeds = kwargs.get("_past_seeds", [])

def _last_seed(self, key: str):
self._seed_engines = kwargs.get("_seed_engines", set(_SEED_FUNCTIONS.keys()))

def _last_seed(self, key: str) -> tuple[int, int | None]:
"""This returns the most recently used seed with a given key.

Args:
key: The key to search for.

Returns:
The index of the most recent seed with a given key in the list of past seeds and the seed itself.

Examples:
>>> M = SeedableMixin()
>>> _ = M._seed(0, "foo")
>>> _ = M._seed(2, "bar")
>>> _ = M._seed(4, "foo")
>>> _ = M._seed(6, "baz")
>>> M._last_seed("foo")
(2, 4)
>>> M._last_seed("bar")
(1, 2)
>>> M._last_seed("baz")
(3, 6)
"""
for idx, (s, k, time) in enumerate(self._past_seeds[::-1]):
if k == key:
idx = len(self._past_seeds) - 1 - idx
return idx, s

return -1, None

def _seed(self, seed: int | None = None, key: str | None = None):
def _seed(self, seed: int | None = None, key: str | None = None) -> int:
"""This seeds the random number generators.

Args:
seed: The seed to use. If None, a new seed is chosen.
key: The key to associate with this seed.

Returns:
The seed that was used.

Examples:
>>> M = SeedableMixin()
>>> M._seed(0, "foo")
0
>>> M._seed(2, "bar")
2
>>> M._seed(4, "foo")
4

Note that by virtue of the fact that we've already seeded `M`, future seeds are deterministic (though
they are still pseudo-random, as they are simply random integers drawn from the current random
distribution, which in this test was seeded at 4 immediately prior to this call).
>>> M._seed()
31681838

Past seeds and keys are stored in the `_past_seeds` attribute, which is created if the object does not
have it at the start.
>>> M = SeedableMixin()
>>> del M._past_seeds
>>> M._seed(0, "foo")
0
>>> M._seed(2, "bar")
2
>>> M._seed(4, "foo")
4
>>> M._seed()
31681838
>>> M._past_seeds
[(0, 'foo', ...), (2, 'bar', ...), (4, 'foo', ...), (31681838, '', ...)]
"""
if seed is None:
seed = random.randint(0, int(1e8))
if key is None:
Expand All @@ -62,12 +170,51 @@ def _seed(self, seed: int | None = None, key: str | None = None):
else:
self._past_seeds = [(self.seed, key, time)]

seed_everything(seed)
seed_everything(seed, getattr(self, "_seed_engines", None))
return seed

@staticmethod
@doublewrap
def WithSeed(fn, key: str | None = None):
def WithSeed(fn, key: str | None = None) -> callable:
"""This function is a decorator that returns a function that also takes a seed which seeds the RNG.

This decorator can either be called with a `key` argument or without arguments. In the latter case,
the decorator is used like this:

```
@SeedableMixin.WithSeed
def func(...):
...
```

In this case, the name of the function is used as the key to the associated seed call. If a key is
provided, the decorator is used like this:

```
@SeedableMixin.WithSeed(key="foo")
def func(...):
...
```

In this case, the key is used as the key to the associated seed call. This is useful when the function
name is not the desired seed.

Args:
fn: The function to wrap. This argument _does not need to be provided_ if a key is used; instead
the `doublewrap` decorator is used to allow the key to be passed as a keyword argument to a
meta-function that returns the true decorator applied to the target function.
key: The key to use for the seed. If None, the function name is used.

Returns:
A function that takes all the input arguments of the wrapped function and a seed keyword argument.
If the seed is not provided, a new seed is chosen. The seed is used to seed the RNG before calling
the wrapped function, under the provided key.

Note that if the function being wrapped explicitly takes a seed argument, this decorator will not
work, and the failure will not necessarily be graceful.

Examples:
"""
if key is None:
key = fn.__name__

Expand Down
14 changes: 14 additions & 0 deletions src/mixins/timeable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@


class TimeableMixin:
"""A mixin class to add timing functionality to a class for profiling its methods.

This mixin class provides the following functionality:
- Timing of methods using the TimeAs decorator.
- Timing of arbitrary code blocks using the _time_as context manager.
- Profiling of the durations of the timed methods.

Attributes:
_timings: A dictionary of lists of dictionaries containing the start and end times of timed methods.
The keys of the dictionary are the names of the timed methods.
The values are lists of dictionaries containing the start and end times of each timed method call.
The dictionaries contain the keys "start" and "end" with the corresponding times.
"""

_START_TIME = "start"
_END_TIME = "end"

Expand Down
Loading
Loading