From 4ee0b51b46c8e8ae4d43f93457567316ad44fc24 Mon Sep 17 00:00:00 2001 From: Ilya Orson Date: Fri, 20 Sep 2024 13:46:19 +0100 Subject: [PATCH 1/3] Add pixi compatibility With this versions 'python leanrl/dqn.py' runs --- pyproject.toml | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 49c7fab..99bd1f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,3 +103,29 @@ qdagger_dqn_atari_jax_impalacnn = [ "ale-py", "AutoROM", "opencv-python", # atari "jax", "jaxlib", "flax", # jax ] + +[tool.pixi.project] +channels = ["conda-forge","pytorch"] +platforms = ["win-64", "linux-64", "osx-64", "osx-arm64"] + +[tool.pixi.pypi-dependencies] +gym = "==0.23.1" + +[tool.pixi.tasks] + +[tool.pixi.dependencies] +python = ">=3.8,<3.11" +tensorboard = ">=2.10.0,<3.0.0" +wandb = ">=0.13.11,<1.0.0" +pytorch = ">=2.0" +stable-baselines3 = "2.0.0" +gymnasium = ">=0.28.1" +moviepy = ">=1.0.3,<2.0.0" +pygame = ">=2.1,<2.2" +huggingface_hub = ">=0.11.1,<1.0.0" +rich = "<12.0" +tenacity = ">=8.2.2,<9.0.0" +tyro = ">=0.5.10,<1.0.0" +pyyaml = ">=6.0.1,<7.0.0" +numpy = ">=1.21.6" + From e5ca95293bb0b8e0c7254db73ed5dbd863b63823 Mon Sep 17 00:00:00 2001 From: Ilya Orson Sandoval Date: Fri, 20 Sep 2024 16:52:49 +0100 Subject: [PATCH 2/3] Ignore pixi files --- .gitignore | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.gitignore b/.gitignore index 1d4cfa0..8a177e7 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,8 @@ venv.bak/ # mypy .mypy_cache/ + +# pixi environments +.pixi +pixi.lock +*.egg-info From e9fb1c76992b4a335a5282bc5d40823f84915d2d Mon Sep 17 00:00:00 2001 From: Ilya Orson Sandoval Date: Fri, 20 Sep 2024 16:53:28 +0100 Subject: [PATCH 3/3] Use features for optional dependencies --- pyproject.toml | 87 +++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 83 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 99bd1f7..c87d520 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,17 +104,24 @@ qdagger_dqn_atari_jax_impalacnn = [ "jax", "jaxlib", "flax", # jax ] +[project] +name = "leanrl" +requires-python = ">=3.8,<3.11" + [tool.pixi.project] channels = ["conda-forge","pytorch"] -platforms = ["win-64", "linux-64", "osx-64", "osx-arm64"] +platforms = ["osx-arm64", "linux-64", "osx-64", "win-64"] + +[tool.pixi.tasks] +dqn = "python leanrl/dqn.py --total-timesteps 256" +#dqn-torchcompile = "python leanrl/dqn_torchcompile.py --total-timesteps 256 --compile --cudagraphs" +dqn-jax = "python leanrl/dqn_jax.py --total-timesteps 256" +ppo = "python leanrl/ppo_continuous_action.py --num-steps 64 --total-timesteps 256" [tool.pixi.pypi-dependencies] gym = "==0.23.1" -[tool.pixi.tasks] - [tool.pixi.dependencies] -python = ">=3.8,<3.11" tensorboard = ">=2.10.0,<3.0.0" wandb = ">=0.13.11,<1.0.0" pytorch = ">=2.0" @@ -129,3 +136,75 @@ tyro = ">=0.5.10,<1.0.0" pyyaml = ">=6.0.1,<7.0.0" numpy = ">=1.21.6" +[tool.pixi.feature.tensordict.dependencies] +tensordict = "*" # latest version not available in conda for osx-arm64 + +[tool.pixi.feature.atari.dependencies] +ale-py = "0.8.1" + +[tool.pixi.feature.atari.pypi-dependencies] +AutoROM = {version = ">=0.4.2,<0.5.0", extras = ["accept-rom-license"]} + +[tool.pixi.feature.opencv.dependencies] +opencv-python = ">=4.6.0.66" + +[tool.pixi.feature.procgen.dependencies] +procgen = ">=0.10.7" + +[tool.pixi.feature.pytest.dependencies] +pytest = ">=7.1.3" + +[tool.pixi.feature.mujoco.dependencies] +mujoco = "<=2.3.3" + +[tool.pixi.feature.imageio.dependencies] +imageio = ">=2.14.1" + +[tool.pixi.feature.docs.dependencies] +mkdocs-material = ">=8.4.3" +markdown-include = ">=0.7.0" + +[tool.pixi.feature.openrlbenchmark.dependencies] +openrlbenchmark = ">=0.1.1b4" + +[tool.pixi.feature.jax.pypi-dependencies] +jax = ">=0.4.11,<0.5" +jaxlib = ">=0.4.7,<0.5" +flax = ">=0.6.8,<0.10" + +[tool.pixi.feature.optuna.dependencies] +optuna = ">=3.0.1" +optuna-dashboard = ">=0.7.2" + +[tool.pixi.feature.envpool.dependencies] +envpool = ">=0.6.4" + +[tool.pixi.feature.multi_agent.dependencies] +PettingZoo = "1.18.1" +SuperSuit = "3.4.0" +multi-agent-ale-py = "0.1.11" + +[tool.pixi.feature.aws.dependencies] +boto3 = ">=1.24.70" +awscli = ">=1.31.0" + +[tool.pixi.feature.shimmy.dependencies] +shimmy = ">=1.1.0" + +[tool.pixi.feature.dm_control.dependencies] +dm-control = ">=1.0.10" + +[tool.pixi.feature.h5py.dependencies] +h5py = ">=3.7.0" + +[tool.pixi.feature.optax.dependencies] +optax = "0.1.4" + +[tool.pixi.feature.chex.dependencies] +chex = "0.1.5" + +[tool.pixi.environments] +dqn = [] +#dqn-torchcompile = ["tensordict"] +dqn-jax = ["jax"] +ppo = ["mujoco"]