Skip to content

Commit

Permalink
Save datamodule on self in text_classifier.py
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Dec 6, 2024
1 parent 4ca6d9e commit e2e1e07
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions project/algorithms/text_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
init_seed: int = 42,
):
super().__init__()
self.datamodule = datamodule
self.network_config = network
self.num_labels = datamodule.num_classes
self.task_name = datamodule.task_name
Expand Down
6 changes: 4 additions & 2 deletions project/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,10 @@ def train_lightningmodule(
# example in RL, where we need to set the actor to use in the environment, as well as
# potentially adding Wrappers on top of the environment, or having a replay buffer, etc.
if datamodule is None:
datamodule = datamodule or getattr(algorithm, "datamodule", None)

if hasattr(algorithm, "datamodule"):
datamodule = getattr(algorithm, "datamodule")
elif config.datamodule is not None:
datamodule = instantiate_datamodule(config.datamodule)
trainer.fit(algorithm, datamodule=datamodule, ckpt_path=config.ckpt_path)
train_results = None # todo: get the train results from the trainer.
return algorithm, train_results

0 comments on commit e2e1e07

Please sign in to comment.