Skip to content

Commit

Permalink
Added default reward_function for cont. envs; remove bug in ImageCont…
Browse files Browse the repository at this point in the history
…inuous;
  • Loading branch information
RaghuSpaceRajan committed Nov 22, 2024
1 parent ce5aa6a commit 3411bf7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
8 changes: 4 additions & 4 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def discrete_environment_image_representations_example():
augmented_state_dict = env.get_augmented_state()
next_state = augmented_state_dict["curr_state"] # Underlying MDP state holds
# the current discrete state.
print("sars', done =", state, action, reward, next_state, done)
print("sars', done, image shape =", state, action, reward, next_state, done, next_state_image.shape)

env.close()

Expand Down Expand Up @@ -175,7 +175,7 @@ def discrete_environment_diameter_image_representations_example():
augmented_state_dict = env.get_augmented_state()
next_state = augmented_state_dict["curr_state"] # Underlying MDP state holds
# the current discrete state.
print("sars', done =", state, action, reward, next_state, done)
print("sars', done, shape =", state, action, reward, next_state, done, next_state_image.shape)

env.close()

Expand Down Expand Up @@ -262,7 +262,7 @@ def continuous_environment_example_move_to_a_point_irrelevant_image():
augmented_state_dict = env.get_augmented_state()
next_state = augmented_state_dict["curr_state"].copy() # Underlying MDP state holds
# the current continuous state.
print("sars', done =", state, action, reward, next_state, done)
print("sars', done, image shape =", state, action, reward, next_state, done, next_state_image.shape)

env.close()

Expand Down Expand Up @@ -388,7 +388,7 @@ def grid_environment_image_representations_example():
action = actions[i]
next_obs, reward, done, trunc, info = env.step(action)
next_state = env.get_augmented_state()["augmented_state"][-1]
print("sars', done =", state, action, reward, next_state, done)
print("sars', done, image shape =", state, action, reward, next_state, done, next_obs.shape)
state = next_state

env.reset()[0]
Expand Down
12 changes: 10 additions & 2 deletions mdp_playground/envs/rl_toy_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def __init__(self, **config):
# if config["state_space_type"] == "discrete":
# assert "init_state_dist" in config

# Common defaults for all types of environments:
if "terminal_state_density" not in config:
self.terminal_state_density = 0.25
else:
Expand Down Expand Up @@ -483,6 +484,7 @@ def __init__(self, **config):
else:
self.image_scale_range = config["image_scale_range"]

# Defaults for the individual environment types:
if config["state_space_type"] == "discrete":
if "reward_dist" not in config:
self.reward_dist = None
Expand All @@ -498,6 +500,11 @@ def __init__(self, **config):
# if not self.use_custom_mdp:
self.state_space_dim = config["state_space_dim"]

# ##TODO Do something to dismbiguate the Python function reward_function from the
# choice of reward_function below.
if "reward_function" not in config:
config["reward_function"] = "move_to_a_point"

if "transition_dynamics_order" not in config:
self.dynamics_order = 1
else:
Expand Down Expand Up @@ -548,8 +555,9 @@ def __init__(self, **config):
self.repeats_in_sequences = config["repeats_in_sequences"]


# ##TODO Move these to the individual env types' defaults section above?
if config["state_space_type"] == "discrete":
self.dtype_s = np.int64 if "dtype_s" not in config else config["dtype_s"]
self.dtype_s = np.int32 if "dtype_s" not in config else config["dtype_s"]
if self.irrelevant_features:
assert (
len(config["action_space_size"]) == 2
Expand Down Expand Up @@ -589,7 +597,7 @@ def __init__(self, **config):

# Set the dtype for the observation space:
if self.image_representations:
self.dtype_o = np.float32 if "dtype_o" not in config else config["dtype_o"]
self.dtype_o = np.uint8 if "dtype_o" not in config else config["dtype_o"]
else:
self.dtype_o = self.dtype_s if "dtype_o" not in config else config["dtype_o"]

Expand Down
16 changes: 10 additions & 6 deletions mdp_playground/spaces/image_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
term_spaces=None,
width=100,
height=100,
num_channels=3,
circle_radius=5,
target_point=None,
relevant_indices=[0, 1],
Expand All @@ -43,6 +44,8 @@ def __init__(
The width of the image
height : int
The height of the image
num_channels : int
The number of channels in the image ###TODO: Support for 1 channel; unify with ImageMultiDiscrete
circle_radius : int
The radius of the circle which represents the agent and target point
target_point : np.array
Expand All @@ -60,6 +63,7 @@ def __init__(
assert (self.feature_space.low != -np.inf).any()
self.width = width
self.height = height
self.num_channels = num_channels
# Warn if resolution is too low?
self.circle_radius = circle_radius
self.target_point = target_point
Expand Down Expand Up @@ -99,7 +103,7 @@ def __init__(

# Shape has 1 appended for Ray Rllib to be compatible IIRC
super(ImageContinuous, self).__init__(
shape=(width, height, 1), dtype=dtype, low=0, high=255
shape=(width, height, num_channels), dtype=dtype, low=0, high=255
)
super(ImageContinuous, self).seed(seed=seed)

Expand All @@ -117,10 +121,10 @@ def generate_image(self, position, relevant=True):
"""
# Use RGB
image_ = Image.new("RGB", (self.width, self.height), color=self.bg_colour)
# Use L for black and white 8-bit pixels instead of RGB in case not
# using custom images
# image_ = Image.new("L", (self.width, self.height))
if self.num_channels == 3:
image_ = Image.new("RGB", (self.width, self.height), color=self.bg_colour)
elif self.num_channels == 1:
image_ = Image.new("L", (self.width, self.height), color=self.bg_colour)
draw = ImageDraw.Draw(image_)

# Draw in decreasing order of importance:
Expand Down Expand Up @@ -239,7 +243,7 @@ def contains(self, x):
if x.shape == (
self.width,
self.height,
1,
self.num_channels,
): # TODO compare each pixel for all possible images?
return True

Expand Down

0 comments on commit 3411bf7

Please sign in to comment.