Skip to content

Commit

Permalink
started refactoring GeneralizedTMaze(Env) class
Browse files Browse the repository at this point in the history
  • Loading branch information
conorheins committed Oct 2, 2024
1 parent 786ec50 commit b4513b4
Showing 1 changed file with 157 additions and 77 deletions.
234 changes: 157 additions & 77 deletions pymdp/envs/generalized_tmaze.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,18 @@
from pymdp.jax.envs import PyMDPEnv
from .env import Env
import numpy as np
import jax.numpy as jnp

import matplotlib.pyplot as plt
import io
import PIL.Image

import jax.numpy as jnp
import jax.tree_util as jtu
from jax import random as jr
from jaxtyping import Array, PRNGKeyArray
from matplotlib.lines import Line2D


def position_to_index(position, shape):
"""
Maps the position in the grid to a flat index
Parameters
----------
position
Tuple of position (row, col)
shape
The shape of the grid (n_rows, n_cols)
Returns
----------
index
A flattened index of position
"""
return position[0] * shape[1] + position[1]


def index_to_position(index, shape):
"""
Maps the flat index to a position coordinate in the grid
Parameters
----------
shape
The shape of the grid (n_rows, n_cols)
index
A flattened index of position
Returns
----------
position
Tuple of position (row, col)
"""
return index // shape[1], index % shape[1]


def parse_maze(maze):
def parse_maze(maze, rng_key: PRNGKeyArray):
"""
Parameters
----------
Expand All @@ -67,45 +33,36 @@ def parse_maze(maze):
purposes
"""

maze = np.array(maze)
rows, cols = maze.shape

