Skip to content

Commit

Permalink
Allow dtype_s and dtype_o of toy envs to be set for the underlying st…
Browse files Browse the repository at this point in the history
…ate space and observation space, respt. (action_space is currently set the same as the state space); partially fix some text cases.
  • Loading branch information
RaghuSpaceRajan committed Nov 21, 2024
1 parent 49551d2 commit ce5aa6a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 29 deletions.
56 changes: 35 additions & 21 deletions mdp_playground/envs/rl_toy_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ class RLToyEnv(gym.Env):
The externally visible observation space for the enviroment.
action_space : Gym.Space
The externally visible action space for the enviroment.
feature_space : Gym.Space
In case of continuous and grid environments, this is the underlying state space. ##TODO Unify this across all types of environments.
rewardable_sequences : dict
holds the rewardable sequences. The keys are tuples of rewardable sequences and values are the rewards handed out. When make_denser is True for discrete environments, this dict also holds the rewardable partial sequences.
Expand Down Expand Up @@ -519,7 +521,6 @@ def __init__(self, **config):
elif config["state_space_type"] == "grid":
assert "grid_shape" in config
self.grid_shape = config["grid_shape"]
self.grid_np_data_type = np.int64
else:
raise ValueError("Unknown state_space_type")

Expand All @@ -546,9 +547,9 @@ def __init__(self, **config):
else:
self.repeats_in_sequences = config["repeats_in_sequences"]

self.dtype = np.float32 if "dtype" not in config else config["dtype"]

if config["state_space_type"] == "discrete":
self.dtype_s = np.int64 if "dtype_s" not in config else config["dtype_s"]
if self.irrelevant_features:
assert (
len(config["action_space_size"]) == 2
Expand All @@ -570,6 +571,7 @@ def __init__(self, **config):
)
# assert (np.array(self.state_space_size) % np.array(self.diameter) == 0).all(), "state_space_size should be a multiple of the diameter to allow for the generation of regularly connected MDPs."
elif config["state_space_type"] == "continuous":
self.dtype_s = np.float32 if "dtype_s" not in config else config["dtype_s"]
self.action_space_dim = self.state_space_dim
if self.irrelevant_features:
assert (
Expand All @@ -580,10 +582,18 @@ def __init__(self, **config):
config["relevant_indices"] = range(self.state_space_dim)
# config["irrelevant_indices"] = list(set(range(len(config["state_space_dim"]))) - set(config["relevant_indices"]))
elif config["state_space_type"] == "grid":
self.dtype_s = np.int64 if "dtype_s" not in config else config["dtype_s"]
# Repeat the grid for the irrelevant part as well
if self.irrelevant_features:
self.grid_shape = self.grid_shape * 2

# Set the dtype for the observation space:
if self.image_representations:
self.dtype_o = np.float32 if "dtype_o" not in config else config["dtype_o"]
else:
self.dtype_o = self.dtype_s if "dtype_o" not in config else config["dtype_o"]


if ("init_state_dist" in config) and ("relevant_init_state_dist" not in config):
config["relevant_init_state_dist"] = config["init_state_dist"]

Expand Down Expand Up @@ -614,7 +624,7 @@ def __init__(self, **config):
assert self.sequence_length == 1
if "target_point" in config:
self.target_point = np.array(
config["target_point"], dtype=self.dtype
config["target_point"], dtype=self.dtype_s
)
assert self.target_point.shape == (
len(config["relevant_indices"]),
Expand All @@ -640,6 +650,7 @@ def __init__(self, **config):
DiscreteExtended(
self.state_space_size[0],
seed=self.seed_dict["relevant_state_space"],
# dtype=self.dtype_o, # Gymnasium seems to hardcode as np.int64
)
] # #seed #hardcoded, many time below as well
self.action_spaces = [
Expand Down Expand Up @@ -671,7 +682,7 @@ def __init__(self, **config):
# self.action_spaces[i] = DiscreteExtended(self.action_space_size[i],
# seed=self.seed_dict["irrelevant_action_space"]) #seed

if self.image_representations:
if self.image_representations: # for discrete envs
# underlying_obs_space = MultiDiscreteExtended(self.state_space_size, seed=self.seed_dict["state_space"]) #seed
self.observation_space = ImageMultiDiscrete(
self.state_space_size,
Expand Down Expand Up @@ -714,7 +725,7 @@ def __init__(self, **config):
self.state_space_max,
shape=(self.state_space_dim,),
seed=self.seed_dict["state_space"],
dtype=self.dtype,
dtype=self.dtype_s,
) # #seed
# hack #TODO # low and high are 1st 2 and required arguments
# for instantiating BoxExtended
Expand All @@ -729,7 +740,7 @@ def __init__(self, **config):
self.action_space_max,
shape=(self.action_space_dim,),
seed=self.seed_dict["action_space"],
dtype=self.dtype,
dtype=self.dtype_s,
) # #seed
# hack #TODO

Expand All @@ -754,7 +765,7 @@ def __init__(self, **config):
0 * underlying_space_maxes,
underlying_space_maxes,
seed=self.seed_dict["state_space"],
dtype=self.dtype,
dtype=self.dtype_s,
) # #seed

