diff --git a/.gitignore b/.gitignore index 90a1e78d7..1bc7f879c 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,7 @@ scripts/article/xpu/ dependencies/ benchmarks/recursiongfn/gflownet benchmarks/recursiongfn/logs/ +benchmarks/llm/tune/ scripts/inventory.yaml output/ diff --git a/.pin/constraints-cuda-gnn.txt b/.pin/constraints-cuda-gnn.txt index 238faa6ae..d8dd566d9 100644 --- a/.pin/constraints-cuda-gnn.txt +++ b/.pin/constraints-cuda-gnn.txt @@ -29,7 +29,7 @@ blosc2==2.7.1 # via tables botorch==0.11.3 # via -r benchmarks/recursiongfn/requirements.in -certifi==2024.7.4 +certifi==2024.8.30 # via # requests # sentry-sdk @@ -53,7 +53,7 @@ frozenlist==1.4.1 # via # aiohttp # aiosignal -fsspec==2024.6.1 +fsspec==2024.9.0 # via # torch # torch-geometric @@ -77,7 +77,7 @@ idna==3.8 # via # requests # yarl -jaxtyping==0.2.33 +jaxtyping==0.2.34 # via linear-operator jinja2==3.1.4 # via @@ -330,7 +330,7 @@ wandb==0.17.8 # via -r benchmarks/recursiongfn/requirements.in werkzeug==3.0.4 # via tensorboard -yarl==1.9.4 +yarl==1.9.8 # via aiohttp # The following packages are considered to be unsafe in a requirements file: diff --git a/.pin/constraints-cuda-torch.txt b/.pin/constraints-cuda-torch.txt index e5733ae37..f5d6ac6d5 100644 --- a/.pin/constraints-cuda-torch.txt +++ b/.pin/constraints-cuda-torch.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# pip-compile --output-file=.pin/constraints-cuda-torch.txt .pin/tmp-constraints.txt benchmarks/brax/requirements.in benchmarks/diffusion/requirements.in benchmarks/dinov2/requirements.in benchmarks/flops/requirements.in benchmarks/huggingface/requirements.in benchmarks/lightning/requirements.in benchmarks/llama/requirements.in benchmarks/llm/requirements.in benchmarks/super-slomo/requirements.in benchmarks/timm/requirements.in benchmarks/torchatari/requirements.in benchmarks/torchvision/requirements.in benchmarks/torchvision_ddp/requirements.in constraints/extra/torch.cuda.txt +# pip-compile --output-file=.pin/constraints-cuda-torch.txt .pin/tmp-constraints.txt benchmarks/brax/requirements.in benchmarks/diffusion/requirements.in benchmarks/dinov2/requirements.in benchmarks/flops/requirements.in benchmarks/huggingface/requirements.in benchmarks/lightning/requirements.in benchmarks/llm/requirements.in benchmarks/super-slomo/requirements.in benchmarks/timm/requirements.in benchmarks/torchatari/requirements.in benchmarks/torchvision/requirements.in benchmarks/torchvision_ddp/requirements.in constraints/extra/torch.cuda.txt # --extra-index-url https://pypi.ngc.nvidia.com --extra-index-url https://download.pytorch.org/whl/cu121 @@ -20,7 +20,7 @@ absl-py==2.1.0 # optax # orbax-checkpoint # tensorboard -accelerate==0.33.0 +accelerate==0.34.0 # via # -r benchmarks/diffusion/requirements.in # diffusers @@ -54,7 +54,7 @@ brax==0.10.5 # via -r benchmarks/brax/requirements.in cantilever==0.1.0 # via -r benchmarks/torchatari/requirements.in -certifi==2024.7.4 +certifi==2024.8.30 # via requests charset-normalizer==3.3.2 # via requests @@ -74,9 +74,8 @@ contextlib2==21.6.0 datasets==2.21.0 # via # -r benchmarks/diffusion/requirements.in - # -r benchmarks/llama/requirements.in # torchtune -diffusers[torch]==0.30.1 +diffusers[torch]==0.30.2 # via -r benchmarks/diffusion/requirements.in dill==0.3.8 # via @@ -92,7 +91,7 @@ docstring-parser==0.16 # via tyro envpool==0.8.4 # via -r benchmarks/torchatari/requirements.in -etils[epath,epy]==1.7.0 +etils[epath,epy]==1.9.4 # via # brax # mujoco @@ -101,8 +100,6 @@ etils[epath,epy]==1.7.0 # orbax-checkpoint executing==1.2.0 # via varname -fairscale==0.4.13 - # via -r benchmarks/llama/requirements.in farama-notifications==0.0.4 # via gymnasium filelock==3.15.4 @@ -114,13 +111,11 @@ filelock==3.15.4 # torch # transformers # triton -fire==0.6.0 - # via -r benchmarks/llama/requirements.in flask==3.0.3 # via # brax # flask-cors -flask-cors==4.0.1 +flask-cors==5.0.0 # via brax flax==0.9.0 # via brax @@ -221,7 +216,7 @@ jinja2==3.1.4 # torch lightning==2.4.0 # via -r benchmarks/lightning/requirements.in -lightning-utilities==0.11.6 +lightning-utilities==0.11.7 # via # lightning # pytorch-lightning @@ -278,7 +273,6 @@ numpy==1.26.4 # diffusers # dm-env # envpool - # fairscale # fvcore # gym # gymnasium @@ -476,15 +470,12 @@ scipy==1.14.1 # jaxopt # mujoco-mjx sentencepiece==0.2.0 - # via - # -r benchmarks/llama/requirements.in - # torchtune + # via torchtune shtab==1.7.1 # via tyro six==1.16.0 # via # asttokens - # fire # ml-collections # python-dateutil # tensorboard @@ -505,9 +496,7 @@ tensorstore==0.1.64 # flax # orbax-checkpoint termcolor==2.4.0 - # via - # fire - # fvcore + # via fvcore tiktoken==0.7.0 # via torchtune tokenizers==0.19.1 @@ -521,7 +510,6 @@ torch==2.4.0+cu121 # -r benchmarks/flops/requirements.in # -r benchmarks/huggingface/requirements.in # -r benchmarks/lightning/requirements.in - # -r benchmarks/llama/requirements.in # -r benchmarks/llm/requirements.in # -r benchmarks/super-slomo/requirements.in # -r benchmarks/timm/requirements.in @@ -530,7 +518,6 @@ torch==2.4.0+cu121 # -r benchmarks/torchvision_ddp/requirements.in # accelerate # diffusers - # fairscale # lightning # pytorch-lightning # torchmetrics @@ -582,8 +569,7 @@ transformers==4.44.2 # via # -r benchmarks/diffusion/requirements.in # -r benchmarks/huggingface/requirements.in - # -r benchmarks/llama/requirements.in -trimesh==4.4.7 +trimesh==4.4.8 # via # brax # mujoco-mjx @@ -629,7 +615,6 @@ voir==0.2.19 # -r benchmarks/flops/requirements.in # -r benchmarks/huggingface/requirements.in # -r benchmarks/lightning/requirements.in - # -r benchmarks/llama/requirements.in # -r benchmarks/llm/requirements.in # -r benchmarks/super-slomo/requirements.in # -r benchmarks/timm/requirements.in @@ -646,7 +631,7 @@ xxhash==3.5.0 # via datasets yacs==0.1.8 # via fvcore -yarl==1.9.4 +yarl==1.9.8 # via aiohttp zipp==3.20.1 # via diff --git a/benchmarks/brax/requirements.cuda.txt b/benchmarks/brax/requirements.cuda.txt index 3a84f7fd5..4ad501766 100644 --- a/benchmarks/brax/requirements.cuda.txt +++ b/benchmarks/brax/requirements.cuda.txt @@ -64,7 +64,7 @@ dm-tree==0.1.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # dm-env -etils[epath,epy]==1.7.0 +etils[epath,epy]==1.9.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax @@ -86,7 +86,7 @@ flask==3.0.3 # -c .pin/../.pin/constraints-cuda-torch.txt # brax # flask-cors -flask-cors==4.0.1 +flask-cors==5.0.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax @@ -414,7 +414,7 @@ torch==2.4.0+cu121 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/brax/requirements.in -trimesh==4.4.7 +trimesh==4.4.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax diff --git a/benchmarks/cleanrl_jax/Makefile b/benchmarks/cleanrl_jax/Makefile new file mode 100644 index 000000000..20c442249 --- /dev/null +++ b/benchmarks/cleanrl_jax/Makefile @@ -0,0 +1,31 @@ +# Use global base if possible +ifndef MILABENCH_BASE + MILABENCH_BASE="base" +endif + +export MILABENCH_BASE + +BENCH_NAME=cleanrl_jax +MILABENCH_CONFIG=dev.yaml +MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE) + +all: + install prepare single gpus nodes + +install: + milabench install $(MILABENCH_ARGS) --force + +prepare: + milabench prepare $(MILABENCH_ARGS) + +tests: install prepare + milabench run $(MILABENCH_ARGS) + +single: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME) + +gpus: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus + +nodes: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes diff --git a/benchmarks/cleanrl_jax/README.md b/benchmarks/cleanrl_jax/README.md new file mode 100644 index 000000000..255d805b7 --- /dev/null +++ b/benchmarks/cleanrl_jax/README.md @@ -0,0 +1,4 @@ + +# Cleanrl_jax + +Rewrite this README to explain what the benchmark is! diff --git a/benchmarks/cleanrl_jax/benchfile.py b/benchmarks/cleanrl_jax/benchfile.py new file mode 100644 index 000000000..afc5f5a65 --- /dev/null +++ b/benchmarks/cleanrl_jax/benchfile.py @@ -0,0 +1,31 @@ +from milabench.pack import Package + + +class Cleanrl_jax(Package): + # Requirements file installed by install(). It can be empty or absent. + base_requirements = "requirements.in" + + # The preparation script called by prepare(). It must be executable, + # but it can be any type of script. It can be empty or absent. + prepare_script = "prepare.py" + + # The main script called by run(). It must be a Python file. It has to + # be present. + main_script = "main.py" + + # You can remove the functions below if you don't need to modify them. + + def make_env(self): + # Return a dict of environment variables for prepare_script and + # main_script. + return super().make_env() + + async def install(self): + await super().install() # super() call installs the requirements + + async def prepare(self): + await super().prepare() # super() call executes prepare_script + + + +__pack__ = Cleanrl_jax diff --git a/benchmarks/cleanrl_jax/dev.yaml b/benchmarks/cleanrl_jax/dev.yaml new file mode 100644 index 000000000..37ea10198 --- /dev/null +++ b/benchmarks/cleanrl_jax/dev.yaml @@ -0,0 +1,8 @@ + +cleanrl_jax: + inherits: _defaults + definition: . + install-variant: unpinned + install_group: torch + plan: + method: per_gpu diff --git a/benchmarks/cleanrl_jax/main.py b/benchmarks/cleanrl_jax/main.py new file mode 100644 index 000000000..3d6597563 --- /dev/null +++ b/benchmarks/cleanrl_jax/main.py @@ -0,0 +1,522 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_xla_jaxpy +import os +import random +import time +from dataclasses import dataclass +from functools import partial +from typing import Sequence + +import envpool +import flax +import flax.linen as nn +import gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +import tyro +from flax.linen.initializers import constant, orthogonal +from flax.training.train_state import TrainState +from torch.utils.tensorboard import SummaryWriter + +# Fix weird OOM https://github.com/google/jax/discussions/6332#discussioncomment-1279991 +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.6" +# Fix CUDNN non-determinisim; https://github.com/google/jax/issues/4823#issuecomment-952835771 +os.environ["TF_XLA_FLAGS"] = "--xla_gpu_autotune_level=2 --xla_gpu_deterministic_reductions" +os.environ["TF_CUDNN DETERMINISTIC"] = "1" + + +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanRL" + """the wandb's project name""" + wandb_entity: str = None + """the entity (team) of wandb's project""" + capture_video: bool = False + """whether to capture videos of the agent performances (check out `videos` folder)""" + save_model: bool = False + """whether to save model into the `runs/{run_name}` folder""" + upload_model: bool = False + """whether to upload the saved model to huggingface""" + hf_entity: str = "" + """the user or org name of the model repository from the Hugging Face Hub""" + + # Algorithm specific arguments + env_id: str = "Breakout-v5" + """the id of the environment""" + total_timesteps: int = 10000000 + """total timesteps of the experiments""" + learning_rate: float = 2.5e-4 + """the learning rate of the optimizer""" + num_envs: int = 8 + """the number of parallel game environments""" + num_steps: int = 128 + """the number of steps to run in each environment per policy rollout""" + anneal_lr: bool = True + """Toggle learning rate annealing for policy and value networks""" + gamma: float = 0.99 + """the discount factor gamma""" + gae_lambda: float = 0.95 + """the lambda for the general advantage estimation""" + num_minibatches: int = 4 + """the number of mini-batches""" + update_epochs: int = 4 + """the K epochs to update the policy""" + norm_adv: bool = True + """Toggles advantages normalization""" + clip_coef: float = 0.1 + """the surrogate clipping coefficient""" + clip_vloss: bool = True + """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" + ent_coef: float = 0.01 + """coefficient of the entropy""" + vf_coef: float = 0.5 + """coefficient of the value function""" + max_grad_norm: float = 0.5 + """the maximum norm for the gradient clipping""" + target_kl: float = None + """the target KL divergence threshold""" + + # to be filled in runtime + batch_size: int = 0 + """the batch size (computed in runtime)""" + minibatch_size: int = 0 + """the mini-batch size (computed in runtime)""" + num_iterations: int = 0 + """the number of iterations (computed in runtime)""" + + +def make_env(env_id, seed, num_envs): + def thunk(): + envs = envpool.make( + env_id, + env_type="gym", + num_envs=num_envs, + episodic_life=True, + reward_clip=True, + seed=seed, + ) + envs.num_envs = num_envs + envs.single_action_space = envs.action_space + envs.single_observation_space = envs.observation_space + envs.is_vector_env = True + return envs + + return thunk + + +class Network(nn.Module): + @nn.compact + def __call__(self, x): + x = jnp.transpose(x, (0, 2, 3, 1)) + x = x / (255.0) + x = nn.Conv( + 32, + kernel_size=(8, 8), + strides=(4, 4), + padding="VALID", + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + x = nn.relu(x) + x = nn.Conv( + 64, + kernel_size=(4, 4), + strides=(2, 2), + padding="VALID", + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + x = nn.relu(x) + x = nn.Conv( + 64, + kernel_size=(3, 3), + strides=(1, 1), + padding="VALID", + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + x = nn.relu(x) + x = x.reshape((x.shape[0], -1)) + x = nn.Dense(512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) + x = nn.relu(x) + return x + + +class Critic(nn.Module): + @nn.compact + def __call__(self, x): + return nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(x) + + +class Actor(nn.Module): + action_dim: Sequence[int] + + @nn.compact + def __call__(self, x): + return nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x) + + +@flax.struct.dataclass +class AgentParams: + network_params: flax.core.FrozenDict + actor_params: flax.core.FrozenDict + critic_params: flax.core.FrozenDict + + +@flax.struct.dataclass +class Storage: + obs: jnp.array + actions: jnp.array + logprobs: jnp.array + dones: jnp.array + values: jnp.array + advantages: jnp.array + returns: jnp.array + rewards: jnp.array + + +@flax.struct.dataclass +class EpisodeStatistics: + episode_returns: jnp.array + episode_lengths: jnp.array + returned_episode_returns: jnp.array + returned_episode_lengths: jnp.array + + +if __name__ == "__main__": + args = tyro.cli(Args) + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_iterations = args.total_timesteps // args.batch_size + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + key = jax.random.PRNGKey(args.seed) + key, network_key, actor_key, critic_key = jax.random.split(key, 4) + + # env setup + envs = make_env(args.env_id, args.seed, args.num_envs)() + episode_stats = EpisodeStatistics( + episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32), + episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32), + returned_episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32), + returned_episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32), + ) + handle, recv, send, step_env = envs.xla() + + def step_env_wrappeed(episode_stats, handle, action): + handle, (next_obs, reward, next_done, info) = step_env(handle, action) + new_episode_return = episode_stats.episode_returns + info["reward"] + new_episode_length = episode_stats.episode_lengths + 1 + episode_stats = episode_stats.replace( + episode_returns=(new_episode_return) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), + episode_lengths=(new_episode_length) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), + # only update the `returned_episode_returns` if the episode is done + returned_episode_returns=jnp.where( + info["terminated"] + info["TimeLimit.truncated"], new_episode_return, episode_stats.returned_episode_returns + ), + returned_episode_lengths=jnp.where( + info["terminated"] + info["TimeLimit.truncated"], new_episode_length, episode_stats.returned_episode_lengths + ), + ) + return episode_stats, handle, (next_obs, reward, next_done, info) + + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + + def linear_schedule(count): + # anneal learning rate linearly after one training iteration which contains + # (args.num_minibatches * args.update_epochs) gradient updates + frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_iterations + return args.learning_rate * frac + + network = Network() + actor = Actor(action_dim=envs.single_action_space.n) + critic = Critic() + network_params = network.init(network_key, np.array([envs.single_observation_space.sample()])) + agent_state = TrainState.create( + apply_fn=None, + params=AgentParams( + network_params, + actor.init(actor_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), + critic.init(critic_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), + ), + tx=optax.chain( + optax.clip_by_global_norm(args.max_grad_norm), + optax.inject_hyperparams(optax.adam)( + learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5 + ), + ), + ) + network.apply = jax.jit(network.apply) + actor.apply = jax.jit(actor.apply) + critic.apply = jax.jit(critic.apply) + + @jax.jit + def get_action_and_value( + agent_state: TrainState, + next_obs: np.ndarray, + key: jax.random.PRNGKey, + ): + """sample action, calculate value, logprob, entropy, and update storage""" + hidden = network.apply(agent_state.params.network_params, next_obs) + logits = actor.apply(agent_state.params.actor_params, hidden) + # sample action: Gumbel-softmax trick + # see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution + key, subkey = jax.random.split(key) + u = jax.random.uniform(subkey, shape=logits.shape) + action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1) + logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] + value = critic.apply(agent_state.params.critic_params, hidden) + return action, logprob, value.squeeze(1), key + + @jax.jit + def get_action_and_value2( + params: flax.core.FrozenDict, + x: np.ndarray, + action: np.ndarray, + ): + """calculate value, logprob of supplied `action`, and entropy""" + hidden = network.apply(params.network_params, x) + logits = actor.apply(params.actor_params, hidden) + logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] + # normalize the logits https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/ + logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) + logits = logits.clip(min=jnp.finfo(logits.dtype).min) + p_log_p = logits * jax.nn.softmax(logits) + entropy = -p_log_p.sum(-1) + value = critic.apply(params.critic_params, hidden).squeeze() + return logprob, entropy, value + + def compute_gae_once(carry, inp, gamma, gae_lambda): + advantages = carry + nextdone, nextvalues, curvalues, reward = inp + nextnonterminal = 1.0 - nextdone + + delta = reward + gamma * nextvalues * nextnonterminal - curvalues + advantages = delta + gamma * gae_lambda * nextnonterminal * advantages + return advantages, advantages + + compute_gae_once = partial(compute_gae_once, gamma=args.gamma, gae_lambda=args.gae_lambda) + + @jax.jit + def compute_gae( + agent_state: TrainState, + next_obs: np.ndarray, + next_done: np.ndarray, + storage: Storage, + ): + next_value = critic.apply( + agent_state.params.critic_params, network.apply(agent_state.params.network_params, next_obs) + ).squeeze() + + advantages = jnp.zeros((args.num_envs,)) + dones = jnp.concatenate([storage.dones, next_done[None, :]], axis=0) + values = jnp.concatenate([storage.values, next_value[None, :]], axis=0) + _, advantages = jax.lax.scan( + compute_gae_once, advantages, (dones[1:], values[1:], values[:-1], storage.rewards), reverse=True + ) + storage = storage.replace( + advantages=advantages, + returns=advantages + storage.values, + ) + return storage + + def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): + newlogprob, entropy, newvalue = get_action_and_value2(params, x, a) + logratio = newlogprob - logp + ratio = jnp.exp(logratio) + approx_kl = ((ratio - 1) - logratio).mean() + + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean() + + # Value loss + v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) + + ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) + + @jax.jit + def update_ppo( + agent_state: TrainState, + storage: Storage, + key: jax.random.PRNGKey, + ): + def update_epoch(carry, unused_inp): + agent_state, key = carry + key, subkey = jax.random.split(key) + + def flatten(x): + return x.reshape((-1,) + x.shape[2:]) + + # taken from: https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py + def convert_data(x: jnp.ndarray): + x = jax.random.permutation(subkey, x) + x = jnp.reshape(x, (args.num_minibatches, -1) + x.shape[1:]) + return x + + flatten_storage = jax.tree_map(flatten, storage) + shuffled_storage = jax.tree_map(convert_data, flatten_storage) + + def update_minibatch(agent_state, minibatch): + (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn( + agent_state.params, + minibatch.obs, + minibatch.actions, + minibatch.logprobs, + minibatch.advantages, + minibatch.returns, + ) + agent_state = agent_state.apply_gradients(grads=grads) + return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) + + agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan( + update_minibatch, agent_state, shuffled_storage + ) + return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) + + (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan( + update_epoch, (agent_state, key), (), length=args.update_epochs + ) + return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs = envs.reset() + next_done = jnp.zeros(args.num_envs, dtype=jax.numpy.bool_) + + # based on https://github.dev/google/evojax/blob/0625d875262011d8e1b6aa32566b236f44b4da66/evojax/sim_mgr.py + def step_once(carry, step, env_step_fn): + agent_state, episode_stats, obs, done, key, handle = carry + action, logprob, value, key = get_action_and_value(agent_state, obs, key) + + episode_stats, handle, (next_obs, reward, next_done, _) = env_step_fn(episode_stats, handle, action) + storage = Storage( + obs=obs, + actions=action, + logprobs=logprob, + dones=done, + values=value, + rewards=reward, + returns=jnp.zeros_like(reward), + advantages=jnp.zeros_like(reward), + ) + return ((agent_state, episode_stats, next_obs, next_done, key, handle), storage) + + def rollout(agent_state, episode_stats, next_obs, next_done, key, handle, step_once_fn, max_steps): + (agent_state, episode_stats, next_obs, next_done, key, handle), storage = jax.lax.scan( + step_once_fn, (agent_state, episode_stats, next_obs, next_done, key, handle), (), max_steps + ) + return agent_state, episode_stats, next_obs, next_done, storage, key, handle + + rollout = partial(rollout, step_once_fn=partial(step_once, env_step_fn=step_env_wrappeed), max_steps=args.num_steps) + + for iteration in range(1, args.num_iterations + 1): + iteration_time_start = time.time() + agent_state, episode_stats, next_obs, next_done, storage, key, handle = rollout( + agent_state, episode_stats, next_obs, next_done, key, handle + ) + global_step += args.num_steps * args.num_envs + storage = compute_gae(agent_state, next_obs, next_done, storage) + agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key = update_ppo( + agent_state, + storage, + key, + ) + avg_episodic_return = np.mean(jax.device_get(episode_stats.returned_episode_returns)) + print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}") + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step) + writer.add_scalar( + "charts/avg_episodic_length", np.mean(jax.device_get(episode_stats.returned_episode_lengths)), global_step + ) + writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"].item(), global_step) + writer.add_scalar("losses/value_loss", v_loss[-1, -1].item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss[-1, -1].item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss[-1, -1].item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl[-1, -1].item(), global_step) + writer.add_scalar("losses/loss", loss[-1, -1].item(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + writer.add_scalar( + "charts/SPS_update", int(args.num_envs * args.num_steps / (time.time() - iteration_time_start)), global_step + ) + + if args.save_model: + model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" + with open(model_path, "wb") as f: + f.write( + flax.serialization.to_bytes( + [ + vars(args), + [ + agent_state.params.network_params, + agent_state.params.actor_params, + agent_state.params.critic_params, + ], + ] + ) + ) + print(f"model saved to {model_path}") + from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate + + episodic_returns = evaluate( + model_path, + make_env, + args.env_id, + eval_episodes=10, + run_name=f"{run_name}-eval", + Model=(Network, Actor, Critic), + ) + for idx, episodic_return in enumerate(episodic_returns): + writer.add_scalar("eval/episodic_return", episodic_return, idx) + + if args.upload_model: + from cleanrl_utils.huggingface import push_to_hub + + repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval") + + envs.close() + writer.close() \ No newline at end of file diff --git a/benchmarks/cleanrl_jax/prepare.py b/benchmarks/cleanrl_jax/prepare.py new file mode 100755 index 000000000..32bd5901d --- /dev/null +++ b/benchmarks/cleanrl_jax/prepare.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +import os + +if __name__ == "__main__": + # If you need the whole configuration: + # config = json.loads(os.environ["MILABENCH_CONFIG"]) + + data_directory = os.environ["MILABENCH_DIR_DATA"] + + # Download (or generate) the needed dataset(s). You are responsible + # to check if it has already been properly downloaded or not, and to + # do nothing if it has been. + print("Hello I am doing some data stuff!") + + # If there is nothing to download or generate, just delete this file. diff --git a/benchmarks/cleanrl_jax/requirements.in b/benchmarks/cleanrl_jax/requirements.in new file mode 100644 index 000000000..77b50182c --- /dev/null +++ b/benchmarks/cleanrl_jax/requirements.in @@ -0,0 +1,99 @@ +voir>=0.2.17,<0.3 +absl-py==1.4.0 ; python_version >= "3.8" and python_version < "3.11" +appdirs==1.4.4 ; python_version >= "3.8" and python_version < "3.11" +cached-property==1.5.2 ; python_version >= "3.8" and python_version < "3.11" +cachetools==5.3.0 ; python_version >= "3.8" and python_version < "3.11" +certifi==2023.5.7 ; python_version >= "3.8" and python_version < "3.11" +charset-normalizer==3.1.0 ; python_version >= "3.8" and python_version < "3.11" +chex==0.1.5 ; python_version >= "3.8" and python_version < "3.11" +click==8.1.3 ; python_version >= "3.8" and python_version < "3.11" +cloudpickle==2.2.1 ; python_version >= "3.8" and python_version < "3.11" +colorama==0.4.4 ; python_version >= "3.8" and python_version < "3.11" +commonmark==0.9.1 ; python_version >= "3.8" and python_version < "3.11" +cycler==0.11.0 ; python_version >= "3.8" and python_version < "3.11" +decorator==4.4.2 ; python_version >= "3.8" and python_version < "3.11" +dm-tree==0.1.8 ; python_version >= "3.8" and python_version < "3.11" +docker-pycreds==0.4.0 ; python_version >= "3.8" and python_version < "3.11" +docstring-parser==0.15 ; python_version >= "3.8" and python_version < "3.11" +etils==0.9.0 ; python_version >= "3.8" and python_version < "3.11" +exceptiongroup==1.1.1 ; python_version >= "3.8" and python_version < "3.11" +farama-notifications==0.0.4 ; python_version >= "3.8" and python_version < "3.11" +filelock==3.12.0 ; python_version >= "3.8" and python_version < "3.11" +flax==0.6.8 ; python_version >= "3.8" and python_version < "3.11" +fonttools==4.38.0 ; python_version >= "3.8" and python_version < "3.11" +gitdb==4.0.10 ; python_version >= "3.8" and python_version < "3.11" +gitpython==3.1.31 ; python_version >= "3.8" and python_version < "3.11" +google-auth-oauthlib==0.4.6 ; python_version >= "3.8" and python_version < "3.11" +google-auth==2.18.0 ; python_version >= "3.8" and python_version < "3.11" +grpcio==1.54.0 ; python_version >= "3.8" and python_version < "3.11" +gym-notices==0.0.8 ; python_version >= "3.8" and python_version < "3.11" +gym==0.23.1 ; python_version >= "3.8" and python_version < "3.11" +gymnasium==0.28.1 ; python_version >= "3.8" and python_version < "3.11" +huggingface-hub==0.11.1 ; python_version >= "3.8" and python_version < "3.11" +idna==3.4 ; python_version >= "3.8" and python_version < "3.11" +imageio-ffmpeg==0.3.0 ; python_version >= "3.8" and python_version < "3.11" +imageio==2.28.1 ; python_version >= "3.8" and python_version < "3.11" +importlib-metadata==5.2.0 ; python_version >= "3.8" and python_version < "3.10" +importlib-resources==5.12.0 ; python_version >= "3.8" and python_version < "3.11" +iniconfig==2.0.0 ; python_version >= "3.8" and python_version < "3.11" +jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" +jax==0.4.8 ; python_version >= "3.8" and python_version < "3.11" +jaxlib==0.4.7 ; python_version >= "3.8" and python_version < "3.11" +kiwisolver==1.4.4 ; python_version >= "3.8" and python_version < "3.11" +markdown==3.3.7 ; python_version >= "3.8" and python_version < "3.11" +markupsafe==2.1.2 ; python_version >= "3.8" and python_version < "3.11" +matplotlib==3.5.3 ; python_version >= "3.8" and python_version < "3.11" +ml-dtypes==0.2.0 ; python_version >= "3.8" and python_version < "3.11" +moviepy==1.0.3 ; python_version >= "3.8" and python_version < "3.11" +msgpack==1.0.5 ; python_version >= "3.8" and python_version < "3.11" +numpy==1.24.4 ; python_version < "3.11" and python_version >= "3.8" +oauthlib==3.2.2 ; python_version >= "3.8" and python_version < "3.11" +opt-einsum==3.3.0 ; python_version >= "3.8" and python_version < "3.11" +optax==0.1.4 ; python_version >= "3.8" and python_version < "3.11" +orbax==0.1.0 ; python_version >= "3.8" and python_version < "3.11" +packaging==23.1 ; python_version >= "3.8" and python_version < "3.11" +pandas==1.3.5 ; python_version >= "3.8" and python_version < "3.11" +pathtools==0.1.2 ; python_version >= "3.8" and python_version < "3.11" +pillow==9.5.0 ; python_version >= "3.8" and python_version < "3.11" +pluggy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" +proglog==0.1.10 ; python_version >= "3.8" and python_version < "3.11" +protobuf==3.20.3 ; python_version < "3.11" and python_version >= "3.8" +psutil +envpool +pyasn1-modules==0.3.0 ; python_version >= "3.8" and python_version < "3.11" +pyasn1==0.5.0 ; python_version >= "3.8" and python_version < "3.11" +pygame==2.1.0 ; python_version >= "3.8" and python_version < "3.11" +pygments==2.15.1 ; python_version >= "3.8" and python_version < "3.11" +pyparsing==3.0.9 ; python_version >= "3.8" and python_version < "3.11" +pytest==7.3.1 ; python_version >= "3.8" and python_version < "3.11" +python-dateutil==2.8.2 ; python_version >= "3.8" and python_version < "3.11" +pytz==2023.3 ; python_version >= "3.8" and python_version < "3.11" +pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "3.11" +requests-oauthlib==1.3.1 ; python_version >= "3.8" and python_version < "3.11" +requests==2.30.0 ; python_version >= "3.8" and python_version < "3.11" +rich +rsa==4.7.2 ; python_version >= "3.8" and python_version < "3.11" +scipy==1.10.1 ; python_version >= "3.8" and python_version < "3.11" +sentry-sdk==1.22.2 ; python_version >= "3.8" and python_version < "3.11" +setproctitle==1.3.2 ; python_version >= "3.8" and python_version < "3.11" +setuptools==67.7.2 ; python_version >= "3.8" and python_version < "3.11" +shtab==1.6.4 ; python_version >= "3.8" and python_version < "3.11" +six==1.16.0 ; python_version >= "3.8" and python_version < "3.11" +smmap==5.0.0 ; python_version >= "3.8" and python_version < "3.11" +stable-baselines3==2.0.0 ; python_version >= "3.8" and python_version < "3.11" +tenacity==8.2.3 ; python_version >= "3.8" and python_version < "3.11" +tensorboard-data-server==0.6.1 ; python_version >= "3.8" and python_version < "3.11" +tensorboard-plugin-wit==1.8.1 ; python_version >= "3.8" and python_version < "3.11" +tensorboard==2.11.2 ; python_version >= "3.8" and python_version < "3.11" +tensorstore==0.1.28 ; python_version >= "3.8" and python_version < "3.11" +tomli==2.0.1 ; python_version >= "3.8" and python_version < "3.11" +toolz==0.12.0 ; python_version >= "3.8" and python_version < "3.11" +torch==1.12.1 ; python_version >= "3.8" and python_version < "3.11" +tqdm==4.65.0 ; python_version >= "3.8" and python_version < "3.11" +typing-extensions==4.5.0 ; python_version >= "3.8" and python_version < "3.11" +tyro==0.5.10 ; python_version >= "3.8" and python_version < "3.11" +urllib3==1.26.15 ; python_version >= "3.8" and python_version < "3.11" +wandb==0.13.11 ; python_version >= "3.8" and python_version < "3.11" +werkzeug==2.2.3 ; python_version >= "3.8" and python_version < "3.11" +wheel==0.40.0 ; python_version >= "3.8" and python_version < "3.11" +zipp==3.15.0 ; python_version >= "3.8" and python_version < "3.10" \ No newline at end of file diff --git a/benchmarks/cleanrl_jax/voirfile.py b/benchmarks/cleanrl_jax/voirfile.py new file mode 100644 index 000000000..d93f886cd --- /dev/null +++ b/benchmarks/cleanrl_jax/voirfile.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass + +from voir import configurable +from voir.instruments import dash, early_stop, log, rate +from benchmate.monitor import monitor_monogpu + +@dataclass +class Config: + """voir configuration""" + + # Whether to display the dash or not + dash: bool = False + + # How often to log the rates + interval: str = "1s" + + # Number of rates to skip before logging + skip: int = 5 + + # Number of rates to log before stopping + stop: int = 20 + + # Number of seconds between each gpu poll + gpu_poll: int = 3 + + +@configurable +def instrument_main(ov, options: Config): + yield ov.phases.init + + if options.dash: + ov.require(dash) + + ov.require( + log("value", "progress", "rate", "units", "loss", "gpudata", context="task"), + early_stop(n=options.stop, key="rate", task="train"), + monitor_monogpu(poll_interval=options.gpu_poll), + ) diff --git a/benchmarks/diffusion/main.py b/benchmarks/diffusion/main.py old mode 100644 new mode 100755 index 2b4fe9bfd..09513f606 --- a/benchmarks/diffusion/main.py +++ b/benchmarks/diffusion/main.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python + from dataclasses import dataclass from accelerate import Accelerator diff --git a/benchmarks/diffusion/requirements.cuda.txt b/benchmarks/diffusion/requirements.cuda.txt index 44bd64eab..a978bf360 100644 --- a/benchmarks/diffusion/requirements.cuda.txt +++ b/benchmarks/diffusion/requirements.cuda.txt @@ -9,7 +9,7 @@ --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --trusted-host pypi.ngc.nvidia.com -accelerate==0.33.0 +accelerate==0.34.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/diffusion/requirements.in @@ -47,7 +47,7 @@ attrs==24.2.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp -certifi==2024.7.4 +certifi==2024.8.30 # via # -c .pin/../.pin/constraints-cuda-torch.txt # requests @@ -63,7 +63,7 @@ datasets==2.21.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/diffusion/requirements.in -diffusers[torch]==0.30.1 +diffusers[torch]==0.30.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/diffusion/requirements.in @@ -421,7 +421,7 @@ xxhash==3.5.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets -yarl==1.9.4 +yarl==1.9.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp diff --git a/benchmarks/dinov2/requirements.cuda.txt b/benchmarks/dinov2/requirements.cuda.txt index 3c63f45f4..e3e2536c2 100644 --- a/benchmarks/dinov2/requirements.cuda.txt +++ b/benchmarks/dinov2/requirements.cuda.txt @@ -72,7 +72,7 @@ jinja2==3.1.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch -lightning-utilities==0.11.6 +lightning-utilities==0.11.7 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torchmetrics diff --git a/benchmarks/geo_gnn/requirements-pre.cuda.txt b/benchmarks/geo_gnn/requirements-pre.cuda.txt index 45e24191e..822b1e37e 100644 --- a/benchmarks/geo_gnn/requirements-pre.cuda.txt +++ b/benchmarks/geo_gnn/requirements-pre.cuda.txt @@ -14,7 +14,7 @@ filelock==3.15.4 # -c .pin/../.pin/constraints-cuda-gnn.txt # torch # triton -fsspec==2024.6.1 +fsspec==2024.9.0 # via # -c .pin/../.pin/constraints-cuda-gnn.txt # torch diff --git a/benchmarks/geo_gnn/requirements.cuda.txt b/benchmarks/geo_gnn/requirements.cuda.txt index 2596bd63d..f68973ae3 100644 --- a/benchmarks/geo_gnn/requirements.cuda.txt +++ b/benchmarks/geo_gnn/requirements.cuda.txt @@ -37,7 +37,7 @@ attrs==24.2.0 # via # -c .pin/../.pin/constraints-cuda-gnn.txt # aiohttp -certifi==2024.7.4 +certifi==2024.8.30 # via # -c .pin/../.pin/constraints-cuda-gnn.txt # requests @@ -64,7 +64,7 @@ frozenlist==1.4.1 # -c .pin/../.pin/constraints-cuda-gnn.txt # aiohttp # aiosignal -fsspec==2024.6.1 +fsspec==2024.9.0 # via # -c .pin/../.pin/constraints-cuda-gnn.txt # -r benchmarks/geo_gnn/requirements-pre.cuda.txt @@ -334,7 +334,7 @@ voir==0.2.19 # -c .pin/../.pin/constraints-cuda-gnn.txt # -c .pin/../constraints/cuda.txt # -r benchmarks/geo_gnn/requirements.in -yarl==1.9.4 +yarl==1.9.8 # via # -c .pin/../.pin/constraints-cuda-gnn.txt # aiohttp diff --git a/benchmarks/huggingface/requirements.cuda.txt b/benchmarks/huggingface/requirements.cuda.txt index 732f1633d..9342d1170 100644 --- a/benchmarks/huggingface/requirements.cuda.txt +++ b/benchmarks/huggingface/requirements.cuda.txt @@ -17,7 +17,7 @@ asttokens==2.4.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # giving -certifi==2024.7.4 +certifi==2024.8.30 # via # -c .pin/../.pin/constraints-cuda-torch.txt # requests diff --git a/benchmarks/lightning/requirements.cuda.txt b/benchmarks/lightning/requirements.cuda.txt index afafff613..09c8d79cb 100644 --- a/benchmarks/lightning/requirements.cuda.txt +++ b/benchmarks/lightning/requirements.cuda.txt @@ -98,7 +98,7 @@ lightning==2.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/lightning/requirements.in -lightning-utilities==0.11.6 +lightning-utilities==0.11.7 # via # -c .pin/../.pin/constraints-cuda-torch.txt # lightning @@ -327,7 +327,7 @@ voir==0.2.19 # -c .pin/../.pin/constraints-cuda-torch.txt # -c .pin/../constraints/cuda.txt # -r benchmarks/lightning/requirements.in -yarl==1.9.4 +yarl==1.9.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp diff --git a/benchmarks/llama/requirements.cuda.txt b/benchmarks/llama/requirements.cuda.txt index 11f814c36..80156f391 100644 --- a/benchmarks/llama/requirements.cuda.txt +++ b/benchmarks/llama/requirements.cuda.txt @@ -38,7 +38,7 @@ attrs==24.2.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp -certifi==2024.7.4 +certifi==2024.8.30 # via # -c .pin/../.pin/constraints-cuda-torch.txt # requests @@ -390,7 +390,7 @@ xxhash==3.5.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets -yarl==1.9.4 +yarl==1.9.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp diff --git a/benchmarks/llm/requirements.cuda.txt b/benchmarks/llm/requirements.cuda.txt index e65fc1ec8..94afa483c 100644 --- a/benchmarks/llm/requirements.cuda.txt +++ b/benchmarks/llm/requirements.cuda.txt @@ -46,7 +46,7 @@ blobfile==3.0.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torchtune -certifi==2024.7.4 +certifi==2024.8.30 # via # -c .pin/../.pin/constraints-cuda-torch.txt # requests @@ -403,7 +403,7 @@ xxhash==3.5.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets -yarl==1.9.4 +yarl==1.9.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp diff --git a/benchmarks/llm/tune b/benchmarks/llm/tune new file mode 160000 index 000000000..a83eeff00 --- /dev/null +++ b/benchmarks/llm/tune @@ -0,0 +1 @@ +Subproject commit a83eeff0079a73ee04a11e8fc2573ed8f671b231 diff --git a/benchmarks/purejaxrl/Makefile b/benchmarks/purejaxrl/Makefile new file mode 100644 index 000000000..48cd14f88 --- /dev/null +++ b/benchmarks/purejaxrl/Makefile @@ -0,0 +1,31 @@ +# Use global base if possible +ifndef MILABENCH_BASE + MILABENCH_BASE="base" +endif + +export MILABENCH_BASE + +BENCH_NAME=template +MILABENCH_CONFIG=dev.yaml +MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE) + +all: + install prepare single gpus nodes + +install: + milabench install $(MILABENCH_ARGS) --force + +prepare: + milabench prepare $(MILABENCH_ARGS) + +tests:# install prepare + CUDA_VISIBLE_DEVICES=0 milabench run $(MILABENCH_ARGS) + +single: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-single + +gpus: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus + +nodes: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes diff --git a/benchmarks/purejaxrl/README.md b/benchmarks/purejaxrl/README.md new file mode 100644 index 000000000..239a2dfaf --- /dev/null +++ b/benchmarks/purejaxrl/README.md @@ -0,0 +1,4 @@ + +# Template + +Rewrite this README to explain what the benchmark is! diff --git a/benchmarks/purejaxrl/benchfile.py b/benchmarks/purejaxrl/benchfile.py new file mode 100644 index 000000000..08a51cef0 --- /dev/null +++ b/benchmarks/purejaxrl/benchfile.py @@ -0,0 +1,31 @@ +from milabench.pack import Package + + +class Template(Package): + # Requirements file installed by install(). It can be empty or absent. + base_requirements = "requirements.in" + + # The preparation script called by prepare(). It must be executable, + # but it can be any type of script. It can be empty or absent. + prepare_script = "prepare.py" + + # The main script called by run(). It must be a Python file. It has to + # be present. + main_script = "main.py" + + # You can remove the functions below if you don't need to modify them. + + def make_env(self): + # Return a dict of environment variables for prepare_script and + # main_script. + return super().make_env() + + async def install(self): + await super().install() # super() call installs the requirements + + async def prepare(self): + await super().prepare() # super() call executes prepare_script + + + +__pack__ = Template diff --git a/benchmarks/purejaxrl/dev.yaml b/benchmarks/purejaxrl/dev.yaml new file mode 100644 index 000000000..ad1ebf87f --- /dev/null +++ b/benchmarks/purejaxrl/dev.yaml @@ -0,0 +1,28 @@ +_purejaxrl: + inherits: _defaults + definition: . + install-variant: unpinned + install_group: torch + plan: + method: per_gpu + +dqn: + inherits: _purejaxrl + argv: + dqn: true + --num_envs: auto({cpu_per_gpu}, 128) + --buffer_batch_size: 128 + --env_name: CartPole-v1 + --training_interval: 10 + --learning_starts: 10000 + +ppo: + inherits: _purejaxrl + argv: + ppo: true + --num_envs: auto({cpu_per_gpu}, 128) + --num_steps: 10 + --num_minibatches: 32 + --update_epochs: 4 + --env_name: hopper + --total_timesteps: 200000 \ No newline at end of file diff --git a/benchmarks/purejaxrl/dqn.py b/benchmarks/purejaxrl/dqn.py new file mode 100644 index 000000000..16fa55f52 --- /dev/null +++ b/benchmarks/purejaxrl/dqn.py @@ -0,0 +1,325 @@ +""" +PureJaxRL version of CleanRL's DQN: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_jax.py +""" +from dataclasses import dataclass +import time + +import jax +import jax.numpy as jnp +import chex +import flax +import optax +import flax.linen as nn +from flax.training.train_state import TrainState +from gymnax.wrappers.purerl import FlattenObservationWrapper, LogWrapper +import gymnax +import flashbax as fbx + +from benchmate.metrics import give_push + + + +class QNetwork(nn.Module): + action_dim: int + + @nn.compact + def __call__(self, x: jnp.ndarray): + x = nn.Dense(120)(x) + x = nn.relu(x) + x = nn.Dense(84)(x) + x = nn.relu(x) + x = nn.Dense(self.action_dim)(x) + return x + + +@chex.dataclass(frozen=True) +class TimeStep: + obs: chex.Array + action: chex.Array + reward: chex.Array + done: chex.Array + + +class CustomTrainState(TrainState): + target_network_params: flax.core.FrozenDict + timesteps: int + n_updates: int + + +def make_train(config): + config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // config["NUM_ENVS"] + + from benchmate.timings import StepTimer + step_timer = StepTimer(give_push()) + + basic_env, env_params = gymnax.make(config["ENV_NAME"]) + env = FlattenObservationWrapper(basic_env) + env = LogWrapper(env) + + vmap_reset = lambda n_envs: lambda rng: jax.vmap(env.reset, in_axes=(0, None))( + jax.random.split(rng, n_envs), env_params + ) + vmap_step = lambda n_envs: lambda rng, env_state, action: jax.vmap( + env.step, in_axes=(0, 0, 0, None) + )(jax.random.split(rng, n_envs), env_state, action, env_params) + + def train(rng): + + # INIT ENV + rng, _rng = jax.random.split(rng) + init_obs, env_state = vmap_reset(config["NUM_ENVS"])(_rng) + + # INIT BUFFER + buffer = fbx.make_flat_buffer( + max_length=config["BUFFER_SIZE"], + min_length=config["BUFFER_BATCH_SIZE"], + sample_batch_size=config["BUFFER_BATCH_SIZE"], + add_sequences=False, + add_batch_size=config["NUM_ENVS"], + ) + buffer = buffer.replace( + init=jax.jit(buffer.init), + add=jax.jit(buffer.add, donate_argnums=0), + sample=jax.jit(buffer.sample), + can_sample=jax.jit(buffer.can_sample), + ) + rng = jax.random.PRNGKey(0) # use a dummy rng here + _action = basic_env.action_space().sample(rng) + _, _env_state = env.reset(rng, env_params) + _obs, _, _reward, _done, _ = env.step(rng, _env_state, _action, env_params) + _timestep = TimeStep(obs=_obs, action=_action, reward=_reward, done=_done) + buffer_state = buffer.init(_timestep) + + # INIT NETWORK AND OPTIMIZER + network = QNetwork(action_dim=env.action_space(env_params).n) + rng, _rng = jax.random.split(rng) + init_x = jnp.zeros(env.observation_space(env_params).shape) + network_params = network.init(_rng, init_x) + + def linear_schedule(count): + frac = 1.0 - (count / config["NUM_UPDATES"]) + return config["LR"] * frac + + lr = linear_schedule if config.get("LR_LINEAR_DECAY", False) else config["LR"] + tx = optax.adam(learning_rate=lr) + + train_state = CustomTrainState.create( + apply_fn=network.apply, + params=network_params, + target_network_params=jax.tree_map(lambda x: jnp.copy(x), network_params), + tx=tx, + timesteps=0, + n_updates=0, + ) + + # epsilon-greedy exploration + def eps_greedy_exploration(rng, q_vals, t): + rng_a, rng_e = jax.random.split( + rng, 2 + ) # a key for sampling random actions and one for picking + eps = jnp.clip( # get epsilon + ( + (config["EPSILON_FINISH"] - config["EPSILON_START"]) + / config["EPSILON_ANNEAL_TIME"] + ) + * t + + config["EPSILON_START"], + config["EPSILON_FINISH"], + ) + greedy_actions = jnp.argmax(q_vals, axis=-1) # get the greedy actions + chosed_actions = jnp.where( + jax.random.uniform(rng_e, greedy_actions.shape) + < eps, # pick the actions that should be random + jax.random.randint( + rng_a, shape=greedy_actions.shape, minval=0, maxval=q_vals.shape[-1] + ), # sample random actions, + greedy_actions, + ) + return chosed_actions + + # TRAINING LOOP + def _update_step(runner_state, unused): + + train_state, buffer_state, env_state, last_obs, rng = runner_state + + # STEP THE ENV + rng, rng_a, rng_s = jax.random.split(rng, 3) + q_vals = network.apply(train_state.params, last_obs) + action = eps_greedy_exploration( + rng_a, q_vals, train_state.timesteps + ) # explore with epsilon greedy_exploration + obs, env_state, reward, done, info = vmap_step(config["NUM_ENVS"])( + rng_s, env_state, action + ) + train_state = train_state.replace( + timesteps=train_state.timesteps + config["NUM_ENVS"] + ) # update timesteps count + + # BUFFER UPDATE + timestep = TimeStep(obs=last_obs, action=action, reward=reward, done=done) + buffer_state = buffer.add(buffer_state, timestep) + + # NETWORKS UPDATE + def _learn_phase(train_state, rng): + + learn_batch = buffer.sample(buffer_state, rng).experience + + q_next_target = network.apply( + train_state.target_network_params, learn_batch.second.obs + ) # (batch_size, num_actions) + q_next_target = jnp.max(q_next_target, axis=-1) # (batch_size,) + target = ( + learn_batch.first.reward + + (1 - learn_batch.first.done) * config["GAMMA"] * q_next_target + ) + + def _loss_fn(params): + q_vals = network.apply( + params, learn_batch.first.obs + ) # (batch_size, num_actions) + chosen_action_qvals = jnp.take_along_axis( + q_vals, + jnp.expand_dims(learn_batch.first.action, axis=-1), + axis=-1, + ).squeeze(axis=-1) + return jnp.mean((chosen_action_qvals - target) ** 2) + + loss, grads = jax.value_and_grad(_loss_fn)(train_state.params) + train_state = train_state.apply_gradients(grads=grads) + train_state = train_state.replace(n_updates=train_state.n_updates + 1) + return train_state, loss + + rng, _rng = jax.random.split(rng) + is_learn_time = ( + (buffer.can_sample(buffer_state)) + & ( # enough experience in buffer + train_state.timesteps > config["LEARNING_STARTS"] + ) + & ( # pure exploration phase ended + train_state.timesteps % config["TRAINING_INTERVAL"] == 0 + ) # training interval + ) + train_state, loss = jax.lax.cond( + is_learn_time, + lambda train_state, rng: _learn_phase(train_state, rng), + lambda train_state, rng: (train_state, jnp.array(0.0)), # do nothing + train_state, + _rng, + ) + + # update target network + train_state = jax.lax.cond( + train_state.timesteps % config["TARGET_UPDATE_INTERVAL"] == 0, + lambda train_state: train_state.replace( + target_network_params=optax.incremental_update( + train_state.params, + train_state.target_network_params, + config["TAU"], + ) + ), + lambda train_state: train_state, + operand=train_state, + ) + + metrics = { + "timesteps": train_state.timesteps, + "updates": train_state.n_updates, + "loss": loss.mean(), + "returns": info["returned_episode_returns"].mean(), + } + + def callback(metrics): + # .block_until_ready() + if (metrics["timesteps"] + 1) % 1000: + returns = metrics["returns"].item() + loss = metrics["loss"].block_until_ready().item() + delta = metrics["timesteps"] - step_timer.timesteps + step_timer.timestep = metrics["timesteps"] + + step_timer.step(delta.item()) + step_timer.log(returns=returns, loss=loss) + step_timer.end() + + jax.debug.callback(callback, metrics) + + runner_state = (train_state, buffer_state, env_state, obs, rng) + + return runner_state, metrics + + # train + rng, _rng = jax.random.split(rng) + runner_state = (train_state, buffer_state, env_state, init_obs, _rng) + + runner_state, metrics = jax.lax.scan( + _update_step, runner_state, None, config["NUM_UPDATES"] + ) + return {"runner_state": runner_state, "metrics": metrics} + + return train + + +@dataclass +class Arguments: + num_envs: int = 10 + buffer_size: int = 10000 + buffer_batch_size: int = 128 + total_timesteps: int = 100_000 + epsilon_start: float = 1.0 + epsilon_finish: float = 0.05 + epsilon_anneal_time: int = 25e4 + target_update_interval: int = 500 + lr: float = 2.5e-4 + learning_starts: int = 10000 + training_interval: int = 10 + lr_linear_decay: bool = False + gamma: float = 0.99 + tau: float = 1.0 + env_name: str = "CartPole-v1" + seed: int = 0 + num_seeds: int = 1 + project: str = "" + + +def add_dqn_command(subparser): + parser = subparser.add_parser('dqn', help='RL dqn benchmark') + parser.add_arguments(Arguments) + + + +def main(args: Arguments = None): + if args is None: + args = Arguments() + + config = { + "NUM_ENVS": args.num_envs, + "BUFFER_SIZE": args.buffer_size, + "BUFFER_BATCH_SIZE": args.buffer_batch_size, + "TOTAL_TIMESTEPS": args.total_timesteps, + "EPSILON_START": args.epsilon_start, + "EPSILON_FINISH": args.epsilon_finish, + "EPSILON_ANNEAL_TIME": args.epsilon_anneal_time, + "TARGET_UPDATE_INTERVAL": args.target_update_interval, + "LR": args.lr, + "LEARNING_STARTS": args.learning_starts, + "TRAINING_INTERVAL": args.training_interval, + "LR_LINEAR_DECAY": args.lr_linear_decay, + "GAMMA": args.gamma, + "TAU": args.tau, + "ENV_NAME": args.env_name, + "SEED": args.seed, + "NUM_SEEDS": args.num_seeds, + "PROJECT": args.project, + } + + rng = jax.random.PRNGKey(config["SEED"]) + rngs = jax.random.split(rng, config["NUM_SEEDS"]) + train_vjit = jax.jit(jax.vmap(make_train(config), in_axes=(0,))) + compiled_fn = train_vjit.lower(rngs).compile() + + from benchmate.monitor import bench_monitor + with bench_monitor(): + outs = jax.block_until_ready(compiled_fn(rngs)) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/purejaxrl/main.py b/benchmarks/purejaxrl/main.py new file mode 100644 index 000000000..38eaf2792 --- /dev/null +++ b/benchmarks/purejaxrl/main.py @@ -0,0 +1,37 @@ +# This is the script run by milabench run (by default) + +# It is possible to use a script from a GitHub repo if it is cloned using +# clone_subtree in the benchfile.py, in which case this file can simply +# be deleted. + +import argparse +import argklass + + +from dqn import add_dqn_command, main as dqn_main +from ppo import add_ppo_command, main as ppo_main + + +def main(): + parser = argklass.ArgumentParser(description="PureJaxRL") + subparser = parser.add_subparsers(title="Benchmark", dest="benchmark") + + add_dqn_command(subparser) + add_ppo_command(subparser) + + bench = { + "dqn": dqn_main, + "ppo": ppo_main + } + + args = parser.parse_args() + + if benchmark := bench.get(args.benchmark): + benchmark(args) + + else: + raise ValueError(f"Unknown benchmark: {args.benchmark}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/purejaxrl/ppo.py b/benchmarks/purejaxrl/ppo.py new file mode 100644 index 000000000..a053373f3 --- /dev/null +++ b/benchmarks/purejaxrl/ppo.py @@ -0,0 +1,359 @@ +from dataclasses import dataclass +import time + +import jax +import jax.numpy as jnp +import flax.linen as nn +import numpy as np +import optax +from flax.linen.initializers import constant, orthogonal +from typing import Sequence, NamedTuple, Any +from flax.training.train_state import TrainState +import distrax +from benchmate.metrics import give_push + +from wrappers import ( + LogWrapper, + BraxGymnaxWrapper, + VecEnv, + NormalizeVecObservation, + NormalizeVecReward, + ClipAction, +) + + + + +class ActorCritic(nn.Module): + action_dim: Sequence[int] + activation: str = "tanh" + + @nn.compact + def __call__(self, x): + if self.activation == "relu": + activation = nn.relu + else: + activation = nn.tanh + actor_mean = nn.Dense( + 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(x) + actor_mean = activation(actor_mean) + actor_mean = nn.Dense( + 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(actor_mean) + actor_mean = activation(actor_mean) + actor_mean = nn.Dense( + self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + )(actor_mean) + actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,)) + pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd)) + + critic = nn.Dense( + 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(x) + critic = activation(critic) + critic = nn.Dense( + 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(critic) + critic = activation(critic) + critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( + critic + ) + + return pi, jnp.squeeze(critic, axis=-1) + + +class Transition(NamedTuple): + done: jnp.ndarray + action: jnp.ndarray + value: jnp.ndarray + reward: jnp.ndarray + log_prob: jnp.ndarray + obs: jnp.ndarray + info: jnp.ndarray + + +def make_train(config): + from benchmate.timings import StepTimer + step_timer = StepTimer(give_push()) + + config["NUM_UPDATES"] = ( + config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] + ) + config["MINIBATCH_SIZE"] = ( + config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] + ) + env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None + env = LogWrapper(env) + env = ClipAction(env) + env = VecEnv(env) + if config["NORMALIZE_ENV"]: + env = NormalizeVecObservation(env) + env = NormalizeVecReward(env, config["GAMMA"]) + + def linear_schedule(count): + frac = ( + 1.0 + - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) + / config["NUM_UPDATES"] + ) + return config["LR"] * frac + + def train(rng): + # INIT NETWORK + network = ActorCritic( + env.action_space(env_params).shape[0], activation=config["ACTIVATION"] + ) + rng, _rng = jax.random.split(rng) + init_x = jnp.zeros(env.observation_space(env_params).shape) + network_params = network.init(_rng, init_x) + if config["ANNEAL_LR"]: + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(learning_rate=linear_schedule, eps=1e-5), + ) + else: + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(config["LR"], eps=1e-5), + ) + train_state = TrainState.create( + apply_fn=network.apply, + params=network_params, + tx=tx, + ) + + # INIT ENV + rng, _rng = jax.random.split(rng) + reset_rng = jax.random.split(_rng, config["NUM_ENVS"]) + obsv, env_state = env.reset(reset_rng, env_params) + + # TRAIN LOOP + def _update_step(runner_state, unused): + # COLLECT TRAJECTORIES + def _env_step(runner_state, unused): + train_state, env_state, last_obs, rng = runner_state + + # SELECT ACTION + rng, _rng = jax.random.split(rng) + pi, value = network.apply(train_state.params, last_obs) + action = pi.sample(seed=_rng) + log_prob = pi.log_prob(action) + + # STEP ENV + rng, _rng = jax.random.split(rng) + rng_step = jax.random.split(_rng, config["NUM_ENVS"]) + obsv, env_state, reward, done, info = env.step( + rng_step, env_state, action, env_params + ) + transition = Transition( + done, action, value, reward, log_prob, last_obs, info + ) + runner_state = (train_state, env_state, obsv, rng) + return runner_state, transition + + runner_state, traj_batch = jax.lax.scan( + _env_step, runner_state, None, config["NUM_STEPS"] + ) + + # CALCULATE ADVANTAGE + train_state, env_state, last_obs, rng = runner_state + _, last_val = network.apply(train_state.params, last_obs) + + def _calculate_gae(traj_batch, last_val): + def _get_advantages(gae_and_next_value, transition): + gae, next_value = gae_and_next_value + done, value, reward = ( + transition.done, + transition.value, + transition.reward, + ) + delta = reward + config["GAMMA"] * next_value * (1 - done) - value + gae = ( + delta + + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae + ) + return (gae, value), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + advantages, targets = _calculate_gae(traj_batch, last_val) + + # UPDATE NETWORK + def _update_epoch(update_state, unused): + def _update_minbatch(train_state, batch_info): + traj_batch, advantages, targets = batch_info + + def _loss_fn(params, traj_batch, gae, targets): + # RERUN NETWORK + pi, value = network.apply(params, traj_batch.obs) + log_prob = pi.log_prob(traj_batch.action) + + # CALCULATE VALUE LOSS + value_pred_clipped = traj_batch.value + ( + value - traj_batch.value + ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = ( + 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + ) + + # CALCULATE ACTOR LOSS + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - config["CLIP_EPS"], + 1.0 + config["CLIP_EPS"], + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + entropy = pi.entropy().mean() + + total_loss = ( + loss_actor + + config["VF_COEF"] * value_loss + - config["ENT_COEF"] * entropy + ) + return total_loss, (value_loss, loss_actor, entropy) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + total_loss, grads = grad_fn( + train_state.params, traj_batch, advantages, targets + ) + train_state = train_state.apply_gradients(grads=grads) + return train_state, total_loss + + train_state, traj_batch, advantages, targets, rng = update_state + rng, _rng = jax.random.split(rng) + batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"] + assert ( + batch_size == config["NUM_STEPS"] * config["NUM_ENVS"] + ), "batch size must be equal to number of steps * number of envs" + permutation = jax.random.permutation(_rng, batch_size) + batch = (traj_batch, advantages, targets) + batch = jax.tree_util.tree_map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), batch + ) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) + ), + shuffled_batch, + ) + train_state, total_loss = jax.lax.scan( + _update_minbatch, train_state, minibatches + ) + update_state = (train_state, traj_batch, advantages, targets, rng) + return update_state, total_loss + + update_state = (train_state, traj_batch, advantages, targets, rng) + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config["UPDATE_EPOCHS"] + ) + train_state = update_state[0] + metric = traj_batch.info + rng = update_state[-1] + + metrics = { + "loss": loss_info, + } + + def callback(info): + total_loss, (value_loss, loss_actor, entropy) = info["loss"] + loss = total_loss.mean().item() + + step_timer.step(config["NUM_ENVS"] * config["NUM_STEPS"]) + step_timer.log(loss=loss) + step_timer.end() + + jax.debug.callback(callback, metrics) + + runner_state = (train_state, env_state, last_obs, rng) + return runner_state, metric + + rng, _rng = jax.random.split(rng) + runner_state = (train_state, env_state, obsv, _rng) + runner_state, metric = jax.lax.scan( + _update_step, runner_state, None, config["NUM_UPDATES"] + ) + return {"runner_state": runner_state, "metrics": metric} + + return train + + + +@dataclass +class Arguments: + lr: float = 3e-4 + num_envs: int = 2048 + num_steps: int = 10 + total_timesteps: float = 50_000_000 + update_epochs: int = 4 + num_minibatches: int = 32 + gamma: float = 0.99 + gae_lambda: float = 0.95 + clip_eps: float = 0.2 + ent_coef: float = 0.0 + vf_coef: float = 0.5 + max_grad_norm: float = 0.5 + activation: str = "tanh" + env_name: str = "hopper" + anneal_lr: bool = False + normalize_env: bool = True + + +def add_ppo_command(subparser): + parser = subparser.add_parser('ppo', help='RL dqn benchmark') + parser.add_arguments(Arguments) + + +def main(args: Arguments = None): + if args is None: + args = Arguments() + + config = { + "LR": args.lr, + "NUM_ENVS": args.num_envs, + "NUM_STEPS": args.num_steps, + "TOTAL_TIMESTEPS": args.total_timesteps, + "UPDATE_EPOCHS": args.update_epochs, + "NUM_MINIBATCHES": args.num_minibatches, + "GAMMA": args.gamma, + "GAE_LAMBDA": args.gae_lambda, + "CLIP_EPS": args.clip_eps, + "ENT_COEF": args.ent_coef, + "VF_COEF": args.vf_coef, + "MAX_GRAD_NORM": args.max_grad_norm, + "ACTIVATION": args.activation, + "ENV_NAME": args.env_name, + "ANNEAL_LR": args.anneal_lr, + "NORMALIZE_ENV": args.normalize_env, + } + rng = jax.random.PRNGKey(30) + train_jit = jax.jit(make_train(config)) + compiled_fn = train_jit.lower(rng).compile() + + from benchmate.monitor import bench_monitor + + with bench_monitor(): + out = compiled_fn(rng) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/purejaxrl/requirements.cuda.txt b/benchmarks/purejaxrl/requirements.cuda.txt new file mode 100644 index 000000000..9c7294cb8 --- /dev/null +++ b/benchmarks/purejaxrl/requirements.cuda.txt @@ -0,0 +1,525 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile --output-file=benchmarks/purejaxrl/requirements.cuda.txt .pin/tmp-constraints-cuda-ppo.txt benchmarks/purejaxrl/requirements.in +# +--extra-index-url https://pypi.ngc.nvidia.com +--extra-index-url https://download.pytorch.org/whl/cu121 +--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +--trusted-host pypi.ngc.nvidia.com + +absl-py==2.1.0 + # via + # brax + # chex + # distrax + # dm-env + # ml-collections + # mujoco + # mujoco-mjx + # optax + # orbax-checkpoint + # rlax + # tensorflow-probability +antlr4-python3-runtime==4.9.3 + # via omegaconf +argklass==1.4.4 + # via -r benchmarks/purejaxrl/requirements.in +astroid==3.2.4 + # via pylint +asttokens==2.4.1 + # via giving +black==24.8.0 + # via navix +blinker==1.8.2 + # via flask +brax==0.10.5 + # via -r benchmarks/purejaxrl/requirements.in +certifi==2024.8.30 + # via + # requests + # sentry-sdk +charset-normalizer==3.3.2 + # via requests +chex==0.1.86 + # via + # distrax + # evosax + # flashbax + # gymnax + # optax + # rlax +click==8.1.7 + # via + # black + # flask + # wandb +cloudpickle==3.0.0 + # via + # gym + # gymnasium + # tensorflow-probability +codefind==0.1.6 + # via ptera +contextlib2==21.6.0 + # via ml-collections +contourpy==1.3.0 + # via matplotlib +cycler==0.12.1 + # via matplotlib +decorator==5.1.1 + # via tensorflow-probability +dill==0.3.8 + # via pylint +distrax==0.1.5 + # via + # -r benchmarks/purejaxrl/requirements.in + # rlax +dm-env==1.6 + # via + # brax + # rlax +dm-tree==0.1.8 + # via + # dm-env + # tensorflow-probability +docker-pycreds==0.4.0 + # via wandb +docstring-parser==0.16 + # via tyro +dotmap==1.3.30 + # via evosax +etils[epath,epy]==1.9.4 + # via + # brax + # mujoco + # mujoco-mjx + # optax + # orbax-checkpoint +evosax==0.1.6 + # via -r benchmarks/purejaxrl/requirements.in +exceptiongroup==1.2.2 + # via pytest +executing==1.2.0 + # via varname +farama-notifications==0.0.4 + # via gymnasium +filelock==3.15.4 + # via + # torch + # triton +flake8==7.1.1 + # via navix +flashbax==0.1.2 + # via -r benchmarks/purejaxrl/requirements.in +flask==3.0.3 + # via + # brax + # flask-cors +flask-cors==5.0.0 + # via brax +flax==0.9.0 + # via + # -r benchmarks/purejaxrl/requirements.in + # brax + # evosax + # flashbax + # gymnax + # navix +fonttools==4.53.1 + # via matplotlib +fsspec==2024.9.0 + # via + # etils + # torch +gast==0.6.0 + # via tensorflow-probability +gitdb==4.0.11 + # via gitpython +gitpython==3.1.43 + # via wandb +giving==0.4.2 + # via + # ptera + # voir +glfw==2.7.0 + # via mujoco +grpcio==1.66.1 + # via brax +gym==0.26.2 + # via + # brax + # gymnax +gym-notices==0.0.8 + # via gym +gymnasium==0.29.1 + # via gymnax +gymnax==0.0.8 + # via -r benchmarks/purejaxrl/requirements.in +hjson==3.1.0 + # via argklass +humanize==4.10.0 + # via orbax-checkpoint +idna==3.8 + # via requests +importlib-resources==6.4.4 + # via + # argklass + # etils +iniconfig==2.0.0 + # via pytest +isort==5.13.2 + # via pylint +itsdangerous==2.2.0 + # via flask +jax[cuda12]==0.4.31 + # via + # -r benchmarks/purejaxrl/requirements.in + # brax + # chex + # distrax + # evosax + # flashbax + # flax + # gymnax + # jaxopt + # mujoco-mjx + # optax + # orbax-checkpoint + # rlax +jax-cuda12-pjrt==0.4.31 + # via jax-cuda12-plugin +jax-cuda12-plugin[with-cuda]==0.4.31 + # via jax +jaxlib==0.4.31 + # via + # brax + # chex + # distrax + # evosax + # flashbax + # gymnax + # jax + # jaxopt + # mujoco-mjx + # optax + # orbax-checkpoint + # rlax +jaxopt==0.8.3 + # via brax +jinja2==3.1.4 + # via + # brax + # flask + # torch +kiwisolver==1.4.7 + # via matplotlib +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.5 + # via + # jinja2 + # werkzeug +matplotlib==3.9.2 + # via + # evosax + # gymnax + # seaborn +mccabe==0.7.0 + # via + # flake8 + # pylint +mdurl==0.1.2 + # via markdown-it-py +ml-collections==0.1.1 + # via brax +ml-dtypes==0.4.0 + # via + # jax + # jaxlib + # tensorstore +mpmath==1.3.0 + # via sympy +msgpack==1.0.8 + # via + # flax + # orbax-checkpoint +mujoco==3.2.2 + # via + # brax + # mujoco-mjx +mujoco-mjx==3.2.2 + # via brax +mypy-extensions==1.0.0 + # via black +navix==0.7.0 + # via -r benchmarks/purejaxrl/requirements.in +nest-asyncio==1.6.0 + # via orbax-checkpoint +networkx==3.3 + # via torch +numpy==2.1.1 + # via + # -r benchmarks/purejaxrl/requirements.in + # brax + # chex + # contourpy + # distrax + # dm-env + # evosax + # flashbax + # gym + # gymnasium + # jax + # jaxlib + # jaxopt + # matplotlib + # ml-dtypes + # mujoco + # navix + # opt-einsum + # optax + # orbax-checkpoint + # pandas + # rlax + # scipy + # seaborn + # tensorboardx + # tensorflow-probability + # tensorstore + # trimesh +nvidia-cublas-cu12==12.1.3.1 + # via + # jax-cuda12-plugin + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.1.105 + # via + # jax-cuda12-plugin + # torch +nvidia-cuda-nvcc-cu12==12.6.68 + # via jax-cuda12-plugin +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via + # jax-cuda12-plugin + # torch +nvidia-cudnn-cu12==9.1.0.70 + # via + # jax-cuda12-plugin + # torch +nvidia-cufft-cu12==11.0.2.54 + # via + # jax-cuda12-plugin + # torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via + # jax-cuda12-plugin + # torch +nvidia-cusparse-cu12==12.1.0.106 + # via + # jax-cuda12-plugin + # nvidia-cusolver-cu12 + # torch +nvidia-ml-py==12.560.30 + # via voir +nvidia-nccl-cu12==2.20.5 + # via + # jax-cuda12-plugin + # torch +nvidia-nvjitlink-cu12==12.6.68 + # via + # jax-cuda12-plugin + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch +omegaconf==2.3.0 + # via voir +opt-einsum==3.3.0 + # via jax +optax==0.2.3 + # via + # -r benchmarks/purejaxrl/requirements.in + # brax + # flax +orbax-checkpoint==0.6.1 + # via + # brax + # flax +ovld==0.3.9 + # via voir +packaging==24.1 + # via + # black + # matplotlib + # pytest + # setuptools-scm + # tensorboardx +pandas==2.2.2 + # via seaborn +pathspec==0.12.1 + # via black +pillow==10.4.0 + # via + # brax + # matplotlib + # navix +platformdirs==4.2.2 + # via + # black + # pylint + # wandb +pluggy==1.5.0 + # via pytest +protobuf==5.28.0 + # via + # orbax-checkpoint + # tensorboardx + # wandb +psutil==5.9.8 + # via + # voir + # wandb +ptera==1.4.1 + # via voir +pycodestyle==2.12.1 + # via flake8 +pyflakes==3.2.0 + # via flake8 +pygments==2.18.0 + # via rich +pylint==3.2.7 + # via navix +pyopengl==3.1.7 + # via mujoco +pyparsing==3.1.4 + # via matplotlib +pytest==8.3.2 + # via navix +python-dateutil==2.9.0.post0 + # via + # matplotlib + # pandas +pytinyrenderer==0.0.14 + # via brax +pytz==2024.1 + # via pandas +pyyaml==6.0.2 + # via + # evosax + # flax + # gymnax + # ml-collections + # omegaconf + # orbax-checkpoint + # wandb +reactivex==4.0.4 + # via giving +requests==2.32.3 + # via wandb +rich==13.8.0 + # via + # flax + # tyro + # voir +rlax==0.1.6 + # via navix +scipy==1.14.1 + # via + # brax + # jax + # jaxlib + # jaxopt + # mujoco-mjx +seaborn==0.13.2 + # via gymnax +sentry-sdk==2.13.0 + # via wandb +setproctitle==1.3.3 + # via wandb +setuptools-scm==8.1.0 + # via navix +shtab==1.7.1 + # via tyro +six==1.16.0 + # via + # asttokens + # docker-pycreds + # ml-collections + # python-dateutil + # tensorflow-probability +smmap==5.0.1 + # via gitdb +sympy==1.13.2 + # via torch +tensorboardx==2.6.2.2 + # via brax +tensorflow-probability==0.24.0 + # via distrax +tensorstore==0.1.64 + # via + # flashbax + # flax + # orbax-checkpoint +tomli==2.0.1 + # via + # black + # pylint + # pytest + # setuptools-scm +tomlkit==0.13.2 + # via pylint +toolz==0.12.1 + # via chex +torch==2.4.1+cu121 + # via -r benchmarks/purejaxrl/requirements.in +trimesh==4.4.9 + # via + # brax + # mujoco-mjx +triton==3.0.0 + # via torch +typing-extensions==4.12.2 + # via + # astroid + # black + # brax + # chex + # etils + # flashbax + # flax + # gymnasium + # navix + # orbax-checkpoint + # reactivex + # torch + # tyro +tyro==0.8.10 + # via navix +tzdata==2024.1 + # via pandas +urllib3==2.2.2 + # via + # requests + # sentry-sdk +varname==0.10.0 + # via giving +voir==0.2.19 + # via + # -c .pin/../constraints/cuda.txt + # -r benchmarks/purejaxrl/requirements.in +wandb==0.17.8 + # via navix +werkzeug==3.0.4 + # via flask +zipp==3.20.1 + # via etils + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/benchmarks/purejaxrl/requirements.in b/benchmarks/purejaxrl/requirements.in new file mode 100644 index 000000000..dc225d151 --- /dev/null +++ b/benchmarks/purejaxrl/requirements.in @@ -0,0 +1,14 @@ +voir +torch +jax +gymnax +evosax +distrax +optax +flax +numpy +brax +flashbax +navix +torch +argklass \ No newline at end of file diff --git a/benchmarks/purejaxrl/voirfile.py b/benchmarks/purejaxrl/voirfile.py new file mode 100644 index 000000000..5305be3f4 --- /dev/null +++ b/benchmarks/purejaxrl/voirfile.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass + +from voir import configurable +from voir.instruments import dash, early_stop, log, rate +from benchmate.monitor import monitor_monogpu + +@dataclass +class Config: + """voir configuration""" + + # Whether to display the dash or not + dash: bool = False + + # How often to log the rates + interval: str = "1s" + + # Number of rates to skip before logging + skip: int = 5 + + # Number of rates to log before stopping + stop: int = 20 + + # Number of seconds between each gpu poll + gpu_poll: float = 0.5 + + +@configurable +def instrument_main(ov, options: Config): + yield ov.phases.init + + if options.dash: + ov.require(dash) + + ov.require( + log("value", "progress", "rate", "units", "loss", "gpudata", context="task"), + # early_stop(n=options.stop, key="rate", task="train"), + monitor_monogpu(poll_interval=options.gpu_poll), + ) diff --git a/benchmarks/purejaxrl/wrappers.py b/benchmarks/purejaxrl/wrappers.py new file mode 100644 index 000000000..81b397211 --- /dev/null +++ b/benchmarks/purejaxrl/wrappers.py @@ -0,0 +1,349 @@ +import jax +import jax.numpy as jnp +import chex +import numpy as np +from flax import struct +from functools import partial +from typing import Optional, Tuple, Union, Any +from gymnax.environments import environment, spaces +from brax import envs +from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper +import navix as nx + + +class GymnaxWrapper(object): + """Base class for Gymnax wrappers.""" + + def __init__(self, env): + self._env = env + + # provide proxy access to regular attributes of wrapped object + def __getattr__(self, name): + return getattr(self._env, name) + + +class FlattenObservationWrapper(GymnaxWrapper): + """Flatten the observations of the environment.""" + + def __init__(self, env: environment.Environment): + super().__init__(env) + + def observation_space(self, params) -> spaces.Box: + assert isinstance( + self._env.observation_space(params), spaces.Box + ), "Only Box spaces are supported for now." + return spaces.Box( + low=self._env.observation_space(params).low, + high=self._env.observation_space(params).high, + shape=(np.prod(self._env.observation_space(params).shape),), + dtype=self._env.observation_space(params).dtype, + ) + + @partial(jax.jit, static_argnums=(0,)) + def reset( + self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None + ) -> Tuple[chex.Array, environment.EnvState]: + obs, state = self._env.reset(key, params) + obs = jnp.reshape(obs, (-1,)) + return obs, state + + @partial(jax.jit, static_argnums=(0,)) + def step( + self, + key: chex.PRNGKey, + state: environment.EnvState, + action: Union[int, float], + params: Optional[environment.EnvParams] = None, + ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]: + obs, state, reward, done, info = self._env.step(key, state, action, params) + obs = jnp.reshape(obs, (-1,)) + return obs, state, reward, done, info + + +@struct.dataclass +class LogEnvState: + env_state: environment.EnvState + episode_returns: float + episode_lengths: int + returned_episode_returns: float + returned_episode_lengths: int + timestep: int + + +class LogWrapper(GymnaxWrapper): + """Log the episode returns and lengths.""" + + def __init__(self, env: environment.Environment): + super().__init__(env) + + @partial(jax.jit, static_argnums=(0,)) + def reset( + self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None + ) -> Tuple[chex.Array, environment.EnvState]: + obs, env_state = self._env.reset(key, params) + state = LogEnvState(env_state, 0, 0, 0, 0, 0) + return obs, state + + @partial(jax.jit, static_argnums=(0,)) + def step( + self, + key: chex.PRNGKey, + state: environment.EnvState, + action: Union[int, float], + params: Optional[environment.EnvParams] = None, + ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]: + obs, env_state, reward, done, info = self._env.step( + key, state.env_state, action, params + ) + new_episode_return = state.episode_returns + reward + new_episode_length = state.episode_lengths + 1 + state = LogEnvState( + env_state=env_state, + episode_returns=new_episode_return * (1 - done), + episode_lengths=new_episode_length * (1 - done), + returned_episode_returns=state.returned_episode_returns * (1 - done) + + new_episode_return * done, + returned_episode_lengths=state.returned_episode_lengths * (1 - done) + + new_episode_length * done, + timestep=state.timestep + 1, + ) + info["returned_episode_returns"] = state.returned_episode_returns + info["returned_episode_lengths"] = state.returned_episode_lengths + info["timestep"] = state.timestep + info["returned_episode"] = done + return obs, state, reward, done, info + + +class BraxGymnaxWrapper: + def __init__(self, env_name, backend="positional"): + env = envs.get_environment(env_name=env_name, backend=backend) + env = EpisodeWrapper(env, episode_length=1000, action_repeat=1) + env = AutoResetWrapper(env) + self._env = env + self.action_size = env.action_size + self.observation_size = (env.observation_size,) + + def reset(self, key, params=None): + state = self._env.reset(key) + return state.obs, state + + def step(self, key, state, action, params=None): + next_state = self._env.step(state, action) + return next_state.obs, next_state, next_state.reward, next_state.done > 0.5, {} + + def observation_space(self, params): + return spaces.Box( + low=-jnp.inf, + high=jnp.inf, + shape=(self._env.observation_size,), + ) + + def action_space(self, params): + return spaces.Box( + low=-1.0, + high=1.0, + shape=(self._env.action_size,), + ) + + +class NavixGymnaxWrapper: + def __init__(self, env_name): + self._env = nx.make(env_name) + + def reset(self, key, params=None): + timestep = self._env.reset(key) + return timestep.observation, timestep + + def step(self, key, state, action, params=None): + timestep = self._env.step(state, action) + return timestep.observation, timestep, timestep.reward, timestep.is_done(), {} + + def observation_space(self, params): + return spaces.Box( + low=self._env.observation_space.minimum, + high=self._env.observation_space.maximum, + shape=(np.prod(self._env.observation_space.shape),), + dtype=self._env.observation_space.dtype, + ) + + def action_space(self, params): + return spaces.Discrete( + num_categories=self._env.action_space.maximum.item() + 1, + ) + + +class ClipAction(GymnaxWrapper): + def __init__(self, env, low=-1.0, high=1.0): + super().__init__(env) + self.low = low + self.high = high + + def step(self, key, state, action, params=None): + """TODO: In theory the below line should be the way to do this.""" + # action = jnp.clip(action, self.env.action_space.low, self.env.action_space.high) + action = jnp.clip(action, self.low, self.high) + return self._env.step(key, state, action, params) + + +class TransformObservation(GymnaxWrapper): + def __init__(self, env, transform_obs): + super().__init__(env) + self.transform_obs = transform_obs + + def reset(self, key, params=None): + obs, state = self._env.reset(key, params) + return self.transform_obs(obs), state + + def step(self, key, state, action, params=None): + obs, state, reward, done, info = self._env.step(key, state, action, params) + return self.transform_obs(obs), state, reward, done, info + + +class TransformReward(GymnaxWrapper): + def __init__(self, env, transform_reward): + super().__init__(env) + self.transform_reward = transform_reward + + def step(self, key, state, action, params=None): + obs, state, reward, done, info = self._env.step(key, state, action, params) + return obs, state, self.transform_reward(reward), done, info + + +class VecEnv(GymnaxWrapper): + def __init__(self, env): + super().__init__(env) + self.reset = jax.vmap(self._env.reset, in_axes=(0, None)) + self.step = jax.vmap(self._env.step, in_axes=(0, 0, 0, None)) + + +@struct.dataclass +class NormalizeVecObsEnvState: + mean: jnp.ndarray + var: jnp.ndarray + count: float + env_state: environment.EnvState + + +class NormalizeVecObservation(GymnaxWrapper): + def __init__(self, env): + super().__init__(env) + + def reset(self, key, params=None): + obs, state = self._env.reset(key, params) + state = NormalizeVecObsEnvState( + mean=jnp.zeros_like(obs), + var=jnp.ones_like(obs), + count=1e-4, + env_state=state, + ) + batch_mean = jnp.mean(obs, axis=0) + batch_var = jnp.var(obs, axis=0) + batch_count = obs.shape[0] + + delta = batch_mean - state.mean + tot_count = state.count + batch_count + + new_mean = state.mean + delta * batch_count / tot_count + m_a = state.var * state.count + m_b = batch_var * batch_count + M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count + new_var = M2 / tot_count + new_count = tot_count + + state = NormalizeVecObsEnvState( + mean=new_mean, + var=new_var, + count=new_count, + env_state=state.env_state, + ) + + return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state + + def step(self, key, state, action, params=None): + obs, env_state, reward, done, info = self._env.step( + key, state.env_state, action, params + ) + + batch_mean = jnp.mean(obs, axis=0) + batch_var = jnp.var(obs, axis=0) + batch_count = obs.shape[0] + + delta = batch_mean - state.mean + tot_count = state.count + batch_count + + new_mean = state.mean + delta * batch_count / tot_count + m_a = state.var * state.count + m_b = batch_var * batch_count + M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count + new_var = M2 / tot_count + new_count = tot_count + + state = NormalizeVecObsEnvState( + mean=new_mean, + var=new_var, + count=new_count, + env_state=env_state, + ) + return ( + (obs - state.mean) / jnp.sqrt(state.var + 1e-8), + state, + reward, + done, + info, + ) + + +@struct.dataclass +class NormalizeVecRewEnvState: + mean: jnp.ndarray + var: jnp.ndarray + count: float + return_val: float + env_state: environment.EnvState + + +class NormalizeVecReward(GymnaxWrapper): + def __init__(self, env, gamma): + super().__init__(env) + self.gamma = gamma + + def reset(self, key, params=None): + obs, state = self._env.reset(key, params) + batch_count = obs.shape[0] + state = NormalizeVecRewEnvState( + mean=0.0, + var=1.0, + count=1e-4, + return_val=jnp.zeros((batch_count,)), + env_state=state, + ) + return obs, state + + def step(self, key, state, action, params=None): + obs, env_state, reward, done, info = self._env.step( + key, state.env_state, action, params + ) + return_val = state.return_val * self.gamma * (1 - done) + reward + + batch_mean = jnp.mean(return_val, axis=0) + batch_var = jnp.var(return_val, axis=0) + batch_count = obs.shape[0] + + delta = batch_mean - state.mean + tot_count = state.count + batch_count + + new_mean = state.mean + delta * batch_count / tot_count + m_a = state.var * state.count + m_b = batch_var * batch_count + M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count + new_var = M2 / tot_count + new_count = tot_count + + state = NormalizeVecRewEnvState( + mean=new_mean, + var=new_var, + count=new_count, + return_val=return_val, + env_state=env_state, + ) + return obs, state, reward / jnp.sqrt(state.var + 1e-8), done, info diff --git a/benchmarks/recursiongfn/requirements.cuda.txt b/benchmarks/recursiongfn/requirements.cuda.txt index 41cc95be8..ab28d70b2 100644 --- a/benchmarks/recursiongfn/requirements.cuda.txt +++ b/benchmarks/recursiongfn/requirements.cuda.txt @@ -49,7 +49,7 @@ botorch==0.11.3 # via # -c .pin/../.pin/constraints-cuda-gnn.txt # -r benchmarks/recursiongfn/requirements.in -certifi==2024.7.4 +certifi==2024.8.30 # via # -c .pin/../.pin/constraints-cuda-gnn.txt # requests @@ -88,7 +88,7 @@ frozenlist==1.4.1 # -c .pin/../.pin/constraints-cuda-gnn.txt # aiohttp # aiosignal -fsspec==2024.6.1 +fsspec==2024.9.0 # via # -c .pin/../.pin/constraints-cuda-gnn.txt # torch @@ -121,7 +121,7 @@ idna==3.8 # -c .pin/../.pin/constraints-cuda-gnn.txt # requests # yarl -jaxtyping==0.2.33 +jaxtyping==0.2.34 # via # -c .pin/../.pin/constraints-cuda-gnn.txt # linear-operator @@ -487,7 +487,7 @@ werkzeug==3.0.4 # via # -c .pin/../.pin/constraints-cuda-gnn.txt # tensorboard -yarl==1.9.4 +yarl==1.9.8 # via # -c .pin/../.pin/constraints-cuda-gnn.txt # aiohttp diff --git a/benchmarks/timm/requirements.cuda.txt b/benchmarks/timm/requirements.cuda.txt index 99103a8cf..33ec5562c 100644 --- a/benchmarks/timm/requirements.cuda.txt +++ b/benchmarks/timm/requirements.cuda.txt @@ -17,7 +17,7 @@ asttokens==2.4.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # giving -certifi==2024.7.4 +certifi==2024.8.30 # via # -c .pin/../.pin/constraints-cuda-torch.txt # requests diff --git a/benchmate/benchmate/timings.py b/benchmate/benchmate/timings.py new file mode 100644 index 000000000..a4b7343e3 --- /dev/null +++ b/benchmate/benchmate/timings.py @@ -0,0 +1,64 @@ +import os +import time + + +def getenv(name, type, default): + try: + return type(os.getenv(name, default)) + except: + return default + + +def total_observations(): + return ( + getenv("VOIR_EARLYSTOP_COUNT", int, 60) + + getenv("VOIR_EARLYSTOP_SKIP", int, 5) + ) + + +class StepTimer: + """ + + Examples + -------- + + .. code-block:: python + + step_timer = StepTimer() + for i in range(epochs): + + for i, batch in enumerate(data): + step_timer.step(batch.shape[0]) + step_timer.log(loss=...) + + if (i + 1) % grad_acc == 0: + optimizer.step() + step_timer.end() + + """ + def __init__(self, pusher, sync = lambda: None): + self.start_time = time.perf_counter() + self.end_time = 0 + self.n_size = 0 + self.n_obs = 0 + self.total_obs = total_observations() + self.pusher = pusher + self.sync = sync + self.timesteps = 0 + + def step(self, step_size): + """Log a batch size or work that was been done""" + self.n_size += step_size + + def end(self): + """Push a new perf observation""" + self.sync() + self.end_time = time.perf_counter() + self.pusher(rate=self.n_size/(self.end_time - self.start_time), units="items/s", task="train") + self.pusher(progress=(self.n_obs, self.total_obs), task="early_stop") + self.size = 0 + self.n_obs += 1 + self.start_time = self.end_time + + def log(self, **kwargs): + self.pusher(**kwargs) diff --git a/config/base.yaml b/config/base.yaml index 3d02f33e6..4d93cef35 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -394,10 +394,11 @@ _diffusion: inherits: _defaults definition: ../benchmarks/diffusion install_group: torch + tags: + - diffusion plan: - method: njobs - n: 1 - + method: per_gpu + argv: --num_epochs: 5 --batch_size: 32 @@ -408,10 +409,13 @@ diffusion-single: inherits: _diffusion num_machines: 1 plan: - method: njobs + method: per_gpu diffusion-gpus: inherits: _diffusion + plan: + method: njobs + n: 1 num_machines: 1 diffusion-nodes: @@ -426,6 +430,8 @@ _lightning: inherits: _defaults definition: ../benchmarks/lightning install_group: torch + tags: + - lightning argv: --epochs: 10 --num-workers: "auto({n_worker}, 8)" @@ -452,6 +458,9 @@ _dinov2: definition: ../benchmarks/dinov2 install_group: torch num_machines: 1 + tags: + - image + - transformer plan: method: njobs n: 1 @@ -505,7 +514,9 @@ _llm: voir: options: stop: 30 - + tags: + - nlp + - llm max_duration: 1200 num_machines: 1 inherits: _defaults @@ -517,7 +528,6 @@ llm-lora-single: inherits: _llm plan: method: per_gpu - argv: "{milabench_code}/recipes/lora_finetune_single_device.py": true --config: "{milabench_code}/configs/llama3_8B_lora_single_device.yaml" @@ -647,10 +657,36 @@ llm-full-mp-nodes: requires_capabilities: - "len(nodes) >= ${num_machines}" +_purejaxrl: + inherits: _defaults + definition: ../benchmarks/purejaxrl + plan: + method: per_gpu + +dqn: + inherits: _purejaxrl + argv: + dqn: true + --num_envs: auto({cpu_per_gpu}, 128) + --buffer_batch_size: 128 + --env_name: CartPole-v1 + --training_interval: 10 + +ppo: + inherits: _purejaxrl + argv: + ppo: true + --num_envs: auto({cpu_per_gpu}, 128) + --num_steps: 10 + --num_minibatches: 32 + --update_epochs: 4 + --env_name: hopper + --total_timesteps: 200000 _geo_gnn: inherits: _defaults - definition: . + tags: + - graph # FIXME: torch cluster is laging behind pytorch # we are forced to use torch==2.3 instead of torch==2.4 install_group: gnn @@ -666,12 +702,13 @@ dimenet: --num-samples: 10000 --use3d: True - recursiongfn: inherits: _defaults definition: ../benchmarks/recursiongfn install_group: gnn group: recursiongfn_gnn + tags: + - graph plan: method: per_gpu @@ -682,14 +719,14 @@ recursiongfn: --layer_width: 128 --num_layers: 4 - torchatari: inherits: _defaults definition: ../benchmarks/torchatari install_group: torch plan: method: per_gpu - + tags: + - rl argv: --num-minibatches: 16 --update-epochs: 4 diff --git a/milabench/_version.py b/milabench/_version.py index b6553ed7d..0202d13c4 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -1,5 +1,6 @@ """This file is generated, do not modify""" -__tag__ = "v0.1.0-55-g4230ec2d" -__commit__ = "4230ec2d7c587501f8fa8496fba68b6985423a05" -__date__ = "2024-08-29 15:34:06 -0400" +__tag__ = "v0.1.0-28-g8069946" +__commit__ = "8069946d331fb92090057d7eedd598515249521d" +__date__ = "2024-08-01 12:39:13 -0400" + diff --git a/milabench/scripts/vcs.py b/milabench/scripts/vcs.py index 0f895f886..54bc7638d 100644 --- a/milabench/scripts/vcs.py +++ b/milabench/scripts/vcs.py @@ -26,10 +26,16 @@ def retrieve_git_versions(tag="", commit="", date=""): } +def version_file(): + return os.path.join(ROOT, "milabench", "_version.py") + def read_previous(): info = ["", "", ""] - - with open(os.path.join(ROOT, "milabench", "_version.py"), "r") as file: + + if not os.path.exists(version_file()): + return info + + with open(version_file(), "r") as file: for line in file.readlines(): if "tag" in line: _, v = line.split("=") @@ -49,7 +55,7 @@ def read_previous(): def update_version_file(): version_info = retrieve_git_versions(*read_previous()) - with open(os.path.join(ROOT, "milabench", "_version.py"), "w") as file: + with open(version_file(), "w") as file: file.write('"""') file.write("This file is generated, do not modify") file.write('"""\n\n') diff --git a/milabench/sizer.py b/milabench/sizer.py index b3fa40478..b1f717247 100644 --- a/milabench/sizer.py +++ b/milabench/sizer.py @@ -172,7 +172,7 @@ def argv(self, benchmark, capacity, argv): return argv newsize = self.size(benchmark, capacity) - + if newsize is None: return argv @@ -181,10 +181,10 @@ def argv(self, benchmark, capacity, argv): argname = config.get("arg") if argname is None: return argv - + # placeholder replace # train.batch_size_per_gpu={batch_size} - placeholder = "{batch_size}" + placeholder = "{batch_size}" if placeholder in argname: newval = argname.format(batch_size=str(newsize)) @@ -193,7 +193,7 @@ def argv(self, benchmark, capacity, argv): break else: return argv + [newval] - + argv[i] = newval return argv @@ -230,7 +230,7 @@ def scale_argv(pack, argv): sizer = batch_sizer() system = system_global.get() - + if system: capacity = system.get("gpu", dict()).get("capacity") return sizer.argv(pack, capacity, argv) @@ -266,8 +266,8 @@ def on_start(self, entry): if template is None: self.benchname = None return - - placeholder = "{batch_size}" + + placeholder = "{batch_size}" argstart = template.replace(placeholder, "") is_template = False @@ -276,8 +276,8 @@ def on_start(self, entry): if arg.endswith(template): found = i break - - # + + # if arg.startswith(argstart): found = i is_template = True @@ -350,12 +350,12 @@ def report(self, *args): def arch_to_device(arch): device_types = [ - "cpu", - "cuda", - "ipu", - "xpu", - "mkldnn", - "opengl", "opencl", "ideep", "hip", "ve", + "cpu", + "cuda", + "ipu", + "xpu", + "mkldnn", + "opengl", "opencl", "ideep", "hip", "ve", "fpga", "maia", "xla", "lazy", "vulkan", "mps", "meta", "hpu", "mtia", "privateuseone" ] @@ -384,7 +384,6 @@ def new_argument_resolver(pack): ccl = {"hpu": "hccl", "cuda": "nccl", "rocm": "rccl", "xpu": "ccl", "cpu": "gloo"} - cpu_opt = CPUOptions() def auto(value, default): if cpu_opt.enabled: @@ -407,7 +406,7 @@ def clamp(x, mn=cpu_opt.cpu_min, mx=cpu_opt.cpu_max): context["arch"] = arch context["device_name"] = arch_to_device(arch) context["ccl"] = ccl.get(arch, "gloo") - + context["milabench_base"] = option("base", str, default="") dirs = vars(pack.dirs) context["milabench_venv"] = dirs.get('venv', "") diff --git a/tests/test_command_reg/test_command_reg_one_node.txt b/tests/test_command_reg/test_command_reg_one_node.txt index f3ff218ae..fa898ed7c 100644 --- a/tests/test_command_reg/test_command_reg_one_node.txt +++ b/tests/test_command_reg/test_command_reg_one_node.txt @@ -353,7 +353,14 @@ echo "---" echo "diffusion-single" echo "================" time ( - $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=1 --multi_gpu --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=8 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & wait ) @@ -369,7 +376,14 @@ echo "---" echo "diffusion-nodes" echo "===============" time ( - $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=1 --multi_gpu --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=8 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & wait ) @@ -474,6 +488,36 @@ time ( wait ) +echo "---" +echo "dqn" +echo "===" +time ( + CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + wait +) + +echo "---" +echo "ppo" +echo "===" +time ( + CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + wait +) + echo "---" echo "dimenet" echo "=======" diff --git a/tests/test_command_reg/test_command_reg_two_nodes.txt b/tests/test_command_reg/test_command_reg_two_nodes.txt index bda22033e..97b3f683c 100644 --- a/tests/test_command_reg/test_command_reg_two_nodes.txt +++ b/tests/test_command_reg/test_command_reg_two_nodes.txt @@ -353,7 +353,14 @@ echo "---" echo "diffusion-single" echo "================" time ( - $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=1 --multi_gpu --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=8 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & wait ) @@ -369,8 +376,22 @@ echo "---" echo "diffusion-nodes" echo "===============" time ( - $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=2 --multi_gpu --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=16 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & - ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=1 --num_machines=2 --multi_gpu --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=16 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=0 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=0 ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=1 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=1 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=1 ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=1 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=2 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=2 ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=1 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=3 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=3 ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=1 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=4 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=4 ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=1 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=5 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=5 ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=1 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=6 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=6 ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=1 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=7 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & + CUDA_VISIBLE_DEVICES=7 ssh -oCheckHostIP=no -oStrictHostKeyChecking=no -oPasswordAuthentication=no -oPasswordAuthentication=no -p 22 username@192.168.0.11 $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=1 --num_machines=2 --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=2 $SRC/milabench/benchmarks/diffusion/main.py --num_epochs 5 --batch_size 32 --num_workers 8 --cache $BASE/cache & wait ) @@ -477,6 +498,36 @@ time ( wait ) +echo "---" +echo "dqn" +echo "===" +time ( + CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + wait +) + +echo "---" +echo "ppo" +echo "===" +time ( + CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + wait +) + echo "---" echo "dimenet" echo "======="