num_cues = int((np.max(maze) - 2) // 3)
num_cues = int((jnp.max(maze) - 2) // 3)

cue_positions = []
reward_1_positions = []
reward_2_positions = []
for i in range(num_cues):
cue_positions.append(tuple(np.argwhere(maze == 3 + 3 * i)[0]))
reward_1_positions.append(tuple(np.argwhere(maze == 4 + 3 * i)[0]))
reward_2_positions.append(tuple(np.argwhere(maze == 5 + 3 * i)[0]))
cue_positions.append(tuple(jnp.argwhere(maze == 3 + 3 * i)[0]))
reward_1_positions.append(tuple(jnp.argwhere(maze == 4 + 3 * i)[0]))
reward_2_positions.append(tuple(jnp.argwhere(maze == 5 + 3 * i)[0]))

# Initialize agent's starting position (can be customized if required)
initial_position = tuple(np.argwhere(maze == 1)[0])
initial_position = tuple(jnp.argwhere(maze == 1)[0])

# Actions: up, down, left, right
actions = [(-1, 0), (1, 0), (0, -1), (0, 1)]

# Set reward location randomly
reward_locations = np.random.choice([0, 1], size=num_cues)
reward_locations = jr.choice(rng_key, 2, shape=(num_cues,))
reward_indices = []
no_reward_indices = []

for i in range(num_cues):
if reward_locations[i] == 0:
reward_indices += position_to_index(
reward_1_positions[i], maze.shape
)
no_reward_indices += position_to_index(
reward_2_positions[i], maze.shape
)
reward_indices += [jnp.ravel_multi_index(jnp.array(reward_1_positions[i]), maze.shape).item()]
no_reward_indices += [jnp.ravel_multi_index(jnp.array(reward_2_positions[i]), maze.shape).item()]
else:
reward_indices += position_to_index(
reward_2_positions[i], maze.shape
)
no_reward_indices += position_to_index(
reward_1_positions[i], maze.shape
)
reward_indices += [jnp.ravel_multi_index(jnp.array(reward_2_positions[i]), maze.shape).item()]
no_reward_indices += [jnp.ravel_multi_index(jnp.array(reward_1_positions[i]), maze.shape).item()]

return {
"maze": maze,
Expand Down Expand Up @@ -158,13 +115,9 @@ def generate_A(maze_info):
cue_likelihood = np.zeros((3, num_states, 2))
cue_likelihood[0, :, :] = 1 # Default: no info about reward

cue_state_idx = position_to_index(cue_positions[i], maze.shape)
reward_1_state_idx = position_to_index(
reward_1_positions[i], maze.shape
)
reward_2_state_idx = position_to_index(
reward_2_positions[i], maze.shape
)
cue_state_idx = jnp.ravel_multi_index(jnp.array(cue_positions[i]), maze.shape)
reward_1_state_idx = jnp.ravel_multi_index(jnp.array(reward_1_positions[i]), maze.shape)
reward_2_state_idx = jnp.ravel_multi_index(jnp.array(reward_2_positions[i]), maze.shape)

cue_likelihood[:, cue_state_idx, 0] = [0, 1, 0] # Reward in r1
cue_likelihood[:, cue_state_idx, 1] = [0, 0, 1] # Reward in r2
Expand All @@ -178,13 +131,8 @@ def generate_A(maze_info):
reward_likelihood = np.zeros((3, num_states, 2))
reward_likelihood[0, :, :] = 1 # Default: no reward

reward_1_state_idx = position_to_index(
reward_1_positions[i], maze.shape
)

reward_2_state_idx = position_to_index(
reward_2_positions[i], maze.shape
)
reward_1_state_idx = jnp.ravel_multi_index(jnp.array(reward_1_positions[i]), maze.shape)
reward_2_state_idx = jnp.ravel_multi_index(jnp.array(reward_2_positions[i]), maze.shape)

# Reward in (8,4) if reward state is 0
reward_likelihood[:, reward_1_state_idx, 0] = [0, 1, 0]
Expand Down Expand Up @@ -254,7 +202,7 @@ def generate_B(maze_info):
):
P[s, a] = s
else:
P[s, a] = position_to_index((ns_row, ns_col), maze.shape)
P[s, a] = jnp.ravel_multi_index(jnp.array((ns_row, ns_col)), maze.shape)

B = np.zeros((num_states, num_states, num_actions))
for s in range(num_states):
Expand Down Expand Up @@ -306,7 +254,7 @@ def generate_D(maze_info):

D[0] = np.zeros(cols * rows)
# Position of the agent when starting the environment
D[0][position_to_index(initial_position, maze.shape)] = 1
D[0][jnp.ravel_multi_index(jnp.array(initial_position), maze.shape)] = 1

# Cue state i.e. where is the reward
for i in range(num_cues):
Expand Down Expand Up @@ -339,7 +287,7 @@ def render(maze_info, env_state, show_img=True):
reward_2_positions = maze_info["reward_2_positions"]

current_position = env_state.state[0]
current_position = index_to_position(current_position, maze.shape)
current_position = jnp.unravel_index(current_position, maze.shape)

# Set all states not in [1] to be 0 (accessible state)
mask = np.isin(maze, [2], invert=True)
Expand Down Expand Up @@ -447,7 +395,7 @@ def render(maze_info, env_state, show_img=True):
return image


class GeneralizedTMazeEnv(PyMDPEnv):
class GeneralizedTMazeEnv(Env):
"""
Extended version of the T-Maze in which there are multiple cues and reward pairs
similar to the original T-maze.
Expand All @@ -465,4 +413,136 @@ def __init__(self, env_info, batch_size=1):
}
dependencies = {"A": A_dependencies, "B": B_dependencies}

PyMDPEnv.__init__(self, params, dependencies)
Env.__init__(self, params, dependencies)

def render(self, mode="human"):
"""
Renders the environment
Parameters
----------
mode: str, optional
The mode to render with ("human" or "rgb_array")
Returns
----------
if mode == "human":
returns None, renders the environment using matplotlib inside the function
elif mode == "rgb_array":
A (H, W, 3) jax.numpy array that can act as input to functions like plt.imshow, with values between 0 and 255
"""
pass
# maze = maze_info["maze"]
# num_cues = maze_info["num_cues"]
# cue_positions = maze_info["cue_positions"]
# reward_1_positions = maze_info["reward_1_positions"]
# reward_2_positions = maze_info["reward_2_positions"]

# current_position = env_state.state[0]
# current_position = jnp.unravel_index(current_position, maze.shape)

# # Set all states not in [1] to be 0 (accessible state)
# mask = np.isin(maze, [2], invert=True)
# maze[mask] = 0

# plt.figure()
# plt.imshow(maze, cmap="gray_r", origin="lower")

# cmap = plt.get_cmap("tab10")
# plt.scatter(
# [ci[1] for ci in cue_positions],
# [ci[0] for ci in cue_positions],
# color=[cmap(i) for i in range(len(cue_positions))],
# s=200,
# alpha=0.5,
# )
# plt.scatter(
# [ci[1] for ci in cue_positions],
# [ci[0] for ci in cue_positions],
# color="black",
# s=50,
# label="Cue",
# marker="x",
# )

# plt.scatter(
# [ri[1] for ri in reward_1_positions],
# [ri[0] for ri in reward_1_positions],
# color=[cmap(i) for i in range(len(cue_positions))],
# s=200,
# alpha=0.5,
# )

# plt.scatter(
# [ri[1] for ri in reward_2_positions],
# [ri[0] for ri in reward_2_positions],
# color=[cmap(i) for i in range(len(cue_positions))],
# s=200,
# alpha=0.5,
# )

# plt.scatter(
# [ri[1] for ri in reward_1_positions[-1:]],
# [ri[0] for ri in reward_1_positions[-1:]],
# marker="o",
# color="red",
# s=50,
# label="Positive",
# )

# plt.scatter(
# [ri[1] for ri in reward_2_positions[-1:]],
# [ri[0] for ri in reward_2_positions[-1:]],
# marker="o",
# color="blue",
# s=50,
# label="Negative",
# )

# plt.scatter(
# current_position[1],
# current_position[0],
# c="tab:green",
# marker="s",
# s=100,
# label="Agent",
# )

# plt.title("Generalized T-Maze Environment")

# handles, labels = plt.gca().get_legend_handles_labels()
# for i in range(num_cues):
# if i == num_cues - 1:
# label = "Reward set"
# else:
# label = f"Distractor {i + 1} set"
# patch = Line2D(
# [0],
# [0],
# marker="o",
# markersize=10,
# markerfacecolor=cmap(i),
# markeredgecolor=cmap(i),
# label=label,
# alpha=0.5,
# linestyle="",
# )
# handles.append(patch)

# plt.legend(
# handles=handles, loc="upper left", bbox_to_anchor=(1, 1), fancybox=True
# )
# #plt.axis("off")
# plt.tight_layout()

# # Capture the current figure as an image
# buf = io.BytesIO()
# plt.savefig(buf, format="png")
# buf.seek(0)
# image = PIL.Image.open(buf)

# if show_img:
# plt.show()

# return image



0 comments on commit b4513b4

Please sign in to comment.