diff --git a/generic_trainer/configs.py b/generic_trainer/configs.py index 3ec0f6e..7fb284f 100644 --- a/generic_trainer/configs.py +++ b/generic_trainer/configs.py @@ -240,6 +240,7 @@ class InferenceConfig(Config): def string_to_object(self, key, value): if key == 'model_save_dir': self.pretrained_model_path = os.path.join(value, 'best_model.pth') + return value @dataclasses.dataclass