diff --git a/examples/imagenet/README.md b/examples/imagenet/README.md index f5995f0..37e24bc 100644 --- a/examples/imagenet/README.md +++ b/examples/imagenet/README.md @@ -30,7 +30,7 @@ On A100 (80GB), it takes less than 4 hours to compute the pairwise scores with q ## Computing Pairwise Influence Scores with DDP -You can also use DistributedDataParallel to speed up influence computations. +You can also use DistributedDataParallel to speed up influence computations. You can run: ```bash -torchrun --standalone --nnodes=1 --nproc-per-node=4 ddp_analyze.py +torchrun --standalone --nnodes=1 --nproc-per-node=2 ddp_analyze.py ``` \ No newline at end of file diff --git a/examples/imagenet/ddp_analyze.py b/examples/imagenet/ddp_analyze.py index dbbafef..1655560 100644 --- a/examples/imagenet/ddp_analyze.py +++ b/examples/imagenet/ddp_analyze.py @@ -37,7 +37,7 @@ def parse_args(): help="Rank for the low-rank query gradient approximation.", ) parser.add_argument( - "--covariance_batch_size", + "--factor_batch_size", type=int, default=512, help="Batch size for computing query gradients.", @@ -104,7 +104,7 @@ def main(): analyzer.fit_all_factors( factors_name=args.factor_strategy, dataset=train_dataset, - per_device_batch_size=None, + per_device_batch_size=args.factor_batch_size, factor_args=factor_args, overwrite_output_dir=False, )