From dc7511cd82c71f95d93f91b3c97e754dda38dc50 Mon Sep 17 00:00:00 2001 From: mishari <44849486+maliesa96@users.noreply.github.com> Date: Mon, 14 Sep 2020 14:46:21 -0700 Subject: [PATCH] Add two torch primitives (#1993) This adds the following primitives to torch: - DiscreteQFDerivedPolicy - DiscreteMLPQFunction --- ...trained_network_to_start_new_experiment.md | 4 +- examples/tf/dqn_cartpole.py | 4 +- examples/tf/dqn_pong.py | 4 +- src/garage/tf/policies/__init__.py | 5 +- ...policy.py => discrete_qf_argmax_policy.py} | 9 +- src/garage/torch/policies/__init__.py | 3 + .../policies/discrete_qf_argmax_policy.py | 68 +++++++++++++++ src/garage/torch/q_functions/__init__.py | 6 +- .../q_functions/discrete_mlp_q_function.py | 60 ++++++++++++++ tests/garage/tf/algos/test_dqn.py | 10 +-- ...y.py => test_discrete_qf_argmax_policy.py} | 16 ++-- .../test_discrete_qf_argmax_policy.py | 74 +++++++++++++++++ .../test_discrete_mlp_q_function.py | 82 +++++++++++++++++++ 13 files changed, 317 insertions(+), 28 deletions(-) rename src/garage/tf/policies/{discrete_qf_derived_policy.py => discrete_qf_argmax_policy.py} (95%) create mode 100644 src/garage/torch/policies/discrete_qf_argmax_policy.py create mode 100644 src/garage/torch/q_functions/discrete_mlp_q_function.py rename tests/garage/tf/policies/{test_qf_derived_policy.py => test_discrete_qf_argmax_policy.py} (87%) create mode 100644 tests/garage/torch/policies/test_discrete_qf_argmax_policy.py create mode 100644 tests/garage/torch/q_functions/test_discrete_mlp_q_function.py diff --git a/docs/user/use_pretrained_network_to_start_new_experiment.md b/docs/user/use_pretrained_network_to_start_new_experiment.md index c80264dd34..b316d5f861 100644 --- a/docs/user/use_pretrained_network_to_start_new_experiment.md +++ b/docs/user/use_pretrained_network_to_start_new_experiment.md @@ -88,7 +88,7 @@ from garage.experiment.deterministic import set_seed from garage.np.exploration_policies import EpsilonGreedyPolicy from garage.replay_buffer import PathBuffer from garage.tf.algos import DQN -from garage.tf.policies import DiscreteQfDerivedPolicy +from garage.tf.policies import DiscreteQFArgmaxPolicy from garage.trainer import TFTrainer @click.command() @@ -135,7 +135,7 @@ def dqn_pong(ctxt=None, seed=1, buffer_size=int(5e4), max_episode_length=500): qf = snapshot['algo']._qf # MARK: end modifications to existing example - policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) + policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, diff --git a/examples/tf/dqn_cartpole.py b/examples/tf/dqn_cartpole.py index c04fab5c96..ad194fe57f 100755 --- a/examples/tf/dqn_cartpole.py +++ b/examples/tf/dqn_cartpole.py @@ -9,7 +9,7 @@ from garage.np.exploration_policies import EpsilonGreedyPolicy from garage.replay_buffer import PathBuffer from garage.tf.algos import DQN -from garage.tf.policies import DiscreteQfDerivedPolicy +from garage.tf.policies import DiscreteQFArgmaxPolicy from garage.tf.q_functions import DiscreteMLPQFunction from garage.trainer import TFTrainer @@ -34,7 +34,7 @@ def dqn_cartpole(ctxt=None, seed=1): env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e4)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64)) - policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) + policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, diff --git a/examples/tf/dqn_pong.py b/examples/tf/dqn_pong.py index 18be0b21a6..ee79aa69e1 100755 --- a/examples/tf/dqn_pong.py +++ b/examples/tf/dqn_pong.py @@ -20,7 +20,7 @@ from garage.np.exploration_policies import EpsilonGreedyPolicy from garage.replay_buffer import PathBuffer from garage.tf.algos import DQN -from garage.tf.policies import DiscreteQfDerivedPolicy +from garage.tf.policies import DiscreteQFArgmaxPolicy from garage.tf.q_functions import DiscreteCNNQFunction from garage.trainer import TFTrainer @@ -75,7 +75,7 @@ def dqn_pong(ctxt=None, seed=1, buffer_size=int(5e4), max_episode_length=500): strides=(4, 2, 1), dueling=False) # yapf: disable - policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) + policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, diff --git a/src/garage/tf/policies/__init__.py b/src/garage/tf/policies/__init__.py index dd83296828..25fa526bf4 100644 --- a/src/garage/tf/policies/__init__.py +++ b/src/garage/tf/policies/__init__.py @@ -4,8 +4,7 @@ from garage.tf.policies.categorical_lstm_policy import CategoricalLSTMPolicy from garage.tf.policies.categorical_mlp_policy import CategoricalMLPPolicy from garage.tf.policies.continuous_mlp_policy import ContinuousMLPPolicy -from garage.tf.policies.discrete_qf_derived_policy import ( - DiscreteQfDerivedPolicy) +from garage.tf.policies.discrete_qf_argmax_policy import DiscreteQFArgmaxPolicy from garage.tf.policies.gaussian_gru_policy import GaussianGRUPolicy from garage.tf.policies.gaussian_lstm_policy import GaussianLSTMPolicy from garage.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy @@ -17,7 +16,7 @@ __all__ = [ 'Policy', 'CategoricalCNNPolicy', 'CategoricalGRUPolicy', 'CategoricalLSTMPolicy', 'CategoricalMLPPolicy', 'ContinuousMLPPolicy', - 'DiscreteQfDerivedPolicy', 'GaussianGRUPolicy', 'GaussianLSTMPolicy', + 'DiscreteQFArgmaxPolicy', 'GaussianGRUPolicy', 'GaussianLSTMPolicy', 'GaussianMLPPolicy', 'GaussianMLPTaskEmbeddingPolicy', 'TaskEmbeddingPolicy' ] diff --git a/src/garage/tf/policies/discrete_qf_derived_policy.py b/src/garage/tf/policies/discrete_qf_argmax_policy.py similarity index 95% rename from src/garage/tf/policies/discrete_qf_derived_policy.py rename to src/garage/tf/policies/discrete_qf_argmax_policy.py index b0dbf63b3b..d4c2c96ece 100644 --- a/src/garage/tf/policies/discrete_qf_derived_policy.py +++ b/src/garage/tf/policies/discrete_qf_argmax_policy.py @@ -10,8 +10,8 @@ from garage.tf.policies.policy import Policy -class DiscreteQfDerivedPolicy(Module, Policy): - """DiscreteQfDerived policy. +class DiscreteQFArgmaxPolicy(Module, Policy): + """DiscreteQFArgmax policy. Args: env_spec (garage.envs.env_spec.EnvSpec): Environment specification. @@ -20,10 +20,9 @@ class DiscreteQfDerivedPolicy(Module, Policy): """ - def __init__(self, env_spec, qf, name='DiscreteQfDerivedPolicy'): + def __init__(self, env_spec, qf, name='DiscreteQFArgmaxPolicy'): assert isinstance(env_spec.action_space, akro.Discrete), ( - 'DiscreteQfDerivedPolicy only supports akro.Discrete action spaces' - ) + 'DiscreteQFArgmaxPolicy only supports akro.Discrete action spaces') if isinstance(env_spec.observation_space, akro.Dict): raise ValueError('CNN policies do not support' 'with akro.Dict observation spaces.') diff --git a/src/garage/torch/policies/__init__.py b/src/garage/torch/policies/__init__.py index b3d826623c..a6c1ffb609 100644 --- a/src/garage/torch/policies/__init__.py +++ b/src/garage/torch/policies/__init__.py @@ -4,6 +4,8 @@ ContextConditionedPolicy) from garage.torch.policies.deterministic_mlp_policy import ( DeterministicMLPPolicy) +from garage.torch.policies.discrete_qf_argmax_policy import ( + DiscreteQFArgmaxPolicy) from garage.torch.policies.gaussian_mlp_policy import GaussianMLPPolicy from garage.torch.policies.policy import Policy from garage.torch.policies.tanh_gaussian_mlp_policy import ( @@ -12,6 +14,7 @@ __all__ = [ 'CategoricalCNNPolicy', 'DeterministicMLPPolicy', + 'DiscreteQFArgmaxPolicy', 'GaussianMLPPolicy', 'Policy', 'TanhGaussianMLPPolicy', diff --git a/src/garage/torch/policies/discrete_qf_argmax_policy.py b/src/garage/torch/policies/discrete_qf_argmax_policy.py new file mode 100644 index 0000000000..4ed39c53b4 --- /dev/null +++ b/src/garage/torch/policies/discrete_qf_argmax_policy.py @@ -0,0 +1,68 @@ +"""A Discrete QFunction-derived policy. + +This policy chooses the action that yields to the largest Q-value. +""" +import numpy as np +import torch + +from garage.torch.policies.policy import Policy + + +class DiscreteQFArgmaxPolicy(Policy): + """Policy that derives its actions from a learned Q function. + + The action returned is the one that yields the highest Q value for + a given state, as determined by the supplied Q function. + + Args: + qf (object): Q network. + env_spec (EnvSpec): Environment specification. + name (str): Name of this policy. + """ + + def __init__(self, qf, env_spec, name='DiscreteQFArgmaxPolicy'): + super().__init__(env_spec, name) + self._qf = qf + + # pylint: disable=arguments-differ + def forward(self, observations): + """Get actions corresponding to a batch of observations. + + Args: + observations(torch.Tensor): Batch of observations of shape + :math:`(N, O)`. Observations should be flattened even + if they are images as the underlying Q network handles + unflattening. + + Returns: + torch.Tensor: Batch of actions of shape :math:`(N, A)` + """ + qs = self._qf(observations) + return torch.argmax(qs, dim=1) + + def get_action(self, observation): + """Get a single action given an observation. + + Args: + observation (np.ndarray): Observation with shape :math:`(O, )`. + + Returns: + torch.Tensor: Predicted action with shape :math:`(A, )`. + dict: Empty since this policy does not produce a distribution. + """ + act, dist = self.get_actions(np.expand_dims(observation, axis=0)) + return act[0], dist + + def get_actions(self, observations): + """Get actions given observations. + + Args: + observations (np.ndarray): Batch of observations, should + have shape :math:`(N, O)`. + + Returns: + torch.Tensor: Predicted actions. Tensor has shape :math:`(N, A)`. + dict: Empty since this policy does not produce a distribution. + """ + with torch.no_grad(): + return self(torch.Tensor(observations)).numpy(), dict() diff --git a/src/garage/torch/q_functions/__init__.py b/src/garage/torch/q_functions/__init__.py index 9a402446da..31c1252272 100644 --- a/src/garage/torch/q_functions/__init__.py +++ b/src/garage/torch/q_functions/__init__.py @@ -3,5 +3,9 @@ ContinuousMLPQFunction) from garage.torch.q_functions.discrete_cnn_q_function import ( DiscreteCNNQFunction) +from garage.torch.q_functions.discrete_mlp_q_function import ( + DiscreteMLPQFunction) -__all__ = ['ContinuousMLPQFunction', 'DiscreteCNNQFunction'] +__all__ = [ + 'ContinuousMLPQFunction', 'DiscreteCNNQFunction', 'DiscreteMLPQFunction' +] diff --git a/src/garage/torch/q_functions/discrete_mlp_q_function.py b/src/garage/torch/q_functions/discrete_mlp_q_function.py new file mode 100644 index 0000000000..9b8663908b --- /dev/null +++ b/src/garage/torch/q_functions/discrete_mlp_q_function.py @@ -0,0 +1,60 @@ +"""This modules creates a continuous Q-function network.""" + +from torch import nn +from torch.nn import functional as F + +from garage.torch.modules import MLPModule + + +# pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305 +# pylint: disable=abstract-method +class DiscreteMLPQFunction(MLPModule): + """Implements a discrete MLP Q-value network. + + It predicts the Q-value for all possible actions based on the + input state. + + Args: + env_spec (EnvSpec): Environment specification. + hidden_sizes (list[int]): Output dimension of dense layer(s). + For example, (32, 32) means this MLP consists of two + hidden layers, each with 32 hidden units. + hidden_nonlinearity (callable or torch.nn.Module): Activation function + for intermediate dense layer(s). It should return a torch.Tensor. + Set it to None to maintain a linear activation. + hidden_w_init (callable): Initializer function for the weight + of intermediate dense layer(s). The function should return a + torch.Tensor. + hidden_b_init (callable): Initializer function for the bias + of intermediate dense layer(s). The function should return a + torch.Tensor. + output_nonlinearity (callable or torch.nn.Module): Activation function + for output dense layer. It should return a torch.Tensor. + Set it to None to maintain a linear activation. + output_w_init (callable): Initializer function for the weight + of output dense layer(s). The function should return a + torch.Tensor. + output_b_init (callable): Initializer function for the bias + of output dense layer(s). The function should return a + torch.Tensor. + layer_normalization (bool): Bool for using layer normalization or not. + + """ + + def __init__(self, + env_spec, + hidden_sizes, + hidden_nonlinearity=F.relu, + hidden_w_init=nn.init.xavier_normal_, + hidden_b_init=nn.init.zeros_, + output_nonlinearity=None, + output_w_init=nn.init.xavier_normal_, + output_b_init=nn.init.zeros_, + layer_normalization=False): + + input_dim = env_spec.observation_space.flat_dim + output_dim = env_spec.action_space.flat_dim + super().__init__(input_dim, output_dim, hidden_sizes, + hidden_nonlinearity, hidden_w_init, hidden_b_init, + output_nonlinearity, output_w_init, output_b_init, + layer_normalization) diff --git a/tests/garage/tf/algos/test_dqn.py b/tests/garage/tf/algos/test_dqn.py index 314a39b094..8c004bf155 100644 --- a/tests/garage/tf/algos/test_dqn.py +++ b/tests/garage/tf/algos/test_dqn.py @@ -13,7 +13,7 @@ from garage.np.exploration_policies import EpsilonGreedyPolicy from garage.replay_buffer import PathBuffer from garage.tf.algos import DQN -from garage.tf.policies import DiscreteQfDerivedPolicy +from garage.tf.policies import DiscreteQFArgmaxPolicy from garage.tf.q_functions import DiscreteMLPQFunction from garage.trainer import TFTrainer @@ -34,7 +34,7 @@ def test_dqn_cartpole(self): env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e4)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64)) - policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) + policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) epilson_greedy_policy = EpsilonGreedyPolicy( env_spec=env.spec, policy=policy, @@ -75,7 +75,7 @@ def test_dqn_cartpole_double_q(self): env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e4)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64)) - policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) + policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) epilson_greedy_policy = EpsilonGreedyPolicy( env_spec=env.spec, policy=policy, @@ -116,7 +116,7 @@ def test_dqn_cartpole_grad_clip(self): env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e4)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64)) - policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) + policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) epilson_greedy_policy = EpsilonGreedyPolicy( env_spec=env.spec, policy=policy, @@ -157,7 +157,7 @@ def test_dqn_cartpole_pickle(self): env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e4)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64)) - policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) + policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) epilson_greedy_policy = EpsilonGreedyPolicy( env_spec=env.spec, policy=policy, diff --git a/tests/garage/tf/policies/test_qf_derived_policy.py b/tests/garage/tf/policies/test_discrete_qf_argmax_policy.py similarity index 87% rename from tests/garage/tf/policies/test_qf_derived_policy.py rename to tests/garage/tf/policies/test_discrete_qf_argmax_policy.py index 784733ea50..6712fba4ed 100644 --- a/tests/garage/tf/policies/test_qf_derived_policy.py +++ b/tests/garage/tf/policies/test_discrete_qf_argmax_policy.py @@ -5,7 +5,7 @@ from garage.envs import GymEnv from garage.envs.wrappers import AtariEnv -from garage.tf.policies import DiscreteQfDerivedPolicy +from garage.tf.policies import DiscreteQFArgmaxPolicy from garage.tf.q_functions import DiscreteCNNQFunction # yapf: disable @@ -24,12 +24,12 @@ def setup_method(self): super().setup_method() self.env = GymEnv(DummyDiscreteEnv()) self.qf = SimpleQFunction(self.env.spec) - self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec, - qf=self.qf) + self.policy = DiscreteQFArgmaxPolicy(env_spec=self.env.spec, + qf=self.qf) self.sess.run(tf.compat.v1.global_variables_initializer()) self.env.reset() - def test_discrete_qf_derived_policy(self): + def test_discrete_qf_argmax_policy(self): obs = self.env.step(1).observation action, _ = self.policy.get_action(obs) assert self.env.action_space.contains(action) @@ -62,14 +62,14 @@ def test_does_not_support_dict_obs_space(self): with pytest.raises(ValueError): qf = SimpleQFunction(env.spec, name='does_not_support_dict_obs_space') - DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) + DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) def test_invalid_action_spaces(self): """Test that policy raises error if passed a dict obs space.""" env = GymEnv(DummyDictEnv(act_space_type='box')) with pytest.raises(ValueError): qf = SimpleQFunction(env.spec) - DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) + DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) class TestQfDerivedPolicyImageObs(TfGraphTestCase): @@ -82,8 +82,8 @@ def setup_method(self): filters=((1, (1, 1)), ), strides=(1, ), dueling=False) - self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec, - qf=self.qf) + self.policy = DiscreteQFArgmaxPolicy(env_spec=self.env.spec, + qf=self.qf) self.sess.run(tf.compat.v1.global_variables_initializer()) self.env.reset() diff --git a/tests/garage/torch/policies/test_discrete_qf_argmax_policy.py b/tests/garage/torch/policies/test_discrete_qf_argmax_policy.py new file mode 100644 index 0000000000..1f40314b8b --- /dev/null +++ b/tests/garage/torch/policies/test_discrete_qf_argmax_policy.py @@ -0,0 +1,74 @@ +import pickle + +import numpy as np +import pytest +import torch + +from garage.envs import GymEnv +from garage.torch.policies import DiscreteQFArgmaxPolicy +from garage.torch.q_functions import DiscreteMLPQFunction + +from tests.fixtures.envs.dummy import DummyBoxEnv + + +@pytest.mark.parametrize('batch_size', [1, 5, 10]) +def test_forward(batch_size): + env_spec = GymEnv(DummyBoxEnv()).spec + obs_dim = env_spec.observation_space.flat_dim + obs = torch.ones([batch_size, obs_dim], dtype=torch.float32) + qf = DiscreteMLPQFunction(env_spec=env_spec, + hidden_nonlinearity=None, + hidden_sizes=(2, 2)) + qvals = qf(obs) + policy = DiscreteQFArgmaxPolicy(qf, env_spec) + assert (policy(obs) == torch.argmax(qvals, dim=1)).all() + assert policy(obs).shape == (batch_size, ) + + +def test_get_action(): + env_spec = GymEnv(DummyBoxEnv()).spec + obs_dim = env_spec.observation_space.flat_dim + obs = torch.ones([ + obs_dim, + ], dtype=torch.float32) + qf = DiscreteMLPQFunction(env_spec=env_spec, + hidden_nonlinearity=None, + hidden_sizes=(2, 2)) + qvals = qf(obs.unsqueeze(0)) + policy = DiscreteQFArgmaxPolicy(qf, env_spec) + action, _ = policy.get_action(obs) + assert action == torch.argmax(qvals, dim=1).numpy() + assert action.shape == () + + +@pytest.mark.parametrize('batch_size', [1, 5, 10]) +def test_get_actions(batch_size): + env_spec = GymEnv(DummyBoxEnv()).spec + obs_dim = env_spec.observation_space.flat_dim + obs = torch.ones([batch_size, obs_dim], dtype=torch.float32) + qf = DiscreteMLPQFunction(env_spec=env_spec, + hidden_nonlinearity=None, + hidden_sizes=(2, 2)) + qvals = qf(obs) + policy = DiscreteQFArgmaxPolicy(qf, env_spec) + actions, _ = policy.get_actions(obs) + assert (actions == torch.argmax(qvals, dim=1).numpy()).all() + assert actions.shape == (batch_size, ) + + +@pytest.mark.parametrize('batch_size', [1, 5, 10]) +def test_is_pickleable(batch_size): + env_spec = GymEnv(DummyBoxEnv()) + obs_dim = env_spec.observation_space.flat_dim + obs = torch.ones([batch_size, obs_dim], dtype=torch.float32) + qf = DiscreteMLPQFunction(env_spec=env_spec, + hidden_nonlinearity=None, + hidden_sizes=(2, 2)) + policy = DiscreteQFArgmaxPolicy(qf, env_spec) + + output1 = policy.get_actions(obs)[0] + + p = pickle.dumps(policy) + policy_pickled = pickle.loads(p) + output2 = policy_pickled.get_actions(obs)[0] + assert np.array_equal(output1, output2) diff --git a/tests/garage/torch/q_functions/test_discrete_mlp_q_function.py b/tests/garage/torch/q_functions/test_discrete_mlp_q_function.py new file mode 100644 index 0000000000..14b5af1671 --- /dev/null +++ b/tests/garage/torch/q_functions/test_discrete_mlp_q_function.py @@ -0,0 +1,82 @@ +import pickle + +import numpy as np +import pytest +import torch +from torch import nn + +from garage.envs import GymEnv +from garage.torch.q_functions import DiscreteMLPQFunction + +from tests.fixtures.envs.dummy import DummyBoxEnv + + +# yapf: disable +@pytest.mark.parametrize('hidden_sizes', [ + (1, ), (2, ), (3, ), (1, 1), (2, 2)]) +# yapf: enable +def test_forward(hidden_sizes): + env_spec = GymEnv(DummyBoxEnv()).spec + obs_dim = env_spec.observation_space.flat_dim + obs = torch.ones(obs_dim, dtype=torch.float32).unsqueeze(0) + + qf = DiscreteMLPQFunction(env_spec=env_spec, + hidden_nonlinearity=None, + hidden_sizes=hidden_sizes, + hidden_w_init=nn.init.ones_, + output_w_init=nn.init.ones_) + + output = qf(obs) + + expected_output = torch.full([1, 1], + fill_value=(obs_dim) * np.prod(hidden_sizes), + dtype=torch.float32) + assert torch.eq(output, expected_output).all() + + +# yapf: disable +@pytest.mark.parametrize('batch_size, hidden_sizes', [ + (1, (1, )), + (3, (2, )), + (9, (3, )), + (15, (1, 1)), + (22, (2, 2)), +]) +# yapf: enable +def test_output_shape(batch_size, hidden_sizes): + env_spec = GymEnv(DummyBoxEnv()).spec + obs_dim = env_spec.observation_space.flat_dim + obs = torch.ones(batch_size, obs_dim, dtype=torch.float32) + + qf = DiscreteMLPQFunction(env_spec=env_spec, + hidden_nonlinearity=None, + hidden_sizes=hidden_sizes, + hidden_w_init=nn.init.ones_, + output_w_init=nn.init.ones_) + output = qf(obs) + + assert output.shape == (batch_size, env_spec.action_space.flat_dim) + + +# yapf: disable +@pytest.mark.parametrize('hidden_sizes', [ + (1, ), (2, ), (3, ), (1, 5), (2, 7, 10)]) +# yapf: enable +def test_is_pickleable(hidden_sizes): + env_spec = GymEnv(DummyBoxEnv()).spec + obs_dim = env_spec.observation_space.flat_dim + obs = torch.ones(obs_dim, dtype=torch.float32).unsqueeze(0) + + qf = DiscreteMLPQFunction(env_spec=env_spec, + hidden_nonlinearity=None, + hidden_sizes=hidden_sizes, + hidden_w_init=nn.init.ones_, + output_w_init=nn.init.ones_) + + output1 = qf(obs) + + p = pickle.dumps(qf) + qf_pickled = pickle.loads(p) + output2 = qf_pickled(obs) + + assert torch.eq(output1, output2).all()