Skip to content

Latest commit

 

History

History
156 lines (143 loc) · 6.57 KB

README.md

File metadata and controls

156 lines (143 loc) · 6.57 KB

RL

Reinforcement learning

框架简介

目录

src
├── requestment.txt
├── RLAlgo
│   ├── _base_net.py
│   ├── DDPG.py
│   ├── DQN.py
│   ├── grad_ana.py
│   ├── PPO2_old.py
│   ├── PPO2.py
│   ├── PPO.py
│   ├── __pycache__
│   ├── SAC.py
│   └── TD3.py
├── RLUtils
│   ├── config.py
│   ├── env_wrapper.py
│   ├── __init__.py
│   ├── memory.py
│   ├── __pycache__
│   ├── state_util.py
│   └── trainer.py
├── setup.py
├── test
│   ├── border_detector.py
│   ├── test_ddpg.py
│   ├── test_dqn.py
│   ├── test_env_explore.ipynb
│   ├── README.md
│   ├── test_models
│   ├── test_ppo_atari.py
│   ├── test_ppo_new.py
│   ├── test_ppo.py
│   ├── test_sac.py
│   └── test_TD3.py
└── TODO.md

环境要求

核心包

package version
python版本 Python 3.10
torch 2.1.1
torchvision 0.16.1
gymnasium 0.29.1
cloudpickle 2.2.1
envpool 0.8.4

运行示例

import gymnasium as gym
import torch
from RLAlgo.PPO2 import PPO2
from RLUtils import train_on_policy, random_play, play, Config, gym_env_desc

env_name = 'Hopper-v4'
gym_env_desc(env_name)
path_ = os.path.dirname(__file__) 
env = gym.make(
    env_name, 
    exclude_current_positions_from_observation=True,
    # healthy_reward=0
)
cfg = Config(
    env, 
    # 环境参数
    save_path=os.path.join(path_, "test_models" ,'PPO_Hopper-v4_test2'), 
    seed=42,
    # 网络参数
    actor_hidden_layers_dim=[256, 256, 256],
    critic_hidden_layers_dim=[256, 256, 256],
    # agent参数
    actor_lr=1.5e-4,
    critic_lr=5.5e-4,
    gamma=0.99,
    # 训练参数
    num_episode=12500,
    off_buffer_size=512,
    off_minimal_size=510,
    max_episode_steps=500,
    PPO_kwargs={
        'lmbda': 0.9,
        'eps': 0.25,
        'k_epochs': 4, 
        'sgd_batch_size': 128,
        'minibatch_size': 12, 
        'actor_bound': 1,
        'dist_type': 'beta'
    }
)
agent = PPO2(
    state_dim=cfg.state_dim,
    actor_hidden_layers_dim=cfg.actor_hidden_layers_dim,
    critic_hidden_layers_dim=cfg.critic_hidden_layers_dim,
    action_dim=cfg.action_dim,
    actor_lr=cfg.actor_lr,
    critic_lr=cfg.critic_lr,
    gamma=cfg.gamma,
    PPO_kwargs=cfg.PPO_kwargs,
    device=cfg.device,
    reward_func=None
)

agent.train()
train_on_policy(env, agent, cfg, wandb_flag=False, train_without_seed=True, test_ep_freq=1000, 
                online_collect_nums=cfg.off_buffer_size,
                test_episode_count=5)
agent.load_model(cfg.save_path)
agent.eval()
env_ = gym.make(env_name, 
                exclude_current_positions_from_observation=True,
                # render_mode='human'
                )
play(env_, agent, cfg, episode_count=3, play_without_seed=True, render=False)

训练结果展示

环境与描述 参数函数链接 效果
[ Hopper-v4 ](state: (11,),action: (3,)(连续 <-1.0 -> 1.0>)) Hopper_v4_ppo2_test PPO2-PPO2_Hopper-v4
[ Humanoid-v4 ](state: (376,),action: (17,)(连续 <-0.4 -> 0.4>)) Humanoid_v4_ppo2_test PPO2-PPO2_Humanoid-v4
[ ALE/DemonAttack-v5 ](state: (210, 160, 3),action: 6(离散 )) DemonAttack_v5_ppo2_test PPO2_DemonAttack_v5
[ ALE/AirRaid-v5 ](state: (250, 160, 3),action: 6(离散 )) AirRaid_v5_ppo2_test PPO2_AirRaid_v5
[ ALE/Alien-v5 ](state: (210, 160, 3),action: 18(离散 )) Alien_v5_ppo2_test PPO2_Alien_v5
[ Walker2d-v4 ](state: (17,),action: (6,)(连续 <-1.0 -> 1.0>)) Walker2d_v4_ppo2_test warlker
[ HumanoidStandup-v4 ](state: (376,),action: (17,)(连续 <-0.4 -> 0.4>)) HumanoidStandup_v4_ppo2_test stand
[ CartPole-v1 ](state: (4,),action: 2(离散)) duelingDQN: dqn_test duelingDQN_CartPole
[ MountainCar-v0 ](state: (2,),action: 3(离散 )) duelingDQN: dqn_test duelingDQN_MountainCar
[ Acrobot-v1 ](state: (6,),action: 3(离散 )) duelingDQN: Acrobot_dqn_test duelingDQN_Acrobot
[ LunarLander-v2 ](state: (8,),action: 4(离散 )) duelingDQN: LunarLander_dqn_test duelingDQN_LunarLander
[ ALE/DemonAttack-v5 ](state: (210, 160, 3),action: 6(离散 )) doubleDQN: DemonAttack_v5_dqn_new_test doubleDQN-DemonAc
[ BipedalWalker-v3 ](state: (24,),action: (4,)(连续 <-1.0 -> 1.0>)) BipedalWalker_ddpg_test DDPG
[ BipedalWalkerHardcore-v3 ](state: (24,),action: (4,)(连续 <-1.0 -> 1.0>)) BipedalWalkerHardcore_TD3_test TD3
[ Reacher-v4 ](state: (11,),action: (2,)(连续 <-1.0 -> 1.0>)) sac_Reacher_v4_test SAC
[ Pusher-v4 ](state: (23,),action: (7,)(连续 <-2.0 -> 2.0>)) sac_Pusher_v4_test SAC-2
[ CarRacing-v2 ](state: (96, 96, 3),action: (3,)(连续 <-1.0 -> 1.0>)) CarRacing_TD3_test TD3-car
[ InvertedPendulum-v4 ](state: (4,),action: (1,)(连续 <-3.0 -> 3.0>)) InvertedPendulum_TD3_test TD3-InvertedPendulum
[ HalfCheetah-v4 ](state: (17,),action: (6,)(连续 <-1.0 -> 1.0>)) HalfCheetah_v4_ppo_test PPO-PPO_HalfCheetah-v4
[ ALE/Breakout-v5 ](state: (210, 160, 3),action: 4(离散 )) Breakout_v5_ppo2_test Breakout