Skip to content

Commit

Permalink
WIP: Rework main.py, fix resulting errors
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Dec 6, 2024
1 parent 51d800e commit 23156e1
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 45 deletions.
Empty file.
4 changes: 2 additions & 2 deletions docs/profiling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def test_notebook_commands_dont_cause_errors(experiment_dictconfig: DictConfig):
# _experiment = _setup_experiment(config)
setup_logging(log_level=config.log_level)
lightning.seed_everything(config.seed, workers=True)
_trainer = instantiate_trainer(config)
_trainer = instantiate_trainer(config.trainer)
datamodule = instantiate_datamodule(config.datamodule)
_algorithm = instantiate_algorithm(config.algorithm, datamodule=datamodule)
_algorithm = instantiate_algorithm(config, datamodule=datamodule)

# Note: Here we don't actually do anything with the objects.
4 changes: 2 additions & 2 deletions project/algorithms/testsuites/lightning_module_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def trainer(
) -> lightning.Trainer | JaxTrainer:
setup_logging(log_level=experiment_config.log_level)
lightning.seed_everything(experiment_config.seed, workers=True)
return instantiate_trainer(experiment_config)
return instantiate_trainer(experiment_config.trainer)

@pytest.fixture(scope="class")
def algorithm(
Expand All @@ -82,7 +82,7 @@ def algorithm(
):
"""Fixture that creates the "algorithm" (a
[LightningModule][lightning.pytorch.core.module.LightningModule])."""
algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule)
algorithm = instantiate_algorithm(experiment_config, datamodule=datamodule)
if isinstance(trainer, lightning.Trainer) and isinstance(
algorithm, lightning.LightningModule
):
Expand Down
5 changes: 3 additions & 2 deletions project/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def algorithm(
):
"""Fixture that creates the "algorithm" (a
[LightningModule][lightning.pytorch.core.module.LightningModule])."""
algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule)
algorithm = instantiate_algorithm(experiment_config, datamodule=datamodule)
if isinstance(trainer, lightning.Trainer) and isinstance(algorithm, lightning.LightningModule):
with trainer.init_module(), device:
# A bit hacky, but we have to do this because the lightningmodule isn't associated
Expand All @@ -346,8 +346,9 @@ def trainer(
experiment_config: Config,
) -> pl.Trainer | JaxTrainer:
setup_logging(log_level=experiment_config.log_level)
# put here to copy what's done in main.py
lightning.seed_everything(experiment_config.seed, workers=True)
return instantiate_trainer(experiment_config)
return instantiate_trainer(experiment_config.trainer)


@pytest.fixture(scope="session")
Expand Down
11 changes: 7 additions & 4 deletions project/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def train(
algorithm,
/,
**kwargs,
) -> Any:
) -> tuple[Any, Any]:
raise NotImplementedError(
f"There is no registered handler for training algorithm {algorithm} of type "
f"{type(algorithm)}! (kwargs: {kwargs})."
Expand Down Expand Up @@ -92,14 +92,17 @@ def evaluate_lightningmodule(
/,
*,
trainer: lightning.Trainer,
datamodule: lightning.LightningDataModule | None,
datamodule: lightning.LightningDataModule | None = None,
config: Config,
train_results: Any = None,
) -> tuple[MetricName, float | None, dict]:
"""Evaluates the algorithm and returns the metrics.
By default, if validation is to be performed, returns the validation error. Returns the
training error when `trainer.overfit_batches != 0` (e.g. when debugging or testing). Otherwise,
if `trainer.limit_val_batches == 0`, returns the test error.
"""
datamodule = datamodule or getattr(algorithm, "datamodule", None)

# exp.trainer.logger.log_hyperparams()
# When overfitting on a single batch or only training, we return the train error.
Expand Down Expand Up @@ -169,7 +172,7 @@ def instantiate_datamodule(
f"Datamodule was already instantiated (probably to interpolate a field value). "
f"{datamodule_config=}"
)
return
return datamodule_config

logger.debug(f"Instantiating datamodule from config: {datamodule_config}")
return hydra.utils.instantiate(datamodule_config)
Expand All @@ -180,8 +183,8 @@ def train_lightningmodule(
algorithm: lightning.LightningModule,
/,
*,
datamodule: lightning.LightningDataModule | None,
trainer: lightning.Trainer | None,
datamodule: lightning.LightningDataModule | None = None,
config: Config,
):
# Create the Trainer from the config.
Expand Down
17 changes: 11 additions & 6 deletions project/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,12 @@ def main(dict_config: DictConfig) -> dict:
global_log_level="DEBUG" if config.debug else "INFO" if config.verbose else "WARNING",
)

# Seed the random number generators, so the weights that are
# constructed are deterministic and reproducible.
lightning.seed_everything(seed=config.seed, workers=True)

# Create the algo.
algorithm = hydra.utils.instantiate(config.algorithm)
algorithm = instantiate_algorithm(config)

# Create the trainer
trainer = instantiate_trainer(config.trainer)
Expand Down Expand Up @@ -142,7 +146,9 @@ def setup_logging(log_level: str, global_log_level: str = "WARNING") -> None:
project_logger.setLevel(log_level.upper())


