diff --git a/.pin/constraints-cuda-torch.txt b/.pin/constraints-cuda-torch.txt index 8990a5f26..5b804aa17 100644 --- a/.pin/constraints-cuda-torch.txt +++ b/.pin/constraints-cuda-torch.txt @@ -768,7 +768,7 @@ six==1.16.0 # tensorflow-probability smmap==5.0.1 # via gitdb -submitit==1.5.1 +submitit==1.5.2 # via # -r benchmarks/dinov2/requirements.in # -r benchmarks/vjepa/requirements.in @@ -864,6 +864,8 @@ torch-sparse==0.6.18+pt24cu121 # via # -r benchmarks/geo_gnn/requirements.in # -r benchmarks/recursiongfn/requirements.in +torchao==0.5.0+cu121 + # via -r benchmarks/llm/requirements.in torchcompat==1.1.4 # via # -c .pin/../constraints/cuda.txt @@ -955,7 +957,7 @@ typing-extensions==4.12.2 # torch # typeguard # tyro -tyro==0.8.10 +tyro==0.8.11 # via # -r benchmarks/torchatari/requirements.in # navix diff --git a/benchmarks/dinov2/requirements.cuda.txt b/benchmarks/dinov2/requirements.cuda.txt index 78ac05616..3ea8a7b4a 100644 --- a/benchmarks/dinov2/requirements.cuda.txt +++ b/benchmarks/dinov2/requirements.cuda.txt @@ -246,7 +246,7 @@ six==1.16.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens -submitit==1.5.1 +submitit==1.5.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/dinov2/requirements.in diff --git a/benchmarks/llm/requirements.cuda.txt b/benchmarks/llm/requirements.cuda.txt index aec9cd89b..854bbd102 100644 --- a/benchmarks/llm/requirements.cuda.txt +++ b/benchmarks/llm/requirements.cuda.txt @@ -402,6 +402,10 @@ torch==2.4.0+cu121 # accelerate # fairscale # xformers +torchao==0.5.0+cu121 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.in torchtune==0.3.0+cu121 # via # -c .pin/../.pin/constraints-cuda-torch.txt diff --git a/benchmarks/llm/requirements.in b/benchmarks/llm/requirements.in index 91b62c073..36832ad67 100644 --- a/benchmarks/llm/requirements.in +++ b/benchmarks/llm/requirements.in @@ -4,6 +4,7 @@ torch PyYAML argklass fairscale +torchao # Prepare accelerate diff --git a/benchmarks/purejaxrl/requirements.cuda.txt b/benchmarks/purejaxrl/requirements.cuda.txt index 8c89f572f..c567bef87 100644 --- a/benchmarks/purejaxrl/requirements.cuda.txt +++ b/benchmarks/purejaxrl/requirements.cuda.txt @@ -734,7 +734,7 @@ typing-extensions==4.12.2 # reactivex # torch # tyro -tyro==0.8.10 +tyro==0.8.11 # via # -c .pin/../.pin/constraints-cuda-torch.txt # navix diff --git a/benchmarks/rlhf/requirements.cuda.txt b/benchmarks/rlhf/requirements.cuda.txt index 788f69ff4..df5fa0f95 100644 --- a/benchmarks/rlhf/requirements.cuda.txt +++ b/benchmarks/rlhf/requirements.cuda.txt @@ -384,7 +384,7 @@ typing-extensions==4.12.2 # reactivex # torch # tyro -tyro==0.8.10 +tyro==0.8.11 # via # -c .pin/../.pin/constraints-cuda-torch.txt # trl diff --git a/benchmarks/torchatari/requirements.cuda.txt b/benchmarks/torchatari/requirements.cuda.txt index fdb39ed50..baebdd7b4 100644 --- a/benchmarks/torchatari/requirements.cuda.txt +++ b/benchmarks/torchatari/requirements.cuda.txt @@ -337,7 +337,7 @@ typing-extensions==4.12.2 # reactivex # torch # tyro -tyro==0.8.10 +tyro==0.8.11 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/torchatari/requirements.in diff --git a/benchmarks/vjepa/requirements.cuda.txt b/benchmarks/vjepa/requirements.cuda.txt index f87b5c6d3..4efa6b922 100644 --- a/benchmarks/vjepa/requirements.cuda.txt +++ b/benchmarks/vjepa/requirements.cuda.txt @@ -290,7 +290,7 @@ six==1.16.0 # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens # python-dateutil -submitit==1.5.1 +submitit==1.5.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/vjepa/requirements.in diff --git a/config/scaling.yaml b/config/scaling.yaml index b6226a37d..fe37379b3 100644 --- a/config/scaling.yaml +++ b/config/scaling.yaml @@ -55,13 +55,13 @@ bert-tf32-fp16: 112: 81140.75 MiB optimized: 128 bf16: {} -# brax: -# arg: --batch-size -# model: -# 1024: 4912.25 MiB -# cleanrljax: -# arg: --num_steps -# optimized: 128 +brax: + arg: --batch-size + model: + 1024: 4912.25 MiB +cleanrljax: + arg: --num_steps + optimized: 128 convnext_large-fp16: arg: --batch-size model: @@ -194,16 +194,19 @@ diffusion-single: 4: 23478.75 MiB 16: 33850.25 MiB 32: 55354.25 MiB -pna: - arg: --batch-size - dimenet: arg: --batch-size model: 2: 452.6875 MiB 4: 1604.25 MiB + 24: 4776.25 MiB + 56: 6330.25 MiB 64: 12274.25 MiB + 112: 15294.25 MiB 128: 13002.25 MiB + 240: 67506.25 MiB + 280: 56556.25 MiB + 488: 80406.25 MiB dinov2-giant-gpus: arg: train.batch_size_per_gpu={batch_size} model: @@ -211,7 +214,8 @@ dinov2-giant-gpus: 2: 32252.25 MiB 4: 32404.25 MiB 16: 38350.25 MiB - 32: 69614 MiB + 24: 48856.25 MiB + 32: 72102.25 MiB optimized: 32 dinov2-giant-nodes: arg: train.batch_size_per_gpu={batch_size} @@ -222,16 +226,17 @@ dinov2-giant-single: 2: 20682.25 MiB 4: 20682.25 MiB 16: 52748.25 MiB + 24: 60792.25 MiB 32: 74544.25 MiB dlrm: {} -# dqn: -# arg: --buffer_batch_size -# model: -# 1024: 81.81005859375 MiB -# 2048: 83.40380859375 MiB -# 32768: 131.21630859375 MiB -# 65536: 182.21630859375 MiB -# optimized: 128 +dqn: + arg: --buffer_batch_size + model: + 1024: 81.81005859375 MiB + 2048: 83.40380859375 MiB + 32768: 131.21630859375 MiB + 65536: 182.21630859375 MiB + optimized: 128 focalnet: arg: --batch-size model: @@ -260,9 +265,15 @@ lightning: 2: 1054.25 MiB 4: 1856.25 MiB 16: 4728.25 MiB + 24: 5482.25 MiB 32: 6352.25 MiB + 56: 1054.25 MiB 64: 1856.25 MiB + 120: 14522.25 MiB 128: 14818.25 MiB + 240: 25480.25 MiB + 488: 49042.25 MiB + 664: 65914.25 MiB lightning-gpus: arg: --batch-size model: @@ -271,7 +282,12 @@ lightning-gpus: 4: 1156.75 MiB 8: 1260.75 MiB 16: 4150.75 MiB + 48: 11056.25 MiB + 112: 16776.25 MiB 128: 15858 MiB + 240: 28942.25 MiB + 504: 54100.25 MiB + 624: 65386.25 MiB optimized: 16 llama: {} llava-gpus: @@ -280,6 +296,7 @@ llava-gpus: llava-single: arg: --batch_size model: + 1: 72614.25 MiB 2: 15168.25 MiB 4: 72362.25 MiB optimized: 1 @@ -341,18 +358,21 @@ opt-6_7b-multinode: model: 1: 55380 MiB optimized: 1 -# ppo: -# arg: --num_steps -# model: -# 8: 80.791748046875 MiB -# 16: 80.916748046875 MiB -# 32: 81.166748046875 MiB -# 64: 81.666748046875 MiB -# 128: 82.666748046875 MiB -# 1024: 96.666748046875 MiB -# 2048: 132.484619140625 MiB -# 4096: 205.328369140625 MiB -# optimized: 32 +pna: + arg: --batch-size +ppo: + arg: --num_steps + model: + 8: 80.791748046875 MiB + 16: 80.916748046875 MiB + 32: 81.166748046875 MiB + 64: 81.666748046875 MiB + 128: 82.666748046875 MiB + 1024: 96.666748046875 MiB + 2048: 132.484619140625 MiB + 4096: 205.328369140625 MiB + 2517448: 62094.25 MiB + optimized: 32 recursiongfn: arg: --batch_size model: @@ -477,6 +497,11 @@ resnet50-noio: 4: 1854.25 MiB 16: 3052.25 MiB 32: 4690.25 MiB + 56: 7114.25 MiB + 136: 15194.25 MiB + 288: 30632.25 MiB + 592: 64483.8125 MiB + 736: 76050.25 MiB rlhf-gpus: arg: --per_device_train_batch_size model: @@ -487,6 +512,9 @@ rlhf-gpus: 32: 17918.25 MiB 64: 24374.25 MiB 128: 25830.25 MiB + 136: 29442.25 MiB + 392: 15372.25 MiB + 520: 15808.25 MiB optimized: 64 rlhf-single: arg: --per_device_train_batch_size @@ -496,8 +524,12 @@ rlhf-single: 4: 8822.25 MiB 16: 9694.25 MiB 32: 12952.25 MiB + 40: 14638.25 MiB 64: 19422.25 MiB + 120: 31048.25 MiB 128: 32442.25 MiB + 280: 63262.25 MiB + 352: 77536.25 MiB optimized: 64 rwkv: arg: --micro_bsz @@ -553,10 +585,11 @@ torchatari: vjepa-gpus: arg: --batch_size model: + 1: 27196.25 MiB 2: 28896.25 MiB 4: 30784.25 MiB 16: 52722.25 MiB - 32: 76372.25 MiB + 32: 77124.25 MiB optimized: 24 vjepa-single: arg: --batch_size @@ -564,8 +597,10 @@ vjepa-single: 1: 6644.25 MiB 2: 18984.25 MiB 4: 11860.25 MiB + 8: 30764.25 MiB 16: 45516.25 MiB - 32: 70586.25 MiB + 24: 57574.25 MiB + 32: 67122.25 MiB optimized: 24 whisper: arg: --batch-size diff --git a/scripts/article/run_cuda.sh b/scripts/article/run_cuda.sh index 83c956d18..4135da37f 100644 --- a/scripts/article/run_cuda.sh +++ b/scripts/article/run_cuda.sh @@ -88,9 +88,9 @@ if [ "$MILABENCH_PREPARE" -eq 0 ]; then # milabench prepare --system $MILABENCH_WORDIR/system.yaml $ARGS # pip install torch - # milabench pin --variant cuda --from-scratch $ARGS - # milabench install --system $MILABENCH_WORDIR/system.yaml --force $ARGS - # milabench prepare --system $MILABENCH_WORDIR/system.yaml $ARGS + milabench pin --variant cuda --from-scratch $ARGS + milabench install --system $MILABENCH_WORDIR/system.yaml --force $ARGS + milabench prepare --system $MILABENCH_WORDIR/system.yaml $ARGS ARGS="--select resnet50-noio,brax,lightning,dinov2-giant-single,dinov2-giant-gpus,llm-lora-ddp-gpus,llm-lora-ddp-nodes,llm-lora-mp-gpus,llm-full-mp-gpus,llm-full-mp-nodes,dqn,ppo,dimenet,llava-single,rlhf-single,rlhf-gpus,vjepa-single,vjepa-gpus"