diff --git a/pangaea/run.py b/pangaea/run.py index 82f2c9a..6ca3978 100644 --- a/pangaea/run.py +++ b/pangaea/run.py @@ -133,7 +133,7 @@ def main(cfg: DictConfig) -> None: ) decoder.to(device) decoder = torch.nn.parallel.DistributedDataParallel( - decoder, device_ids=[local_rank], output_device=local_rank + decoder, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True ) logger.info( "Built {} for with {} encoder.".format(