def instantiate_algorithm(config: Config) -> lightning.LightningModule | JaxModule:
def instantiate_algorithm(
config: Config, datamodule: lightning.LightningDataModule | None = None
) -> lightning.LightningModule | JaxModule:
"""Function used to instantiate the algorithm.
It is suggested that your algorithm (LightningModule) take in the `datamodule` and `network`
Expand All @@ -151,15 +157,14 @@ def instantiate_algorithm(config: Config) -> lightning.LightningModule | JaxModu
The instantiated datamodule and network will be passed to the algorithm's constructor.
"""
# seed the random number generators, so the weights that are
# constructed are deterministic and reproducible.
lightning.seed_everything(seed=config.seed, workers=True)

# Create the algorithm
algo_config = config.algorithm

# Create the datamodule (if present) from the config
datamodule: lightning.LightningDataModule | None = instantiate_datamodule(config.datamodule)
if datamodule is None and config.datamodule is not None:
datamodule = instantiate_datamodule(config.datamodule)

if datamodule:
algo_or_algo_partial = hydra.utils.instantiate(algo_config, datamodule=datamodule)
else:
Expand Down
49 changes: 20 additions & 29 deletions project/main_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# ADAPTED FROM https://github.com/facebookresearch/hydra/blob/main/examples/advanced/hydra_app_example/tests/test_example.py
from __future__ import annotations

import shlex
import shutil
import subprocess
import sys
import uuid
from unittest.mock import Mock
Expand All @@ -12,7 +14,9 @@
from _pytest.mark.structures import ParameterSet
from hydra.types import RunMode
from omegaconf import DictConfig
from pytest_regressions.file_regression import FileRegressionFixture

import project.experiment
import project.main
from project.conftest import command_line_overrides, skip_on_macOS_in_CI
from project.utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID
Expand Down Expand Up @@ -46,31 +50,20 @@ def test_torch_can_use_the_GPU():

@pytest.fixture
def mock_train(monkeypatch: pytest.MonkeyPatch):
mock_train_fn = Mock(spec=project.main.train)
mock_train_fn = Mock(spec=project.main.train, return_value=(None, None))
monkeypatch.setattr(project.main, project.main.train.__name__, mock_train_fn)
return mock_train_fn


@pytest.fixture
def mock_evaluate_lightningmodule(monkeypatch: pytest.MonkeyPatch):
mock_eval_lightningmodule = Mock(
spec=project.main.evaluate_lightningmodule, return_value=("fake", 0.0, {})
)
def mock_evaluate(monkeypatch: pytest.MonkeyPatch):
mock_eval = Mock(spec=project.experiment.evaluate, return_value=("fake", 0.0, {}))
monkeypatch.setattr(
project.main, project.main.evaluate_lightningmodule.__name__, mock_eval_lightningmodule
project.main,
project.experiment.evaluate.__name__,
mock_eval,
)
return mock_eval_lightningmodule


@pytest.fixture
def mock_evaluate_jax_module(monkeypatch: pytest.MonkeyPatch):
mock_eval_jax_module = Mock(
spec=project.main.evaluate_jax_module, return_value=("fake", 0.0, {})
)
monkeypatch.setattr(
project.main, project.main.evaluate_jax_module.__name__, mock_eval_jax_module
)
return mock_eval_jax_module
return mock_eval


experiment_configs = [p.stem for p in (CONFIG_DIR / "experiment").glob("*.yaml")]
Expand All @@ -93,11 +86,6 @@ def mock_evaluate_jax_module(monkeypatch: pytest.MonkeyPatch):
IN_GITHUB_CI,
reason="Remote launcher tries to do a git push, doesn't work in github CI.",
),
pytest.mark.xfail(
raises=TypeError,
reason="TODO: Getting a `TypeError: cannot pickle 'weakref.ReferenceType' object` error.",
strict=False,
),
],
),
pytest.param(
Expand Down Expand Up @@ -152,8 +140,7 @@ def test_experiment_config_is_tested(experiment_config: str):
def test_can_load_experiment_configs(
experiment_dictconfig: DictConfig,
mock_train: Mock,
mock_evaluate_lightningmodule: Mock,
mock_evaluate_jax_module: Mock,
mock_evaluate: Mock,
):
# Mock out some part of the `main` function to not actually run anything.
if experiment_dictconfig["hydra"]["mode"] == RunMode.MULTIRUN:
Expand All @@ -168,10 +155,7 @@ def test_can_load_experiment_configs(
assert results is not None

mock_train.assert_called_once()
# One of them should have been called once.
assert (mock_evaluate_lightningmodule.call_count == 1) ^ (
mock_evaluate_jax_module.call_count == 1
)
mock_evaluate.assert_called_once()


@pytest.mark.slow
Expand Down Expand Up @@ -212,6 +196,13 @@ def test_setting_just_algorithm_isnt_enough(experiment_dictconfig: DictConfig) -
_ = resolve_dictconfig(experiment_dictconfig)


def test_help_string(file_regression: FileRegressionFixture) -> None:
help_string = subprocess.run(
shlex.split("python project/main.py --help"), text=True, capture_output=True
).stderr
file_regression.check(help_string)


@pytest.mark.skipif(
IN_GITHUB_CI and sys.platform == "darwin",
reason="TODO: Getting a 'MPS backend out of memory' error on the Github CI. ",
Expand Down

0 comments on commit 23156e1

Please sign in to comment.