Skip to content

Commit

Permalink
MAJOR: Allow transition and reward noises to depend on the current st…
Browse files Browse the repository at this point in the history
…ate and action, improve the default noise for cont. envs. example.py: set up logging; add CLI argument to toggle displaying image observations. Improve logging in general.
  • Loading branch information
RaghuSpaceRajan committed Sep 20, 2024
1 parent 3ce485f commit 2b9f7d5
Show file tree
Hide file tree
Showing 5 changed files with 2,175 additions and 109 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ __pycache__/
MUJOCO_LOG.TXT
*.pdf

log*

*.swp
*.csv
.directory
Expand Down Expand Up @@ -114,4 +116,4 @@ venv.bak/

#whitelist
!tests/files/mdpp_12744267_SAC_target_radius/*.csv
!misc/sample_recorded_data/*/*.csv
!misc/sample_recorded_data/*/*.csv
122 changes: 83 additions & 39 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""We collect here some examples of basic usage for MDP Playground.
Example call: python example.py --do_not_display_images --log_level INFO
Calling this file as a script, invokes the following examples:
one for basic discrete environments
Expand All @@ -10,7 +11,7 @@
one for basic grid environments
one for grid environments with image representations
one for wrapping Atari env qbert
one for wrapping Mujoco env HalfCheetah
one for wrapping Mujoco envs HalfCheetah, Pusher, Reacher
one for wrapping MiniGrid env # Currently commented out due to some errors
one for wrapping ProcGen env # Currently commented out due to some errors
two examples at the end showing how to create toy envs using gym.make()
Expand All @@ -25,6 +26,7 @@
from mdp_playground.envs import RLToyEnv
import numpy as np

display_images = True

def display_image(obs, mode="RGB"):
# Display the image observation associated with the next state
Expand Down Expand Up @@ -121,7 +123,8 @@ def discrete_environment_image_representations_example():

env.close()

display_image(next_state_image, mode="L")
if display_images:
display_image(next_state_image, mode="L")


def discrete_environment_diameter_image_representations_example():
Expand Down Expand Up @@ -167,7 +170,8 @@ def discrete_environment_diameter_image_representations_example():

env.close()

display_image(next_state_image, mode="L")
if display_images:
display_image(next_state_image, mode="L")


def continuous_environment_example_move_to_a_point():
Expand Down Expand Up @@ -249,8 +253,9 @@ def continuous_environment_example_move_to_a_point_irrelevant_image():

env.close()

img1 = display_image(next_state_image, mode="RGB")
img1.save("cont_env_irrelevant_image.pdf")
if display_images:
img1 = display_image(next_state_image, mode="RGB")
# img1.save("cont_env_irrelevant_image.pdf")


def continuous_environment_example_move_along_a_line():
Expand Down Expand Up @@ -342,7 +347,8 @@ def grid_environment_image_representations_example():
env.reset()[0]
env.close()

display_image(next_obs)
if display_images:
display_image(next_obs)


def atari_wrapper_example():
Expand All @@ -351,7 +357,7 @@ def atari_wrapper_example():
"seed": 0,
"delay": 1,
"transition_noise": 0.25,
"reward_noise": lambda a: a.normal(0, 0.1),
"reward_noise": lambda s, a, rng: rng.normal(0, 0.1),
"state_space_type": "discrete",
}

Expand Down Expand Up @@ -380,7 +386,8 @@ def atari_wrapper_example():

env.close()

display_image(next_state)
if display_images:
display_image(next_state)


def mujoco_wrapper_examples():
Expand Down Expand Up @@ -435,11 +442,13 @@ def mujoco_wrapper_examples():
state = env.reset(seed=gym_wrap_config["seed"])[0]

print(
"Taking a step in the environment with a random action and printing the transition:"
"Taking steps in the HalfCheetah environment with a random action and printing the transition:"
)
action = env.action_space.sample()
next_state, reward, done, trunc, info = env.step(action)
print("sars', done =", state, action, reward, next_state, done)
for i in range(3):
action = env.action_space.sample()
next_state, reward, done, trunc, info = env.step(action)
print("sars', done =", state, action, reward, next_state, done)
state = next_state

env.close()

Expand All @@ -453,14 +462,16 @@ def mujoco_wrapper_examples():
import gymnasium as gym
env = GymEnvWrapper(env, **gym_wrap_config)

state = env.reset(seed=gym_wrap_config["seed"])[0]
state = env.reset(seed=gym_wrap_config["seed"] + 1)[0]

print(
"Taking a step in the environment with a random action and printing the transition:"
"Taking steps in the Pusher environment with a random action and printing the transition:"
)
action = env.action_space.sample()
next_state, reward, done, trunc, info = env.step(action)
print("sars', done =", state, action, reward, next_state, done)
for i in range(3):
action = env.action_space.sample()
next_state, reward, done, trunc, info = env.step(action)
print("sars', done =", state, action, reward, next_state, done)
state = next_state

env.close()

Expand All @@ -474,14 +485,16 @@ def mujoco_wrapper_examples():
import gymnasium as gym
env = GymEnvWrapper(env, **gym_wrap_config)

state = env.reset(seed=gym_wrap_config["seed"])[0]
state = env.reset(seed=gym_wrap_config["seed"] + 2)[0]

print(
"Taking a step in the environment with a random action and printing the transition:"
"Taking steps in the Reacher environment with a random action and printing the transition:"
)
action = env.action_space.sample()
next_state, reward, done, trunc, info = env.step(action)
print("sars', done =", state, action, reward, next_state, done)
for i in range(3):
action = env.action_space.sample()
next_state, reward, done, trunc, info = env.step(action)
print("sars', done =", state, action, reward, next_state, done)
state = next_state

env.close()

Expand All @@ -501,7 +514,7 @@ def minigrid_wrapper_example():
"seed": 0,
"delay": 1,
"transition_noise": 0.25,
"reward_noise": lambda a: a.normal(0, 0.1),
"reward_noise": lambda s, a, rng: rng.normal(0, 0.1),
"state_space_type": "discrete",
}

Expand Down Expand Up @@ -533,7 +546,8 @@ def minigrid_wrapper_example():

env.close()

display_image(next_obs)
if display_images:
display_image(next_obs)


def procgen_wrapper_example():
Expand All @@ -542,7 +556,7 @@ def procgen_wrapper_example():
"seed": 0,
"delay": 1,
"transition_noise": 0.25,
"reward_noise": lambda a: a.normal(0, 0.1),
"reward_noise": lambda s, a, rng: rng.normal(0, 0.1),
"state_space_type": "discrete",
}

Expand All @@ -569,76 +583,106 @@ def procgen_wrapper_example():

env.close()

display_image(next_obs)
if display_images:
display_image(next_obs)


if __name__ == "__main__":

# Use argparse to set display_images to False if you don't want to display images
# and to set log level.
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--display_images", "-di", help="Display image observations (available for some examples)", action="store_true")
parser.add_argument("--do_not_display_images", "-n", help="Do not display image observations (available for some examples)", action="store_false", dest="display_images")
parser.add_argument("--log_level", type=str, default="DEBUG", help="Set the log level")
parser.set_defaults(display_images=True)
args = parser.parse_args()
display_images = args.display_images

# Set up logging globally for the MDP Playground library:
import logging
logger = logging.getLogger("mdp_playground")
logger.setLevel(args.log_level)
if not logger.handlers:
log_filename = "log_file.txt"
log_file_handler = logging.FileHandler(log_filename)
log_file_handler.setFormatter(logging.Formatter('%(message)s - %(levelname)s - %(name)s - %(asctime)s', datefmt='%m.%d.%Y %I:%M:%S %p'))
logger.addHandler(log_file_handler)
# Add a console handler:
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter('%(message)s'))
# Have less verbose logging to console:
console_handler.setLevel(logging.INFO)
logger.addHandler(console_handler)
logger.info("Begin logging to: %s", log_filename)


# Colour print
set_ansi_escape = "\033[33;1m" # Yellow, bold
reset_ansi_escape = "\033[0m"

print(set_ansi_escape + "Running discrete environment\n" + reset_ansi_escape)
logger.info(set_ansi_escape + "Running discrete environment\n" + reset_ansi_escape)
discrete_environment_example()

print(
logger.info(
set_ansi_escape
+ "\nRunning discrete environment with image representations\n"
+ reset_ansi_escape
)
discrete_environment_image_representations_example()

print(
logger.info(
set_ansi_escape
+ "\nRunning discrete environment with diameter and image representations\n"
+ reset_ansi_escape
)
discrete_environment_diameter_image_representations_example()

print(
logger.info(
set_ansi_escape
+ "\nRunning continuous environment: move_to_a_point\n"
+ reset_ansi_escape
)
continuous_environment_example_move_to_a_point()

print(
logger.info(
set_ansi_escape
+ "\nRunning continuous environment: move_to_a_point with irrelevant features and image representations\n"
+ reset_ansi_escape
)
continuous_environment_example_move_to_a_point_irrelevant_image()

print(
logger.info(
set_ansi_escape
+ "\nRunning continuous environment: move_along_a_line\n"
+ reset_ansi_escape
)
continuous_environment_example_move_along_a_line()

print(
logger.info(
set_ansi_escape
+ "\nRunning grid environment: move_to_a_point\n"
+ reset_ansi_escape
)
grid_environment_example()

print(
logger.info(
set_ansi_escape + "\nRunning grid environment: move_to_a_point "
"with image representations\n" + reset_ansi_escape
)
grid_environment_image_representations_example()

print(set_ansi_escape + "\nRunning Atari wrapper example:\n" + reset_ansi_escape)
logger.info(set_ansi_escape + "\nRunning Atari wrapper example:\n" + reset_ansi_escape)
atari_wrapper_example()

print(set_ansi_escape + "\nRunning Mujoco wrapper example:\n" + reset_ansi_escape)
logger.info(set_ansi_escape + "\nRunning Mujoco wrapper example:\n" + reset_ansi_escape)
mujoco_wrapper_examples()

print(set_ansi_escape + "\nRunning MiniGrid wrapper example:\n" + reset_ansi_escape)
# logger.info(set_ansi_escape + "\nRunning MiniGrid wrapper example:\n" + reset_ansi_escape)
# minigrid_wrapper_example()

# print(set_ansi_escape + "\nRunning ProcGen wrapper example:\n" + reset_ansi_escape)
# logger.info(set_ansi_escape + "\nRunning ProcGen wrapper example:\n" + reset_ansi_escape)
# procgen_wrapper_example()

# Using gym.make() example 1
Expand All @@ -660,4 +704,4 @@ def procgen_wrapper_example():
)
env.reset()[0]
for i in range(10):
print(env.step(env.action_space.sample()))
logger.info(env.step(env.action_space.sample()))
Loading

0 comments on commit 2b9f7d5

Please sign in to comment.