From e2e1e071a39346aa1c68ac8b9b3c2788399f37b4 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 6 Dec 2024 15:16:39 -0500 Subject: [PATCH] Save datamodule on self in text_classifier.py Signed-off-by: Fabrice Normandin --- project/algorithms/text_classifier.py | 1 + project/experiment.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/project/algorithms/text_classifier.py b/project/algorithms/text_classifier.py index 2ef16b1a..1d9ea4e1 100644 --- a/project/algorithms/text_classifier.py +++ b/project/algorithms/text_classifier.py @@ -30,6 +30,7 @@ def __init__( init_seed: int = 42, ): super().__init__() + self.datamodule = datamodule self.network_config = network self.num_labels = datamodule.num_classes self.task_name = datamodule.task_name diff --git a/project/experiment.py b/project/experiment.py index 69c26eaa..bd5097a5 100644 --- a/project/experiment.py +++ b/project/experiment.py @@ -198,8 +198,10 @@ def train_lightningmodule( # example in RL, where we need to set the actor to use in the environment, as well as # potentially adding Wrappers on top of the environment, or having a replay buffer, etc. if datamodule is None: - datamodule = datamodule or getattr(algorithm, "datamodule", None) - + if hasattr(algorithm, "datamodule"): + datamodule = getattr(algorithm, "datamodule") + elif config.datamodule is not None: + datamodule = instantiate_datamodule(config.datamodule) trainer.fit(algorithm, datamodule=datamodule, ckpt_path=config.ckpt_path) train_results = None # todo: get the train results from the trainer. return algorithm, train_results