Skip to content

Commit

Permalink
Update scaling files
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Sep 18, 2024
1 parent c1760c7 commit 46647ff
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 38 deletions.
46 changes: 43 additions & 3 deletions benchmarks/purejaxrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def make_train(config):
config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // config["NUM_ENVS"]

from benchmate.timings import StepTimer
from benchmate.jaxmem import memory_peak_fetcher
step_timer = StepTimer(give_push())
fetch_memory_peak = memory_peak_fetcher()

basic_env, env_params = gymnax.make(config["ENV_NAME"])
env = FlattenObservationWrapper(basic_env)
Expand Down Expand Up @@ -238,6 +240,7 @@ def callback(metrics):

step_timer.step(delta.item())
step_timer.log(returns=returns, loss=loss)
step_timer.log(memory_peak=fetch_memory_peak(), units="MiB")
step_timer.end()

jax.debug.callback(callback, metrics)
Expand All @@ -258,12 +261,49 @@ def callback(metrics):
return train


# When using nvidia-smi to monitor memory
# arg: --buffer_size
# model:
# 256: 61900.25 MiB
# 1000: 61900.25 MiB
# 10000: 61900.25 MiB

# dqn:
# arg: --num_envs
# model:
# 2: 61900.25 MiB
# 4: 61900.25 MiB
# 16: 61900.25 MiB
# 32: 61900.25 MiB
# 64: 61900.25 MiB
# 128: 61900.25 MiB

# arg: --total_timesteps
# model:
# 32768: 61900.25 MiB
# 65536: 61900.25 MiB

# When using Jax to monitor memory

# dqn.D0 [stdout] Device: cuda:0
# dqn.D0 [stdout] num_allocs: 0.0006799697875976562 MiB
# dqn.D0 [stdout] bytes_in_use: 0.915771484375 MiB
# dqn.D0 [stdout] peak_bytes_in_use: 80.41552734375 MiB
# dqn.D0 [stdout] largest_alloc_size: 16.07958984375 MiB
# dqn.D0 [stdout] bytes_limit: 60832.359375 MiB
# dqn.D0 [stdout] bytes_reserved: 0.0 MiB
# dqn.D0 [stdout] peak_bytes_reserved: 0.0 MiB
# dqn.D0 [stdout] largest_free_block_bytes: 0.0 MiB
# dqn.D0 [stdout] pool_bytes: 60832.359375 MiB
# dqn.D0 [stdout] peak_pool_bytes: 60832.359375 MiB


@dataclass
class Arguments:
num_envs: int = 10
buffer_size: int = 10000
num_envs: int = 10 # No impact on memory
buffer_size: int = 10000 # No impact on memory
buffer_batch_size: int = 128
total_timesteps: int = 100_000
total_timesteps: int = 100_000 # No impact on memory
epsilon_start: float = 1.0
epsilon_finish: float = 0.05
epsilon_anneal_time: int = 25e4
Expand Down
4 changes: 4 additions & 0 deletions benchmarks/purejaxrl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ class Transition(NamedTuple):

def make_train(config):
from benchmate.timings import StepTimer
from benchmate.jaxmem import memory_peak_fetcher

step_timer = StepTimer(give_push())
fetch_memory_peak = memory_peak_fetcher()

