diff --git a/pangaea/run.py b/pangaea/run.py index 82f2c9a..478670e 100644 --- a/pangaea/run.py +++ b/pangaea/run.py @@ -1,8 +1,8 @@ +import hashlib import os as os import pathlib import pprint import time -import hashlib import hydra import torch @@ -39,7 +39,9 @@ def get_exp_info(hydra_config: HydraConf) -> dict[str, str]: str: experiment information. """ choices = OmegaConf.to_container(hydra_config.runtime.choices) - cfg_hash = hashlib.sha1(OmegaConf.to_yaml(hydra_config).encode(), usedforsecurity=False).hexdigest()[:6] + cfg_hash = hashlib.sha1( + OmegaConf.to_yaml(hydra_config).encode(), usedforsecurity=False + ).hexdigest()[:6] timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) fm = choices["encoder"] decoder = choices["decoder"] @@ -133,7 +135,10 @@ def main(cfg: DictConfig) -> None: ) decoder.to(device) decoder = torch.nn.parallel.DistributedDataParallel( - decoder, device_ids=[local_rank], output_device=local_rank + decoder, + device_ids=[local_rank], + output_device=local_rank, + find_unused_parameters=cfg.finetune, ) logger.info( "Built {} for with {} encoder.".format( @@ -185,10 +190,13 @@ def main(cfg: DictConfig) -> None: logger=logger, ) raw_val_dataset = GeoFMSubset(raw_val_dataset, indices) - - - train_dataset = GeoFMDataset(raw_train_dataset, train_preprocessor, cfg.data_replicate) - val_dataset = GeoFMDataset(raw_val_dataset, val_preprocessor, cfg.data_replicate) + + train_dataset = GeoFMDataset( + raw_train_dataset, train_preprocessor, cfg.data_replicate + ) + val_dataset = GeoFMDataset( + raw_val_dataset, val_preprocessor, cfg.data_replicate + ) logger.info("Built {} dataset.".format(cfg.dataset.dataset_name))