Skip to content

Commit

Permalink
padding for masks
Browse files Browse the repository at this point in the history
  • Loading branch information
DennisSoemers committed Dec 13, 2023
1 parent ae4106c commit 476a819
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions gym_microrts/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,14 @@ def get_action_mask(self):
"""
# action_mask shape: [num_envs, map height, map width, 1 + action types + params]
action_mask = np.array(self.vec_client.getMasks(0))
num_envs, height, width, action_channels = action_mask.shape

# Add padding to the mask such that it is as big as we need for our biggest map
pad_width = self.width - width
pad_height = self.height - height
if pad_width > 0 or pad_height > 0:
action_mask = np.pad(action_mask, ((0, 0), (0, pad_height), (0, pad_width), (0, 0)))

# self.source_unit_mask shape: [num_envs, map height * map width * 1]
self.source_unit_mask = action_mask[:, :, :, 0].reshape(self.num_envs, -1)
action_type_and_parameter_mask = action_mask[:, :, :, 1:].reshape(self.num_envs, self.height * self.width, -1)
Expand Down

0 comments on commit 476a819

Please sign in to comment.