Skip to content

Commit

Permalink
[misc] Small ddp script adjustments (#1793)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixT2K authored Nov 22, 2024
1 parent c3ec3cb commit aca6b36
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
4 changes: 2 additions & 2 deletions references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ or PyTorch:
```shell
python references/detection/train_pytorch.py db_resnet50 --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5
```

### Multi-GPU support (PyTorch only)

Multi-GPU support on Detection task with PyTorch has been added.
Multi-GPU support on Detection task with PyTorch has been added.
Arguments are the same than the ones from single GPU, except:

- `--devices`: **by default, if you do not pass `--devices`, it will use all GPUs on your computer**.
Expand All @@ -41,7 +42,6 @@ device_names = [torch.cuda.get_device_name(d) for d in devices]
- `--backend`: you can specify another `backend` for `DistribuedDataParallel` if the default one is not available on
your operating system. Fastest one is `nccl` according to [PyTorch Documentation](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html).


```shell
python references/detection/train_pytorch_ddp.py db_resnet50 --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 --devices 0 1 --backend nccl
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import hashlib
import multiprocessing
import time

import numpy as np
import torch

Expand Down Expand Up @@ -330,9 +331,7 @@ def main(rank: int, world_size: int, args):
pin_memory=torch.cuda.is_available(),
collate_fn=train_set.collate_fn,
)
print(
f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)"
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)")

with open(os.path.join(args.train_path, "labels.json"), "rb") as f:
train_hash = hashlib.sha256(f.read()).hexdigest()
Expand Down Expand Up @@ -446,7 +445,7 @@ def parse_args():
import argparse

parser = argparse.ArgumentParser(
description="DocTR training script for text detection (PyTorch)",
description="DocTR DDP training script for text detection (PyTorch)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

Expand Down Expand Up @@ -505,7 +504,7 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
if not torch.cuda.is_available():
raise AssertionError("PyTorch cannot access your GPUs. please look into it bro !!!")
raise AssertionError("PyTorch cannot access your GPUs. Please investigate!")

if not isinstance(args.devices, list):
args.devices = list(range(torch.cuda.device_count()))
Expand Down
6 changes: 4 additions & 2 deletions references/recognition/train_pytorch_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,13 @@ def parse_args():
import argparse

parser = argparse.ArgumentParser(
description="DocTR training script for text recognition (PyTorch)",
description="DocTR DDP training script for text recognition (PyTorch)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

# DDP related args
parser.add_argument("--backend", default="nccl", type=str, help="Backend to use for Torch DDP")

parser.add_argument("arch", type=str, help="text-recognition model to train")
parser.add_argument("--train_path", type=str, default=None, help="path to train data folder(s)")
parser.add_argument("--val_path", type=str, default=None, help="path to val data folder")
Expand All @@ -384,7 +387,6 @@ def parse_args():
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size for training")
parser.add_argument("--backend", default="nccl", type=str, help="Backend to use for Torch DDP")
parser.add_argument("--devices", default=None, nargs="+", type=int, help="GPU devices to use for training")
parser.add_argument("--input_size", type=int, default=32, help="input size H for the model, W = 4*H")
parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam)")
Expand Down

0 comments on commit aca6b36

Please sign in to comment.