diff --git a/pymdp/envs/generalized_tmaze.py b/pymdp/envs/generalized_tmaze.py index 74419f2e..6a8c4ee6 100644 --- a/pymdp/envs/generalized_tmaze.py +++ b/pymdp/envs/generalized_tmaze.py @@ -1,5 +1,6 @@ -from pymdp.jax.envs import PyMDPEnv +from .env import Env import numpy as np +import jax.numpy as jnp import matplotlib.pyplot as plt import io @@ -7,46 +8,11 @@ import jax.numpy as jnp import jax.tree_util as jtu +from jax import random as jr +from jaxtyping import Array, PRNGKeyArray from matplotlib.lines import Line2D - -def position_to_index(position, shape): - """ - Maps the position in the grid to a flat index - Parameters - ---------- - position - Tuple of position (row, col) - shape - The shape of the grid (n_rows, n_cols) - - Returns - ---------- - index - A flattened index of position - """ - return position[0] * shape[1] + position[1] - - -def index_to_position(index, shape): - """ - Maps the flat index to a position coordinate in the grid - Parameters - ---------- - shape - The shape of the grid (n_rows, n_cols) - index - A flattened index of position - - Returns - ---------- - position - Tuple of position (row, col) - """ - return index // shape[1], index % shape[1] - - -def parse_maze(maze): +def parse_maze(maze, rng_key: PRNGKeyArray): """ Parameters ---------- @@ -67,45 +33,36 @@ def parse_maze(maze): purposes """ - maze = np.array(maze) rows, cols = maze.shape - num_cues = int((np.max(maze) - 2) // 3) + num_cues = int((jnp.max(maze) - 2) // 3) cue_positions = [] reward_1_positions = [] reward_2_positions = [] for i in range(num_cues): - cue_positions.append(tuple(np.argwhere(maze == 3 + 3 * i)[0])) - reward_1_positions.append(tuple(np.argwhere(maze == 4 + 3 * i)[0])) - reward_2_positions.append(tuple(np.argwhere(maze == 5 + 3 * i)[0])) + cue_positions.append(tuple(jnp.argwhere(maze == 3 + 3 * i)[0])) + reward_1_positions.append(tuple(jnp.argwhere(maze == 4 + 3 * i)[0])) + reward_2_positions.append(tuple(jnp.argwhere(maze == 5 + 3 * i)[0])) # Initialize agent's starting position (can be customized if required) - initial_position = tuple(np.argwhere(maze == 1)[0]) + initial_position = tuple(jnp.argwhere(maze == 1)[0]) # Actions: up, down, left, right actions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # Set reward location randomly - reward_locations = np.random.choice([0, 1], size=num_cues) + reward_locations = jr.choice(rng_key, 2, shape=(num_cues,)) reward_indices = [] no_reward_indices = [] for i in range(num_cues): if reward_locations[i] == 0: - reward_indices += position_to_index( - reward_1_positions[i], maze.shape - ) - no_reward_indices += position_to_index( - reward_2_positions[i], maze.shape - ) + reward_indices += [jnp.ravel_multi_index(jnp.array(reward_1_positions[i]), maze.shape).item()] + no_reward_indices += [jnp.ravel_multi_index(jnp.array(reward_2_positions[i]), maze.shape).item()] else: - reward_indices += position_to_index( - reward_2_positions[i], maze.shape - ) - no_reward_indices += position_to_index( - reward_1_positions[i], maze.shape - ) + reward_indices += [jnp.ravel_multi_index(jnp.array(reward_2_positions[i]), maze.shape).item()] + no_reward_indices += [jnp.ravel_multi_index(jnp.array(reward_1_positions[i]), maze.shape).item()] return { "maze": maze, @@ -158,13 +115,9 @@ def generate_A(maze_info): cue_likelihood = np.zeros((3, num_states, 2)) cue_likelihood[0, :, :] = 1 # Default: no info about reward - cue_state_idx = position_to_index(cue_positions[i], maze.shape) - reward_1_state_idx = position_to_index( - reward_1_positions[i], maze.shape - ) - reward_2_state_idx = position_to_index( - reward_2_positions[i], maze.shape - ) + cue_state_idx = jnp.ravel_multi_index(jnp.array(cue_positions[i]), maze.shape) + reward_1_state_idx = jnp.ravel_multi_index(jnp.array(reward_1_positions[i]), maze.shape) + reward_2_state_idx = jnp.ravel_multi_index(jnp.array(reward_2_positions[i]), maze.shape) cue_likelihood[:, cue_state_idx, 0] = [0, 1, 0] # Reward in r1 cue_likelihood[:, cue_state_idx, 1] = [0, 0, 1] # Reward in r2 @@ -178,13 +131,8 @@ def generate_A(maze_info): reward_likelihood = np.zeros((3, num_states, 2)) reward_likelihood[0, :, :] = 1 # Default: no reward - reward_1_state_idx = position_to_index( - reward_1_positions[i], maze.shape - ) - - reward_2_state_idx = position_to_index( - reward_2_positions[i], maze.shape - ) + reward_1_state_idx = jnp.ravel_multi_index(jnp.array(reward_1_positions[i]), maze.shape) + reward_2_state_idx = jnp.ravel_multi_index(jnp.array(reward_2_positions[i]), maze.shape) # Reward in (8,4) if reward state is 0 reward_likelihood[:, reward_1_state_idx, 0] = [0, 1, 0] @@ -254,7 +202,7 @@ def generate_B(maze_info): ): P[s, a] = s else: - P[s, a] = position_to_index((ns_row, ns_col), maze.shape) + P[s, a] = jnp.ravel_multi_index(jnp.array((ns_row, ns_col)), maze.shape) B = np.zeros((num_states, num_states, num_actions)) for s in range(num_states): @@ -306,7 +254,7 @@ def generate_D(maze_info): D[0] = np.zeros(cols * rows) # Position of the agent when starting the environment - D[0][position_to_index(initial_position, maze.shape)] = 1 + D[0][jnp.ravel_multi_index(jnp.array(initial_position), maze.shape)] = 1 # Cue state i.e. where is the reward for i in range(num_cues): @@ -339,7 +287,7 @@ def render(maze_info, env_state, show_img=True): reward_2_positions = maze_info["reward_2_positions"] current_position = env_state.state[0] - current_position = index_to_position(current_position, maze.shape) + current_position = jnp.unravel_index(current_position, maze.shape) # Set all states not in [1] to be 0 (accessible state) mask = np.isin(maze, [2], invert=True) @@ -447,7 +395,7 @@ def render(maze_info, env_state, show_img=True): return image -class GeneralizedTMazeEnv(PyMDPEnv): +class GeneralizedTMazeEnv(Env): """ Extended version of the T-Maze in which there are multiple cues and reward pairs similar to the original T-maze. @@ -465,4 +413,136 @@ def __init__(self, env_info, batch_size=1): } dependencies = {"A": A_dependencies, "B": B_dependencies} - PyMDPEnv.__init__(self, params, dependencies) + Env.__init__(self, params, dependencies) + + def render(self, mode="human"): + """ + Renders the environment + Parameters + ---------- + mode: str, optional + The mode to render with ("human" or "rgb_array") + Returns + ---------- + if mode == "human": + returns None, renders the environment using matplotlib inside the function + elif mode == "rgb_array": + A (H, W, 3) jax.numpy array that can act as input to functions like plt.imshow, with values between 0 and 255 + """ + pass + # maze = maze_info["maze"] + # num_cues = maze_info["num_cues"] + # cue_positions = maze_info["cue_positions"] + # reward_1_positions = maze_info["reward_1_positions"] + # reward_2_positions = maze_info["reward_2_positions"] + + # current_position = env_state.state[0] + # current_position = jnp.unravel_index(current_position, maze.shape) + + # # Set all states not in [1] to be 0 (accessible state) + # mask = np.isin(maze, [2], invert=True) + # maze[mask] = 0 + + # plt.figure() + # plt.imshow(maze, cmap="gray_r", origin="lower") + + # cmap = plt.get_cmap("tab10") + # plt.scatter( + # [ci[1] for ci in cue_positions], + # [ci[0] for ci in cue_positions], + # color=[cmap(i) for i in range(len(cue_positions))], + # s=200, + # alpha=0.5, + # ) + # plt.scatter( + # [ci[1] for ci in cue_positions], + # [ci[0] for ci in cue_positions], + # color="black", + # s=50, + # label="Cue", + # marker="x", + # ) + + # plt.scatter( + # [ri[1] for ri in reward_1_positions], + # [ri[0] for ri in reward_1_positions], + # color=[cmap(i) for i in range(len(cue_positions))], + # s=200, + # alpha=0.5, + # ) + + # plt.scatter( + # [ri[1] for ri in reward_2_positions], + # [ri[0] for ri in reward_2_positions], + # color=[cmap(i) for i in range(len(cue_positions))], + # s=200, + # alpha=0.5, + # ) + + # plt.scatter( + # [ri[1] for ri in reward_1_positions[-1:]], + # [ri[0] for ri in reward_1_positions[-1:]], + # marker="o", + # color="red", + # s=50, + # label="Positive", + # ) + + # plt.scatter( + # [ri[1] for ri in reward_2_positions[-1:]], + # [ri[0] for ri in reward_2_positions[-1:]], + # marker="o", + # color="blue", + # s=50, + # label="Negative", + # ) + + # plt.scatter( + # current_position[1], + # current_position[0], + # c="tab:green", + # marker="s", + # s=100, + # label="Agent", + # ) + + # plt.title("Generalized T-Maze Environment") + + # handles, labels = plt.gca().get_legend_handles_labels() + # for i in range(num_cues): + # if i == num_cues - 1: + # label = "Reward set" + # else: + # label = f"Distractor {i + 1} set" + # patch = Line2D( + # [0], + # [0], + # marker="o", + # markersize=10, + # markerfacecolor=cmap(i), + # markeredgecolor=cmap(i), + # label=label, + # alpha=0.5, + # linestyle="", + # ) + # handles.append(patch) + + # plt.legend( + # handles=handles, loc="upper left", bbox_to_anchor=(1, 1), fancybox=True + # ) + # #plt.axis("off") + # plt.tight_layout() + + # # Capture the current figure as an image + # buf = io.BytesIO() + # plt.savefig(buf, format="png") + # buf.seek(0) + # image = PIL.Image.open(buf) + + # if show_img: + # plt.show() + + # return image + + +