Skip to content

PyTorch implementations of Deep Reinforcement Learning algorithms (DQN, DDQN, A2C, VPG, TRPO, PPO, DDPG, TD3, SAC, SAC-AEA)

License

Notifications You must be signed in to change notification settings

LWw586/deep_rl

 
 

Repository files navigation

Deep Reinforcement Learning (DRL) Algorithms with PyTorch

This repository contains PyTorch implementations of deep reinforcement learning algorithms. The repository will soon be updated including the PyBullet environments!

Algorithms Implemented

  1. Deep Q-Network (DQN) (V. Mnih et al. 2015)
  2. Double DQN (DDQN) (H. Van Hasselt et al. 2015)
  3. Advantage Actor Critic (A2C)
  4. Vanilla Policy Gradient (VPG)
  5. Natural Policy Gradient (NPG) (S. Kakade et al. 2002)
  6. Trust Region Policy Optimization (TRPO) (J. Schulman et al. 2015)
  7. Proximal Policy Optimization (PPO) (J. Schulman et al. 2017)
  8. Deep Deterministic Policy Gradient (DDPG) (T. Lillicrap et al. 2015)
  9. Twin Delayed DDPG (TD3) (S. Fujimoto et al. 2018)
  10. Soft Actor-Critic (SAC) (T. Haarnoja et al. 2018)
  11. SAC with automatic entropy adjustment (SAC-AEA) (T. Haarnoja et al. 2018)

Environments Implemented

  1. Classic control environments (CartPole-v1, Pendulum-v0, etc.) (as described in here)
  2. MuJoCo environments (Hopper-v2, HalfCheetah-v2, Ant-v2, Humanoid-v2, etc.) (as described in here)
  3. PyBullet environments (HopperBulletEnv-v0, HalfCheetahBulletEnv-v0, AntBulletEnv-v0, HumanoidDeepMimicWalkBulletEnv-v1 etc.) (as described in here)

Results (MuJoCo, PyBullet)

MuJoCo environments

Hopper-v2

  • Observation space: 8
  • Action space: 3

HalfCheetah-v2

  • Observation space: 17
  • Action space: 6

Ant-v2

  • Observation space: 111
  • Action space: 8

Humanoid-v2

  • Observation space: 376
  • Action space: 17

PyBullet environments

HopperBulletEnv-v0

  • Observation space: 15
  • Action space: 3

HalfCheetahBulletEnv-v0

  • Observation space: 26
  • Action space: 6

AntBulletEnv-v0

  • Observation space: 28
  • Action space: 8

HumanoidDeepMimicWalkBulletEnv-v1

  • Observation space: 197
  • Action space: 36

Requirements

Usage

The repository's high-level structure is:

├── agents                    
    └── common 
├── results  
    ├── data 
    └── graphs        
└── save_model

1) To train the agents on the environments

To train all the different agents on PyBullet environments, follow these steps:

git clone https://github.com/dongminlee94/deep_rl.git
cd deep_rl
python run_bullet.py

For other environments, change the last line to run_cartpole.py, run_pendulum.py, run_mujoco.py.

If you want to change configurations of the agents, follow this step:

python run_bullet.py \
    --env=HumanoidDeepMimicWalkBulletEnv-v1 \
    --algo=sac-aea \
    --phase=train \
    --render=False \
    --load=None \
    --seed=0 \
    --iterations=200 \
    --steps_per_iter=5000 \
    --max_step=1000 \
    --tensorboard=True \
    --gpu_index=0

2) To watch the learned agents on the above environments

To watch all the learned agents on PyBullet environments, follow these steps:

python run_bullet.py \
    --env=HumanoidDeepMimicWalkBulletEnv-v1 \
    --algo=sac-aea \
    --phase=test \
    --render=True \
    --load=envname_algoname_... \
    --seed=0 \
    --iterations=200 \
    --steps_per_iter=5000 \
    --max_step=1000 \
    --tensorboard=False \
    --gpu_index=0

You should copy the saved model name in save_model/envname_algoname_... and paste the copied name in envname_algoname_.... So the saved model will be load.

About

PyTorch implementations of Deep Reinforcement Learning algorithms (DQN, DDQN, A2C, VPG, TRPO, PPO, DDPG, TD3, SAC, SAC-AEA)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%