Skip to content

Commit

Permalink
Merge TorchGeo env with main torch env
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Sep 11, 2024
1 parent 8dfefe0 commit cb1b231
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 15 deletions.
2 changes: 0 additions & 2 deletions benchmarks/llava/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ def batch_size_fn(batch):
for k, v in inputs.items()
}

inputs["labels"] = inputs["input_ids"]

outputs = model(**inputs)

loss = outputs.loss
Expand Down
7 changes: 7 additions & 0 deletions benchmarks/vjepa/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,9 @@ def main():
parser = ArgumentParser()
parser.add_argument("--dataset", help="path to the csv that list all videos", type=str)
parser.add_argument("--output", help="path to an output directory", type=str)
parser.add_argument("--batch_size", type=int, default=24)
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--num_workers", type=int, default=12)
args = parser.parse_args()

# relying on environment variables is annoying in multinode setups
Expand All @@ -621,6 +624,10 @@ def main():
logger.info('loaded params...')

params["data"]["datasets"] = [args.dataset]
params["data"]["batch_size"] = args.batch_size
params["data"]["num_frames"] = args.num_frames
params["data"]["num_workers"] = args.num_workers

params["logging"]["folder"] = args.output

gpu_per_nodes = int(os.getenv("LOCAL_WORLD_SIZE", 1))
Expand Down
13 changes: 7 additions & 6 deletions config/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ _flops:
- diagnostic
- flops
- monogpu
- nobatch

argv:
--number: 10
Expand All @@ -72,6 +73,7 @@ llama:
- llm
- inference
- monogpu
- nobatch

voir:
options:
Expand Down Expand Up @@ -713,10 +715,7 @@ _geo_gnn:
tags:
- monogpu
- graph
# FIXME: torch cluster is laging behind pytorch
# we are forced to use torch==2.3 instead of torch==2.4
install_group: gnn
group: geo_gnn
install_group: torch
definition: ../benchmarks/geo_gnn
plan:
method: per_gpu
Expand All @@ -733,8 +732,7 @@ dimenet:
recursiongfn:
inherits: _defaults
definition: ../benchmarks/recursiongfn
install_group: gnn
group: recursiongfn_gnn
install_group: torch
tags:
- graph
- monogpu
Expand Down Expand Up @@ -838,6 +836,9 @@ _vjepa:
definition: ../benchmarks/vjepa
tags:
- video
argv:
--batch_size: 24
--num_workers: "auto({n_worker}, 12)"

vjepa-single:
inherits: _vjepa
Expand Down
33 changes: 33 additions & 0 deletions config/scaling.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,36 @@ whisper:
128: 71634.375 MiB
144: 80412.75 MiB
optimized: 128


llava-single:
arg: --batch_size
optimized: 1

llava-gpus:
arg: --batch_size
optimized: 1

rlhf-single:
arg: --per_device_train_batch_size
optimized: 64

rlhf-gpus:
arg: --per_device_train_batch_size
optimized: 64

vjepa-single:
arg: --batch_size
optimized: 24

vjepa-gpus:
arg: --batch_size
optimized: 24

ppo:
arg: --num_minibatches
optimized: 32

dqn:
arg: --buffer_batch_size
optimized: 128
4 changes: 0 additions & 4 deletions constraints/extra/gnn.cuda.txt

This file was deleted.

Empty file removed constraints/extra/gnn.hpu.txt
Empty file.
Empty file removed constraints/extra/gnn.rocm.txt
Empty file.
Empty file removed constraints/extra/gnn.xpu.txt
Empty file.
6 changes: 5 additions & 1 deletion constraints/extra/torch.cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@ jax[cuda12]
# --extra-index-url https://download.pytorch.org/whl/cu121
# --find-links https://download.pytorch.org/whl/xformers/

xformers==0.0.27.post2
xformers==0.0.27.post2


# Torch geometric
--find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html
6 changes: 4 additions & 2 deletions scripts/article/run_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ export MILABENCH_BASE="$MILABENCH_WORDIR/results"

export MILABENCH_VENV="$MILABENCH_WORDIR/env"
export BENCHMARK_VENV="$MILABENCH_WORDIR/results/venv/torch"
export MILABENCH_SIZER_SAVE="$MILABENCH_WORDIR/scaling.yaml"


if [ -z "${MILABENCH_PREPARE}" ]; then
export MILABENCH_PREPARE=0
Expand Down Expand Up @@ -85,7 +87,7 @@ if [ "$MILABENCH_PREPARE" -eq 0 ]; then
. $MILABENCH_WORDIR/env/bin/activate

# milabench pin --variant cuda --from-scratch $ARGS
milabench install --system $MILABENCH_WORDIR/system.yaml $ARGS --force
# milabench install --system $MILABENCH_WORDIR/system.yaml $ARGS --force
# milabench prepare --system $MILABENCH_WORDIR/system.yaml $ARGS

#
Expand All @@ -94,5 +96,5 @@ if [ "$MILABENCH_PREPARE" -eq 0 ]; then

#
# Display report
# milabench report --runs $MILABENCH_WORDIR/results/runs
milabench report --runs $MILABENCH_WORDIR/results/runs
fi

0 comments on commit cb1b231

Please sign in to comment.