config["NUM_UPDATES"] = (
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
Expand Down Expand Up @@ -280,6 +283,7 @@ def callback(info):

step_timer.step(config["NUM_ENVS"] * config["NUM_STEPS"])
step_timer.log(loss=loss)
step_timer.log(memory_peak=fetch_memory_peak(), units="MiB")
step_timer.end()

jax.debug.callback(callback, metrics)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/purejaxrl/voirfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def instrument_main(ov, options: Config):
ov.require(dash)

ov.require(
log("value", "progress", "rate", "units", "loss", "gpudata", context="task"),
log("value", "progress", "rate", "units", "loss", "gpudata", "memory_peak", "cpudata", context="task"),
# early_stop(n=options.stop, key="rate", task="train"),
monitor_monogpu(poll_interval=options.gpu_poll),
)
30 changes: 30 additions & 0 deletions benchmate/benchmate/jaxmem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@



def memory_peak_fetcher():
import jax

def fetch_memory_peak():
# 'memory', 'memory_stats'
devices = jax.devices()
max_mem = -1
for device in devices:
# dqn.D0 [stdout] Device: cuda:0
# dqn.D0 [stdout] num_allocs: 0.0006799697875976562 MiB
# dqn.D0 [stdout] bytes_in_use: 0.915771484375 MiB
# dqn.D0 [stdout] peak_bytes_in_use: 80.41552734375 MiB
# dqn.D0 [stdout] largest_alloc_size: 16.07958984375 MiB
# dqn.D0 [stdout] bytes_limit: 60832.359375 MiB
# dqn.D0 [stdout] bytes_reserved: 0.0 MiB
# dqn.D0 [stdout] peak_bytes_reserved: 0.0 MiB
# dqn.D0 [stdout] largest_free_block_bytes: 0.0 MiB
# dqn.D0 [stdout] pool_bytes: 60832.359375 MiB
# dqn.D0 [stdout] peak_pool_bytes: 60832.359375 MiB

# device_name = str(device)
mem = device.memory_stats().get("peak_bytes_in_use", 0) / (1024 ** 2)
max_mem = max(mem, max_mem)

return max_mem

return fetch_memory_peak
18 changes: 16 additions & 2 deletions config/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,8 @@ dqn:
argv:
dqn: true
--num_envs: auto({cpu_per_gpu}, 128)
--buffer_batch_size: 128
--buffer_size: 131072
--buffer_batch_size: 65536
--env_name: CartPole-v1
--training_interval: 10

Expand All @@ -720,7 +721,7 @@ ppo:
--num_minibatches: 32
--update_epochs: 4
--env_name: hopper
--total_timesteps: 200000
--total_timesteps: 2000000

_geo_gnn:
inherits: _defaults
Expand Down Expand Up @@ -880,3 +881,16 @@ cleanrljax:
definition: ../benchmarks/cleanrl_jax
plan:
method: per_gpu

# args.batch_size = int(args.num_envs * args.num_steps)
# args.minibatch_size = int(args.batch_size // args.num_minibatches)
# args.num_iterations = args.total_timesteps // args.batch_size
# --total_timesteps
# --num_steps
# --num_minibatches

argv:
--num_envs: auto({cpu_per_gpu}, 128)
--num_steps: 128
--num_minibatches: 4
--total_timesteps: 10000000
34 changes: 22 additions & 12 deletions config/scaling.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ brax:
arg: --batch-size
model:
1024: 4912.25 MiB
cleanrljax:
arg: --num_steps
optimized: 128
convnext_large-fp16:
arg: --batch-size
model:
Expand Down Expand Up @@ -219,7 +222,12 @@ dinov2-giant-single:
32: 74544.25 MiB
dlrm: {}
dqn:
arg: --buffer_size
arg: --buffer_batch_size
model:
1024: 81.81005859375 MiB
2048: 83.40380859375 MiB
32768: 131.21630859375 MiB
65536: 182.21630859375 MiB
optimized: 128
focalnet:
arg: --batch-size
Expand Down Expand Up @@ -331,15 +339,16 @@ opt-6_7b-multinode:
1: 55380 MiB
optimized: 1
ppo:
arg: --num_minibatches
model:
1: 62426.25 MiB
2: 62426.25 MiB
4: 62426.25 MiB
16: 62426.25 MiB
32: 62426.25 MiB
64: 62426.25 MiB
128: 62426.25 MiB
arg: --num_steps
model:
8: 80.791748046875 MiB
16: 80.916748046875 MiB
32: 81.166748046875 MiB
64: 81.666748046875 MiB
128: 82.666748046875 MiB
1024: 96.666748046875 MiB
2048: 132.484619140625 MiB
4096: 205.328369140625 MiB
optimized: 32
recursiongfn:
arg: --batch_size
Expand Down Expand Up @@ -535,8 +544,9 @@ torchatari:
arg: --num-steps
model:
1: 1124.75 MiB
2: 1138.75 MiB
4: 1166.75 MiB
1024: 20176.25 MiB
2048: 39020.25 MiB
4096: 76708.25 MiB
vjepa-gpus:
arg: --batch_size
model:
Expand Down
7 changes: 6 additions & 1 deletion config/standard.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ dqn:
ppo:
enabled: true
weight: 1.0



cleanrljax:
enabled: false
weight: 1.0

# Geo
dimenet:
enabled: true
Expand Down
19 changes: 16 additions & 3 deletions milabench/sizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def __init__(self):
self.scaling = None
self.benchname = None
self.batch_size = 0
self.max_usage = float("-inf")
self.max_usage = float("-inf") # Usage from the gpu monitor
self.peak_usage = float("-inf") # Usage provided by the bench itself (for jax)
self.early_stopped = False

def on_start(self, entry):
Expand All @@ -259,6 +260,7 @@ def on_start(self, entry):
self.benchname = entry.pack.config["name"]
self.batch_size = None
self.max_usage = float("-inf")
self.peak_usage = float("-inf")

config = self.memory.setdefault(self.benchname, dict())
template = config.get("arg", None)
Expand Down Expand Up @@ -300,6 +302,11 @@ def on_data(self, entry):
if entry.data is None:
return

memorypeak = entry.data.get("memory_peak")
if memorypeak is not None:
self.peak_usage = max(memorypeak, self.peak_usage)
return

gpudata = entry.data.get("gpudata")
if gpudata is not None:
current_usage = []
Expand All @@ -312,14 +319,19 @@ def on_data(self, entry):
def on_stop(self, entry):
self.early_stopped = True

def max_memory_usage(self):
if self.peak_usage != float("-inf"):
return self.peak_usage
return self.max_usage

def on_end(self, entry):
if self.filepath is None:
return

if (
self.benchname is None
or self.batch_size is None
or self.max_usage == float("-inf")
or self.max_memory_usage() == float("-inf")
):
return

Expand All @@ -328,12 +340,13 @@ def on_end(self, entry):
if rc == 0 or self.early_stopped:
config = self.memory.setdefault(self.benchname, dict())
model = config.setdefault("model", dict())
model[self.batch_size] = f"{self.max_usage} MiB"
model[self.batch_size] = f"{self.max_memory_usage()} MiB"
config["model"] = dict(sorted(model.items(), key=lambda x: x[0]))

self.benchname = None
self.batch_size = None
self.max_usage = float("-inf")
self.peak_usage = float("-inf")

def report(self, *args):
if self.filepath is not None:
Expand Down
23 changes: 7 additions & 16 deletions scripts/article/run_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,14 @@ if [ "$MILABENCH_PREPARE" -eq 0 ]; then

ARGS="--select resnet50-noio,brax,lightning,dinov2-giant-single,dinov2-giant-gpus,llm-lora-ddp-gpus,llm-lora-ddp-nodes,llm-lora-mp-gpus,llm-full-mp-gpus,llm-full-mp-nodes,dqn,ppo,dimenet,llava-single,rlhf-single,rlhf-gpus,vjepa-single,vjepa-gpus"

# MEMORY_CAPACITY=("4Go" "8Go" "16Go" "32Go" "64Go" "80Go")
# MILABENCH_SIZER_MULTIPLE=16
# MILABENCH_SIZER_CAPACITY="$CAPACITY"

MEMORY_CAPACITY=("2" "4" "16" "32" "64" "128")

# "dqn" "ppo" "torchatari"
# BENCHES=("dqn" "ppo" "torchatari" "cleanrljax")
BENCHES=("dqn")
#
MEMORY_CAPACITY=("4Go" "8Go" "16Go" "32Go" "64Go" "80Go")

# Run the benchmakrs
for BENCH in "${BENCHES[@]}"; do
for CAPACITY in "${MEMORY_CAPACITY[@]}"; do
export MILABENCH_SIZER_AUTO=1
export MILABENCH_SIZER_BATCH_SIZE=$CAPACITY
milabench run --run-name "$BENCH.bs$CAPACITY.{time}" --system $MILABENCH_WORDIR/system.yaml --select $BENCH --exclude lightning-gpus
done
for CAPACITY in "${MEMORY_CAPACITY[@]}"; do
export MILABENCH_SIZER_AUTO=1
export MILABENCH_SIZER_MULTIPLE=8
export MILABENCH_SIZER_BATCH_SIZE=$CAPACITY
milabench run --run-name "c$CAPACITY.{time}" --system $MILABENCH_WORDIR/system.yaml $ARGS || true
done

#
Expand Down

0 comments on commit 46647ff

Please sign in to comment.