Skip to content

Commit

Permalink
Merge pull request #265 from mila-iqia/staging
Browse files Browse the repository at this point in the history
Staging
  • Loading branch information
Delaunay authored Sep 5, 2024
2 parents 95f5fc9 + 554f136 commit 6cb2c92
Show file tree
Hide file tree
Showing 42 changed files with 2,773 additions and 92 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ scripts/article/xpu/
dependencies/
benchmarks/recursiongfn/gflownet
benchmarks/recursiongfn/logs/
benchmarks/llm/tune/

scripts/inventory.yaml
output/
Expand Down
8 changes: 4 additions & 4 deletions .pin/constraints-cuda-gnn.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 11 additions & 26 deletions .pin/constraints-cuda-torch.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions benchmarks/brax/requirements.cuda.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions benchmarks/cleanrl_jax/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Use global base if possible
ifndef MILABENCH_BASE
MILABENCH_BASE="base"
endif

export MILABENCH_BASE

BENCH_NAME=cleanrl_jax
MILABENCH_CONFIG=dev.yaml
MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE)

all:
install prepare single gpus nodes

install:
milabench install $(MILABENCH_ARGS) --force

prepare:
milabench prepare $(MILABENCH_ARGS)

tests: install prepare
milabench run $(MILABENCH_ARGS)

single:
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)

gpus:
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus

nodes:
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes
4 changes: 4 additions & 0 deletions benchmarks/cleanrl_jax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

# Cleanrl_jax

Rewrite this README to explain what the benchmark is!
31 changes: 31 additions & 0 deletions benchmarks/cleanrl_jax/benchfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from milabench.pack import Package


class Cleanrl_jax(Package):
# Requirements file installed by install(). It can be empty or absent.
base_requirements = "requirements.in"

# The preparation script called by prepare(). It must be executable,
# but it can be any type of script. It can be empty or absent.
prepare_script = "prepare.py"

# The main script called by run(). It must be a Python file. It has to
# be present.
main_script = "main.py"

# You can remove the functions below if you don't need to modify them.

def make_env(self):
# Return a dict of environment variables for prepare_script and
# main_script.
return super().make_env()

async def install(self):
await super().install() # super() call installs the requirements

async def prepare(self):
await super().prepare() # super() call executes prepare_script



__pack__ = Cleanrl_jax
8 changes: 8 additions & 0 deletions benchmarks/cleanrl_jax/dev.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

cleanrl_jax:
inherits: _defaults
definition: .
install-variant: unpinned
install_group: torch
plan:
method: per_gpu
Loading

0 comments on commit 6cb2c92

Please sign in to comment.