diff --git a/eval_copy_detection.py b/eval_copy_detection.py index 73dcd5078..320800a79 100644 --- a/eval_copy_detection.py +++ b/eval_copy_detection.py @@ -223,7 +223,12 @@ def extract_features(image_list, model, args): parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training; see https://pytorch.org/docs/stable/distributed.html""") - parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") + # In pytorch 2.0 argument name changes to --local-rank + if torch.__version__ >= "2.0.0": + parser.add_argument("--local-rank", default=0, type=int, help="Please ignore and do not set this argument.") + else : + parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") + args = parser.parse_args() utils.init_distributed_mode(args) diff --git a/eval_image_retrieval.py b/eval_image_retrieval.py index 999f8c900..2017a9c04 100644 --- a/eval_image_retrieval.py +++ b/eval_image_retrieval.py @@ -94,7 +94,13 @@ def config_qimname(cfg, i): parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training; see https://pytorch.org/docs/stable/distributed.html""") - parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") + # In pytorch 2.0 argument name changes to --local-rank + if torch.__version__ >= "2.0.0": + parser.add_argument("--local-rank", default=0, type=int, help="Please ignore and do not set this argument.") + else : + parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") + + args = parser.parse_args() utils.init_distributed_mode(args) diff --git a/eval_knn.py b/eval_knn.py index fe99a2604..15330ded0 100644 --- a/eval_knn.py +++ b/eval_knn.py @@ -209,7 +209,12 @@ def __getitem__(self, idx): parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training; see https://pytorch.org/docs/stable/distributed.html""") - parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") + # In pytorch 2.0 argument name changes to --local-rank + if torch.__version__ >= "2.0.0": + parser.add_argument("--local-rank", default=0, type=int, help="Please ignore and do not set this argument.") + else : + parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") + parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) args = parser.parse_args() diff --git a/eval_linear.py b/eval_linear.py index cdef16b47..03e383b24 100644 --- a/eval_linear.py +++ b/eval_linear.py @@ -270,7 +270,11 @@ def forward(self, x): parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training; see https://pytorch.org/docs/stable/distributed.html""") - parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") + # In pytorch 2.0 argument name changes to --local-rank + if torch.__version__ >= "2.0.0": + parser.add_argument("--local-rank", default=0, type=int, help="Please ignore and do not set this argument.") + else : + parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.") diff --git a/main_dino.py b/main_dino.py index cade9873d..78a91304e 100644 --- a/main_dino.py +++ b/main_dino.py @@ -125,9 +125,13 @@ def get_args_parser(): parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training; see https://pytorch.org/docs/stable/distributed.html""") - parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") - return parser + # In pytorch 2.0 argument name changes to --local-rank + if torch.__version__ >= "2.0.0": + parser.add_argument("--local-rank", default=0, type=int, help="Please ignore and do not set this argument.") + else : + parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") + return parser def train_dino(args): utils.init_distributed_mode(args) @@ -221,6 +225,10 @@ def train_dino(args): args.epochs, ).cuda() + # torch.compile() is a new feature in PyTorch 2.0 that can improve the performance of PyTorch code. + if torch.__version__ >= "2.0.0": + dino_loss = torch.compile(dino_loss) + # ============ preparing optimizer ... ============ params_groups = utils.get_params_groups(student) if args.optimizer == "adamw":