Skip to content

Commit

Permalink
Merge pull request #110 from sony/feature/20231012-enable-tuple-actio…
Browse files Browse the repository at this point in the history
…n-support

Feature/20231012 enable tuple action support
  • Loading branch information
TakayoshiTakayanagi authored Oct 18, 2023
2 parents d7142cf + eff8bdb commit 5df4f05
Show file tree
Hide file tree
Showing 37 changed files with 305 additions and 61 deletions.
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_continuous_action_env()
return not env_info.is_continuous_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/atrpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/categorical_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_continuous_action_env()
return not env_info.is_continuous_action_env() and not env_info.is_tuple_action_env()

@classmethod
def is_rnn_supported(self):
Expand Down
10 changes: 7 additions & 3 deletions nnabla_rl/algorithms/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,13 @@ def __call__(self, s: Union[np.ndarray, Tuple[np.ndarray, ...]], *, begin_of_epi
for key in self._rnn_internal_states.keys():
# copy internal states of previous iteration
self._rnn_internal_states[key].d = prev_rnn_states[key].d
self._action.forward(clear_no_need_grad=True)
# No need to save internal states
action = np.squeeze(self._action.d, axis=0) if batch_size == 1 else self._action.d
if self._env_info.is_tuple_action_env():
nn.forward_all(self._action, clear_no_need_grad=True)
action = tuple(np.squeeze(a.d, axis=0) if batch_size == 1 else a.d for a in self._action)
else:
self._action.forward(clear_no_need_grad=True)
# No need to save internal states
action = np.squeeze(self._action.d, axis=0) if batch_size == 1 else self._action.d
return action, {}

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def trainers(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def is_rnn_supported(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
3 changes: 2 additions & 1 deletion nnabla_rl/algorithms/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(
env_or_env_info, gym.Env) else env_or_env_info
return env_info.is_continuous_action_env() or env_info.is_discrete_action_env()
return ((env_info.is_continuous_action_env() or env_info.is_discrete_action_env())
and not env_info.is_tuple_action_env())

@classmethod
def is_rnn_supported(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/demme_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def is_rnn_supported(cls):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_continuous_action_env()
return not env_info.is_continuous_action_env() and not env_info.is_tuple_action_env()

@classmethod
def is_rnn_supported(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/her.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,4 +352,4 @@ def is_supported_env(cls, env_or_env_info):
# continuous action env
is_continuous_action_env = env_info.is_continuous_action_env()
is_goal_conditioned_env = env_info.is_goal_conditioned_env()
return (is_continuous_action_env and is_goal_conditioned_env)
return (is_continuous_action_env and is_goal_conditioned_env) and not env_info.is_tuple_action_env()
4 changes: 3 additions & 1 deletion nnabla_rl/algorithms/icml2015_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,9 @@ def _solvers(self):

@classmethod
def is_supported_env(cls, env_or_env_info):
return True # supports all enviroments
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/icml2018_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def is_rnn_supported(cls):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/icra2018_qtopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,4 @@ def _random_action_selector(self, s, *, begin_of_episode=False):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return env_info.is_continuous_action_env()
return env_info.is_continuous_action_env() and not env_info.is_tuple_action_env()
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_continuous_action_env()
return not env_info.is_continuous_action_env() and not env_info.is_tuple_action_env()

@classmethod
def is_rnn_supported(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/lqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def trainers(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/mppi.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
4 changes: 3 additions & 1 deletion nnabla_rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ def _solvers(self):

@classmethod
def is_supported_env(cls, env_or_env_info):
return True # supports all enviroments
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_tuple_action_env()

def _build_ppo_actors(self, env, v_function, policy, state_preprocessor):
actors = []
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_continuous_action_env()
return not env_info.is_continuous_action_env() and not env_info.is_tuple_action_env()

@classmethod
def is_rnn_supported(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/qrsac.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def is_rnn_supported(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
4 changes: 3 additions & 1 deletion nnabla_rl/algorithms/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ def _solvers(self):

@classmethod
def is_supported_env(cls, env_or_env_info):
return True # supports all enviroments
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def is_rnn_supported(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def is_rnn_supported(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def _solvers(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/xql.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def is_rnn_supported(self):
def is_supported_env(cls, env_or_env_info):
env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \
else env_or_env_info
return not env_info.is_discrete_action_env()
return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env()

@property
def latest_iteration_state(self):
Expand Down
28 changes: 22 additions & 6 deletions nnabla_rl/environment_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,7 @@ def rollout(self, env: gym.Env) -> List[Experience]:
def _step_once(self, env, *, begin_of_episode=False) -> Tuple[Experience, bool]:
self._steps += 1
if self._steps < self._config.warmup_random_steps:
action_info: Dict[str, Any] = {}
if self._env_info.is_discrete_action_env():
action = env.action_space.sample()
self._action = np.asarray(action).reshape((1, ))
else:
self._action = env.action_space.sample()
self._action, action_info = self._warmup_action(env)
else:
self._action, action_info = self.action(self._steps,
cast(np.ndarray, self._state),
Expand Down Expand Up @@ -156,9 +151,30 @@ def _step_once(self, env, *, begin_of_episode=False) -> Tuple[Experience, bool]:
self._state = self._next_state
return experience, done

def _warmup_action(self, env):
return _sample_action(env, self._env_info)


def _is_end_of_episode(done, timelimit, timelimit_as_terminal):
if not done:
return False
else:
return (not timelimit) or (timelimit and timelimit_as_terminal)


def _sample_action(env, env_info):
action_info: Dict[str, Any] = {}
if env_info.is_tuple_action_env():
action = []
for a, action_space in zip(env.action_space.sample(), env_info.action_space):
if isinstance(action_space, gym.spaces.Discrete):
a = np.asarray(a).reshape((1, ))
action.append(a)
action = tuple(action)
else:
if env_info.is_discrete_action_env():
action = env.action_space.sample()
action = np.asarray(action).reshape((1, ))
else:
action = env.action_space.sample()
return action, action_info
8 changes: 5 additions & 3 deletions nnabla_rl/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022 Sony Group Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,8 +17,10 @@

from nnabla_rl.environments.dummy import (DummyAtariEnv, DummyContinuous, DummyContinuousActionGoalEnv, DummyDiscrete, # noqa
DummyDiscreteActionGoalEnv, DummyDiscreteImg, DummyContinuousImg,
DummyFactoredContinuous, DummyMujocoEnv, DummyTupleContinuous,
DummyTupleDiscrete, DummyTupleMixed)
DummyFactoredContinuous, DummyMujocoEnv,
DummyTupleContinuous, DummyTupleDiscrete, DummyTupleMixed,
DummyTupleStateContinuous, DummyTupleStateDiscrete,
DummyTupleActionContinuous, DummyTupleActionDiscrete)

register(
id='FakeMujocoNNablaRL-v1',
Expand Down
Loading

0 comments on commit 5df4f05

Please sign in to comment.