Skip to content

Commit

Permalink
v0.1.7
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Dec 4, 2022
1 parent e2d9312 commit 2477c5d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pip install "torchvision>=0.13.*"
python examples/gan.py

### contrastive_learning.py
pip install "torchvision>=0.13.*" "torchmetrics>=0.10.2" "scikit-learn>=1.1.*"
pip install "torchvision>=0.13.*" "scikit-learn>=1.1.*"
python examples/contrastive_learning.py

### gnn.py gnn2.py
Expand Down
4 changes: 2 additions & 2 deletions examples/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def training_epoch_end(self) -> Dict[str, float]:
ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64)
lmodel = MyLModule(model, [optimizer], "loss")
trainer = ml.Trainer(lmodel, [0], 20, RUNS_DIR, gradient_clip_norm=10,
val_every_n_epoch=10, verbose=True, model_fpath=ckpt_path)
val_every_n_epoch=10, verbose=True, ckpt_fpath=ckpt_path)
logger.info(trainer.test(ldm.val_dataloader, True, True))
logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader))
logger.info(trainer.test(ldm.test_dataloader, True, True))
Expand All @@ -152,5 +152,5 @@ def training_epoch_end(self) -> Dict[str, float]:
model = MLP_L2(2, 4, 1)
ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64)
lmodel = MyLModule(model, [], "loss")
trainer = ml.Trainer(lmodel, [], None, RUNS_DIR, model_fpath=ckpt_path)
trainer = ml.Trainer(lmodel, [], None, RUNS_DIR, ckpt_fpath=ckpt_path)
logger.info(trainer.test(ldm.test_dataloader, True, True))
11 changes: 6 additions & 5 deletions mini_lightning/_mini_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def __init__(
gradient_clip_norm: Optional[float] = None,
sync_bn: bool = False,
replace_sampler_ddp: bool = True,
model_fpath: Optional[str] = None,
ckpt_fpath: Optional[str] = None,
#
val_every_n_epoch: int = 1,
log_every_n_steps: int = 10,
Expand Down Expand Up @@ -360,7 +360,7 @@ def __init__(
replace_sampler_ddp=False: each gpu will use the complete dataset.
replace_sampler_ddp=True: It will slice the dataset into world_size chunks and distribute them to each gpu.
note: Replace train_dataloader only. Because DDP uses a single gpu for val/test.
model_fpath: only load model_state_dict.
ckpt_fpath: only load model_state_dict.
If you want to resume from ckpt. please see `save_optimizers_state_dict` and examples in `examples/test_env.py`
*
val_every_n_epoch: Frequency of validation and prog_bar_leave of training. (the last epoch will always be validated)
Expand Down Expand Up @@ -472,9 +472,10 @@ def __init__(
hparams = lmodel.hparams
self.save_hparams(hparams)
#
if model_fpath is not None:
self._load_ckpt(model_fpath)
logger.info(f"Using ckpt: {model_fpath}")
self.ckpt_fpath = ckpt_fpath
if ckpt_fpath is not None:
self._load_ckpt(ckpt_fpath)
logger.info(f"Using ckpt: {ckpt_fpath}")
lmodel.trainer_init(self)
for s in lmodel._models:
model: Module = getattr(lmodel, s)
Expand Down

0 comments on commit 2477c5d

Please sign in to comment.