From 5820cfe206b1c7e7ba1a597c07e4bdfdb7a097ca Mon Sep 17 00:00:00 2001 From: Tom Tseng Date: Mon, 6 Jan 2025 21:21:11 -0800 Subject: [PATCH] Fix mypy errors --- src/imitation/algorithms/adversarial/common.py | 2 +- src/imitation/policies/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index ece30b011..bde75f83d 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -207,7 +207,7 @@ def __init__( self.debug_use_ground_truth = debug_use_ground_truth self.venv = venv self.gen_algo = gen_algo - self._reward_net = reward_net.to(gen_algo.device) + self._reward_net: reward_nets.RewardNet = reward_net.to(gen_algo.device) self._log_dir = util.parse_path(log_dir) # Create graph for optimising/recording stats on discriminator diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index ba1f550df..c12de74e0 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -31,7 +31,7 @@ def _predict( ): np_actions = [] if isinstance(obs, dict): - np_obs = types.DictObs( + np_obs: Union[types.DictObs, np.ndarray] = types.DictObs( {k: v.detach().cpu().numpy() for k, v in obs.items()}, ) else: