Skip to content

Commit

Permalink
Update batch sizer model
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Sep 18, 2024
1 parent 46647ff commit 71e45c7
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 39 deletions.
17 changes: 12 additions & 5 deletions benchmarks/geo_gnn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,20 @@ def train_degree(train_dataset):
# Compute the maximum in-degree in the training data.
max_degree = -1
for data in train_dataset:
d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
max_degree = max(max_degree, int(d.max()))
try:
d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
max_degree = max(max_degree, int(d.max()))
except TypeError:
pass

# Compute the in-degree histogram tensor
deg = torch.zeros(max_degree + 1, dtype=torch.long)
for data in train_dataset:
d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
deg += torch.bincount(d, minlength=deg.numel())
try:
d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
deg += torch.bincount(d, minlength=deg.numel())
except TypeError:
pass

return deg

Expand All @@ -109,13 +115,14 @@ def batch_size(x):
observer = BenchObserver(batch_size_fn=batch_size)

train_dataset = PCQM4Mv2Subset(args.num_samples, args.root)
degree = train_degree(train_dataset)

sample = next(iter(train_dataset))

info = models[args.model](
args,
sample=sample,
degree=lambda: train_degree(train_dataset),
degree=lambda: degree,
)

TRAIN_mean, TRAIN_std = (
Expand Down
11 changes: 8 additions & 3 deletions config/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -733,18 +733,23 @@ _geo_gnn:
plan:
method: per_gpu

pna:
inherits: _geo_gnn
argv:
--model: 'PNA'
--num-samples: 100000
--batch-size: 4096
--num-workers: "auto({n_worker}, 0)"

dimenet:
inherits: _geo_gnn
tags:
- monogpu
argv:
--model: 'DimeNet'
--num-samples: 10000
--use3d: True
--batch-size: 16
--num-workers: "auto({n_worker}, 0)"


recursiongfn:
inherits: _defaults
definition: ../benchmarks/recursiongfn
Expand Down
57 changes: 30 additions & 27 deletions config/scaling.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ bert-tf32-fp16:
112: 81140.75 MiB
optimized: 128
bf16: {}
brax:
arg: --batch-size
model:
1024: 4912.25 MiB
cleanrljax:
arg: --num_steps
optimized: 128
# brax:
# arg: --batch-size
# model:
# 1024: 4912.25 MiB
# cleanrljax:
# arg: --num_steps
# optimized: 128
convnext_large-fp16:
arg: --batch-size
model:
Expand Down Expand Up @@ -194,6 +194,9 @@ diffusion-single:
4: 23478.75 MiB
16: 33850.25 MiB
32: 55354.25 MiB
pna:
arg: --batch-size

dimenet:
arg: --batch-size
model:
Expand Down Expand Up @@ -221,14 +224,14 @@ dinov2-giant-single:
16: 52748.25 MiB
32: 74544.25 MiB
dlrm: {}
dqn:
arg: --buffer_batch_size
model:
1024: 81.81005859375 MiB
2048: 83.40380859375 MiB
32768: 131.21630859375 MiB
65536: 182.21630859375 MiB
optimized: 128
# dqn:
# arg: --buffer_batch_size
# model:
# 1024: 81.81005859375 MiB
# 2048: 83.40380859375 MiB
# 32768: 131.21630859375 MiB
# 65536: 182.21630859375 MiB
# optimized: 128
focalnet:
arg: --batch-size
model:
Expand Down Expand Up @@ -338,18 +341,18 @@ opt-6_7b-multinode:
model:
1: 55380 MiB
optimized: 1
ppo:
arg: --num_steps
model:
8: 80.791748046875 MiB
16: 80.916748046875 MiB
32: 81.166748046875 MiB
64: 81.666748046875 MiB
128: 82.666748046875 MiB
1024: 96.666748046875 MiB
2048: 132.484619140625 MiB
4096: 205.328369140625 MiB
optimized: 32
# ppo:
# arg: --num_steps
# model:
# 8: 80.791748046875 MiB
# 16: 80.916748046875 MiB
# 32: 81.166748046875 MiB
# 64: 81.666748046875 MiB
# 128: 82.666748046875 MiB
# 1024: 96.666748046875 MiB
# 2048: 132.484619140625 MiB
# 4096: 205.328369140625 MiB
# optimized: 32
recursiongfn:
arg: --batch_size
model:
Expand Down
5 changes: 4 additions & 1 deletion config/standard.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ ppo:
enabled: true
weight: 1.0


cleanrljax:
enabled: false
weight: 1.0
Expand All @@ -172,6 +171,10 @@ dimenet:
enabled: true
weight: 1.0

pna:
enabled: False
weight: 1.0

recursiongfn:
enabled: true
weight: 1.0
Expand Down
8 changes: 5 additions & 3 deletions scripts/article/run_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,15 @@ if [ "$MILABENCH_PREPARE" -eq 0 ]; then
ARGS="--select resnet50-noio,brax,lightning,dinov2-giant-single,dinov2-giant-gpus,llm-lora-ddp-gpus,llm-lora-ddp-nodes,llm-lora-mp-gpus,llm-full-mp-gpus,llm-full-mp-nodes,dqn,ppo,dimenet,llava-single,rlhf-single,rlhf-gpus,vjepa-single,vjepa-gpus"

MEMORY_CAPACITY=("4Go" "8Go" "16Go" "32Go" "64Go" "80Go")

# MEMORY_CAPACITY=("2048" "4096" "8192")

# Run the benchmakrs
for CAPACITY in "${MEMORY_CAPACITY[@]}"; do
export MILABENCH_SIZER_AUTO=1
export MILABENCH_SIZER_MULTIPLE=8
export MILABENCH_SIZER_BATCH_SIZE=$CAPACITY
milabench run --run-name "c$CAPACITY.{time}" --system $MILABENCH_WORDIR/system.yaml $ARGS || true
export MILABENCH_SIZER_CAPACITY=$CAPACITY
# export MILABENCH_SIZER_BATCH_SIZE=$CAPACITY
milabench run --run-name "bs$CAPACITY.{time}" --system $MILABENCH_WORDIR/system.yaml $ARGS|| true
done

#
Expand Down

0 comments on commit 71e45c7

Please sign in to comment.