lows = np.array([-1] * len(self.grid_shape))
Expand Down Expand Up @@ -893,7 +904,7 @@ def init_terminal_states(self):
# print("Term state lows, highs:", lows, highs)
self.term_spaces.append(
BoxExtended(
low=lows, high=highs, seed=self.seed_, dtype=self.dtype
low=lows, high=highs, seed=self.seed_, dtype=self.dtype_s
)
) # #seed #hack #TODO
self.logger.debug(
Expand Down Expand Up @@ -931,7 +942,7 @@ def init_terminal_states(self):
highs = term_state # #hardcoded
self.term_spaces.append(
BoxExtended(
low=lows, high=highs, seed=self.seed_, dtype=self.grid_np_data_type
low=lows, high=highs, seed=self.seed_, dtype=self.dtype_s
)
) # #seed #hack #TODO

Expand Down Expand Up @@ -1657,7 +1668,7 @@ def transition_function(self, state, action):
# for a "wall", but would need to take care of multiple
# reflections near a corner/edge.
# Resets all higher order derivatives to 0
zero_state = np.array([0.0] * (self.state_space_dim), dtype=self.dtype)
zero_state = np.array([0.0] * (self.state_space_dim), dtype=self.dtype_s)
# #####IMP to have copy() otherwise it's the same array
# (in memory) at every position in the list:
self.state_derivatives = [
Expand All @@ -1666,7 +1677,7 @@ def transition_function(self, state, action):
self.state_derivatives[0] = next_state

if self.config["reward_function"] == "move_to_a_point":
next_state_rel = np.array(next_state, dtype=self.dtype)[
next_state_rel = np.array(next_state, dtype=self.dtype_s)[
self.config["relevant_indices"]
]
dist_ = np.linalg.norm(next_state_rel - self.target_point)
Expand All @@ -1678,7 +1689,7 @@ def transition_function(self, state, action):
# Need to check that dtype is int because Gym doesn't
if (
self.action_space.contains(action)
and np.array(action).dtype == self.grid_np_data_type
and np.array(action).dtype == self.dtype_s
):
if self.transition_noise:
# self._np_random.choice only works for 1-D arrays
Expand Down Expand Up @@ -1820,7 +1831,7 @@ def reward_function(self, state, action):
# of the formulae and see that programmatic results match: should
# also have a unit version of 4. for dist_of_pt_from_line() and
# an integration version here for total_deviation calc.?.
data_ = np.array(state_considered, dtype=self.dtype)[
data_ = np.array(state_considered, dtype=self.dtype_s)[
1 + delay : self.augmented_state_length,
self.config["relevant_indices"],
]
Expand Down Expand Up @@ -1863,10 +1874,10 @@ def reward_function(self, state, action):
# that. #TODO Generate it randomly to have random Rs?
if self.make_denser:
old_relevant_state = np.array(
state_considered, dtype=self.dtype
state_considered, dtype=self.dtype_s
)[-2, self.config["relevant_indices"]]
new_relevant_state = np.array(
state_considered, dtype=self.dtype
state_considered, dtype=self.dtype_s
)[-1, self.config["relevant_indices"]]
reward = -np.linalg.norm(new_relevant_state - self.target_point)
# Should allow other powers of the distance from target_point,
Expand All @@ -1879,7 +1890,7 @@ def reward_function(self, state, action):
# TODO also make_denser, sparse rewards only at target
else: # sparse reward
new_relevant_state = np.array(
state_considered, dtype=self.dtype
state_considered, dtype=self.dtype_s
)[-1, self.config["relevant_indices"]]
if (
np.linalg.norm(new_relevant_state - self.target_point)
Expand All @@ -1890,7 +1901,7 @@ def reward_function(self, state, action):
# stay in the radius and earn more reward.

reward -= self.action_loss_weight * np.linalg.norm(
np.array(action, dtype=self.dtype)
np.array(action, dtype=self.dtype_s)
)

elif self.config["state_space_type"] == "grid":
Expand Down Expand Up @@ -2044,8 +2055,8 @@ def step(self, action, imaginary_rollout=False):
if self.image_representations:
next_obs = self.observation_space.get_concatenated_image(next_state)

self.curr_state = next_state
self.curr_obs = next_obs
self.curr_state = self.dtype_s(next_state)
self.curr_obs = self.dtype_o(next_obs)

# #### TODO curr_state is external state, while we need to check relevant state for terminality! Done - by using augmented_state now instead of curr_state!
self.done = (
Expand Down Expand Up @@ -2199,7 +2210,7 @@ def reset(self, seed=None):

# if not self.use_custom_mdp:
# init the state derivatives needed for continuous spaces
zero_state = np.array([0.0] * (self.state_space_dim), dtype=self.dtype)
zero_state = np.array([0.0] * (self.state_space_dim), dtype=self.dtype_s)
self.state_derivatives = [
zero_state.copy() for i in range(self.dynamics_order + 1)
] # #####IMP to have copy()
Expand All @@ -2217,7 +2228,7 @@ def reset(self, seed=None):
while True: # Be careful about infinite loops
term_space_was_sampled = False
# curr_state is an np.array while curr_state_relevant is a list
self.curr_state = self.feature_space.sample().astype(int) # #random
self.curr_state = self.feature_space.sample().astype(self.dtype_s) # #random
self.curr_state_relevant = list(self.curr_state[[0, 1]]) # #hardcoded
if self.is_terminal_state(self.curr_state_relevant):
self.logger.debug(
Expand All @@ -2241,6 +2252,9 @@ def reset(self, seed=None):
else:
self.curr_obs = self.curr_state

self.curr_state = self.dtype_s(self.curr_state)
self.curr_obs = self.dtype_o(self.curr_obs)

self.logger.info("RESET called. curr_state reset to: " + str(self.curr_state))
self.reached_terminal = False

Expand Down
17 changes: 9 additions & 8 deletions tests/test_mdp_playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_continuous_dynamics_move_along_a_line(self):
# Test 5: R noise - same as Test 1 above except with reward noise and with only 5 steps
# instead of 20.
print("\nTest 5: \033[32;1;4mTEST_CONTINUOUS_DYNAMICS_R_NOISE\033[0m")
config["reward_noise"] = lambda a: a.normal(0, 0.5)
config["reward_noise"] = lambda s, a, rng: rng.normal(0, 0.5)
config["delay"] = 0
env = RLToyEnv(**config)
state = env.get_augmented_state()["curr_state"].copy() # env.reset()[0]
Expand Down Expand Up @@ -303,7 +303,7 @@ def test_continuous_dynamics_move_along_a_line(self):

# Test P noise
print("\nTest 9: \033[32;1;4mTEST_CONTINUOUS_DYNAMICS_P_NOISE\033[0m")
config["transition_noise"] = lambda a: a.normal([0] * 7, [0.5] * 7)
config["transition_noise"] = lambda s, a, rng: rng.normal([0] * 7, [0.5] * 7)
# Reset seed to have states far away from state maxes so that it is easier to
# test stuff below, but in the end, the state is clipped to [-5, 5] anyway
# while testing, so this wasn't really needed.
Expand Down Expand Up @@ -1243,9 +1243,10 @@ def test_discrete_dynamics(self):
config["generate_random_mdp"] = True
env = RLToyEnv(**config)
state = env.get_augmented_state()["curr_state"]
self.assertEqual(
type(state), int, "Type of discrete state should be int."
) # TODO Move this and the test_continuous_dynamics type checks to separate unit tests
if type(state) != int:
self.assertEqual(
state.dtype, env.observation_space.dtype, "Type of discrete state should be: " + str(env.observation_space.dtype)
) # TODO Move this and the test_continuous_dynamics type checks to separate unit tests

action = 2
next_state, reward, done, trunc, info = env.step(action)
Expand Down Expand Up @@ -1482,7 +1483,7 @@ def test_discrete_r_noise(self):
config["delay"] = 0
config["sequence_length"] = 1
config["reward_scale"] = 1.0
config["reward_noise"] = lambda a: a.normal(0, 0.5)
config["reward_noise"] = lambda s, a, rng: rng.normal(0, 0.5)

config["generate_random_mdp"] = True
config["log_level"] = logging.INFO
Expand Down Expand Up @@ -1545,7 +1546,7 @@ def test_discrete_multiple_meta_features(self):
config["reward_scale"] = 2.5
config["reward_shift"] = -1.75
# config["transition_noise"] = 0.1
config["reward_noise"] = lambda a: a.normal(0, 0.5)
config["reward_noise"] = lambda s, a, rng: rng.normal(0, 0.5)

config["generate_random_mdp"] = True
env = RLToyEnv(**config)
Expand Down Expand Up @@ -1804,7 +1805,7 @@ def test_discrete_image_representations(self):
config["reward_scale"] = 2.5
config["reward_shift"] = -1.75
# config["transition_noise"] = 0.1
config["reward_noise"] = lambda a: a.normal(0, 0.5)
config["reward_noise"] = lambda s, a, rng: rng.normal(0, 0.5)

config["generate_random_mdp"] = True

Expand Down

0 comments on commit ce5aa6a

Please sign in to comment.