Skip to content

Commit

Permalink
enable ddp find_unused_parameterse (#111)
Browse files Browse the repository at this point in the history
* enable ddp find_unused_parameterse

* Find ununsed parameters only when finetuning

---------

Co-authored-by: gle-bellier <[email protected]>
  • Loading branch information
yurujaja and gle-bellier authored Nov 13, 2024
1 parent 4b3f9cf commit 9a795ad
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions pangaea/run.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import hashlib
import os as os
import pathlib
import pprint
import time
import hashlib

import hydra
import torch
Expand Down Expand Up @@ -40,7 +40,9 @@ def get_exp_info(hydra_config: HydraConf) -> dict[str, str]:
str: experiment information.
"""
choices = OmegaConf.to_container(hydra_config.runtime.choices)
cfg_hash = hashlib.sha1(OmegaConf.to_yaml(hydra_config).encode(), usedforsecurity=False).hexdigest()[:6]
cfg_hash = hashlib.sha1(
OmegaConf.to_yaml(hydra_config).encode(), usedforsecurity=False
).hexdigest()[:6]
timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
fm = choices["encoder"]
decoder = choices["decoder"]
Expand Down Expand Up @@ -134,7 +136,10 @@ def main(cfg: DictConfig) -> None:
)
decoder.to(device)
decoder = torch.nn.parallel.DistributedDataParallel(
decoder, device_ids=[local_rank], output_device=local_rank
decoder,
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=cfg.finetune,
)
logger.info(
"Built {} for with {} encoder.".format(
Expand Down Expand Up @@ -186,10 +191,13 @@ def main(cfg: DictConfig) -> None:
logger=logger,
)
raw_val_dataset = GeoFMSubset(raw_val_dataset, indices)


train_dataset = GeoFMDataset(raw_train_dataset, train_preprocessor, cfg.data_replicate)
val_dataset = GeoFMDataset(raw_val_dataset, val_preprocessor, cfg.data_replicate)

train_dataset = GeoFMDataset(
raw_train_dataset, train_preprocessor, cfg.data_replicate
)
val_dataset = GeoFMDataset(
raw_val_dataset, val_preprocessor, cfg.data_replicate
)

logger.info("Built {} dataset.".format(cfg.dataset.dataset_name))

Expand Down

0 comments on commit 9a795ad

Please sign in to comment.