diff --git a/src/seals/util.py b/src/seals/util.py index 6299bdc..c3b6718 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -16,7 +16,6 @@ import gymnasium as gym import numpy as np -import numpy.typing as npt # Note: we redefine the type vars from gymnasium.core here, because pytype does not # recognize them as valid type vars if we import them from gymnasium.core. @@ -137,10 +136,7 @@ class BoxRegion: MaskedRegionSpecifier = List[BoxRegion] -class MaskScoreWrapper( - gym.Wrapper[npt.NDArray, ActType, npt.NDArray, ActType], - Generic[ActType], -): +class MaskScoreWrapper(gym.ObservationWrapper): """Mask a list of box-shaped regions in the observation to hide reward info. Intended for environments whose observations are raw pixels (like Atari @@ -178,22 +174,10 @@ def __init__( raise ValueError('Invalid region: "x" and "y" must be increasing.') self.mask[r.x[0] : r.x[1], r.y[0] : r.y[1]] = 0 - def _mask_obs(self, obs) -> npt.NDArray: + def observation(self, obs): + """Returns observation with masked regions filled with `fill_value`.""" return np.where(self.mask, obs, self.fill_value) - def step( - self, - action: ActType, - ) -> Tuple[npt.NDArray, SupportsFloat, bool, bool, Dict[str, Any]]: - """Returns (obs, rew, terminated, truncated, info) with masked obs.""" - obs, rew, terminated, truncated, info = self.env.step(action) - return self._mask_obs(obs), rew, terminated, truncated, info - - def reset(self, **kwargs): - """Returns masked reset observation.""" - obs, info = self.env.reset(**kwargs) - return self._mask_obs(obs), info - class ObsCastWrapper(gym.ObservationWrapper): """Cast observations to specified dtype.