Skip to content

Commit

Permalink
Add two torch primitives (#1993)
Browse files Browse the repository at this point in the history
This adds the following primitives to torch:
  - DiscreteQFDerivedPolicy
  - DiscreteMLPQFunction
  • Loading branch information
maliesa96 authored Sep 14, 2020
1 parent 4b88bef commit dc7511c
Show file tree
Hide file tree
Showing 13 changed files with 317 additions and 28 deletions.
4 changes: 2 additions & 2 deletions docs/user/use_pretrained_network_to_start_new_experiment.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ from garage.experiment.deterministic import set_seed
from garage.np.exploration_policies import EpsilonGreedyPolicy
from garage.replay_buffer import PathBuffer
from garage.tf.algos import DQN
from garage.tf.policies import DiscreteQfDerivedPolicy
from garage.tf.policies import DiscreteQFArgmaxPolicy
from garage.trainer import TFTrainer

@click.command()
Expand Down Expand Up @@ -135,7 +135,7 @@ def dqn_pong(ctxt=None, seed=1, buffer_size=int(5e4), max_episode_length=500):
qf = snapshot['algo']._qf
# MARK: end modifications to existing example

policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec,
policy=policy,
total_timesteps=num_timesteps,
Expand Down
4 changes: 2 additions & 2 deletions examples/tf/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from garage.np.exploration_policies import EpsilonGreedyPolicy
from garage.replay_buffer import PathBuffer
from garage.tf.algos import DQN
from garage.tf.policies import DiscreteQfDerivedPolicy
from garage.tf.policies import DiscreteQFArgmaxPolicy
from garage.tf.q_functions import DiscreteMLPQFunction
from garage.trainer import TFTrainer

Expand All @@ -34,7 +34,7 @@ def dqn_cartpole(ctxt=None, seed=1):
env = GymEnv('CartPole-v0')
replay_buffer = PathBuffer(capacity_in_transitions=int(1e4))
qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64))
policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec,
policy=policy,
total_timesteps=num_timesteps,
Expand Down
4 changes: 2 additions & 2 deletions examples/tf/dqn_pong.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from garage.np.exploration_policies import EpsilonGreedyPolicy
from garage.replay_buffer import PathBuffer
from garage.tf.algos import DQN
from garage.tf.policies import DiscreteQfDerivedPolicy
from garage.tf.policies import DiscreteQFArgmaxPolicy
from garage.tf.q_functions import DiscreteCNNQFunction
from garage.trainer import TFTrainer

Expand Down Expand Up @@ -75,7 +75,7 @@ def dqn_pong(ctxt=None, seed=1, buffer_size=int(5e4), max_episode_length=500):
strides=(4, 2, 1),
dueling=False) # yapf: disable

policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec,
policy=policy,
total_timesteps=num_timesteps,
Expand Down
5 changes: 2 additions & 3 deletions src/garage/tf/policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from garage.tf.policies.categorical_lstm_policy import CategoricalLSTMPolicy
from garage.tf.policies.categorical_mlp_policy import CategoricalMLPPolicy
from garage.tf.policies.continuous_mlp_policy import ContinuousMLPPolicy
from garage.tf.policies.discrete_qf_derived_policy import (
DiscreteQfDerivedPolicy)
from garage.tf.policies.discrete_qf_argmax_policy import DiscreteQFArgmaxPolicy
from garage.tf.policies.gaussian_gru_policy import GaussianGRUPolicy
from garage.tf.policies.gaussian_lstm_policy import GaussianLSTMPolicy
from garage.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy
Expand All @@ -17,7 +16,7 @@
__all__ = [
'Policy', 'CategoricalCNNPolicy', 'CategoricalGRUPolicy',
'CategoricalLSTMPolicy', 'CategoricalMLPPolicy', 'ContinuousMLPPolicy',
'DiscreteQfDerivedPolicy', 'GaussianGRUPolicy', 'GaussianLSTMPolicy',
'DiscreteQFArgmaxPolicy', 'GaussianGRUPolicy', 'GaussianLSTMPolicy',
'GaussianMLPPolicy', 'GaussianMLPTaskEmbeddingPolicy',
'TaskEmbeddingPolicy'
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from garage.tf.policies.policy import Policy


class DiscreteQfDerivedPolicy(Module, Policy):
"""DiscreteQfDerived policy.
class DiscreteQFArgmaxPolicy(Module, Policy):
"""DiscreteQFArgmax policy.
Args:
env_spec (garage.envs.env_spec.EnvSpec): Environment specification.
Expand All @@ -20,10 +20,9 @@ class DiscreteQfDerivedPolicy(Module, Policy):
"""

def __init__(self, env_spec, qf, name='DiscreteQfDerivedPolicy'):
def __init__(self, env_spec, qf, name='DiscreteQFArgmaxPolicy'):
assert isinstance(env_spec.action_space, akro.Discrete), (
'DiscreteQfDerivedPolicy only supports akro.Discrete action spaces'
)
'DiscreteQFArgmaxPolicy only supports akro.Discrete action spaces')
if isinstance(env_spec.observation_space, akro.Dict):
raise ValueError('CNN policies do not support'
'with akro.Dict observation spaces.')
Expand Down
3 changes: 3 additions & 0 deletions src/garage/torch/policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
ContextConditionedPolicy)
from garage.torch.policies.deterministic_mlp_policy import (
DeterministicMLPPolicy)
from garage.torch.policies.discrete_qf_argmax_policy import (
DiscreteQFArgmaxPolicy)
from garage.torch.policies.gaussian_mlp_policy import GaussianMLPPolicy
from garage.torch.policies.policy import Policy
from garage.torch.policies.tanh_gaussian_mlp_policy import (
Expand All @@ -12,6 +14,7 @@
__all__ = [
'CategoricalCNNPolicy',
'DeterministicMLPPolicy',
'DiscreteQFArgmaxPolicy',
'GaussianMLPPolicy',
'Policy',
'TanhGaussianMLPPolicy',
Expand Down
68 changes: 68 additions & 0 deletions src/garage/torch/policies/discrete_qf_argmax_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""A Discrete QFunction-derived policy.
This policy chooses the action that yields to the largest Q-value.
"""
import numpy as np
import torch

from garage.torch.policies.policy import Policy


class DiscreteQFArgmaxPolicy(Policy):
"""Policy that derives its actions from a learned Q function.
The action returned is the one that yields the highest Q value for
a given state, as determined by the supplied Q function.
Args:
qf (object): Q network.
env_spec (EnvSpec): Environment specification.
name (str): Name of this policy.
"""

def __init__(self, qf, env_spec, name='DiscreteQFArgmaxPolicy'):
super().__init__(env_spec, name)
self._qf = qf

# pylint: disable=arguments-differ
def forward(self, observations):
"""Get actions corresponding to a batch of observations.
Args:
observations(torch.Tensor): Batch of observations of shape
:math:`(N, O)`. Observations should be flattened even
if they are images as the underlying Q network handles
unflattening.
Returns:
torch.Tensor: Batch of actions of shape :math:`(N, A)`
"""
qs = self._qf(observations)
return torch.argmax(qs, dim=1)

def get_action(self, observation):
"""Get a single action given an observation.
Args:
observation (np.ndarray): Observation with shape :math:`(O, )`.
Returns:
torch.Tensor: Predicted action with shape :math:`(A, )`.
dict: Empty since this policy does not produce a distribution.
"""
act, dist = self.get_actions(np.expand_dims(observation, axis=0))
return act[0], dist

def get_actions(self, observations):
"""Get actions given observations.
Args:
observations (np.ndarray): Batch of observations, should
have shape :math:`(N, O)`.
Returns:
torch.Tensor: Predicted actions. Tensor has shape :math:`(N, A)`.
dict: Empty since this policy does not produce a distribution.
"""
with torch.no_grad():
return self(torch.Tensor(observations)).numpy(), dict()
6 changes: 5 additions & 1 deletion src/garage/torch/q_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@
ContinuousMLPQFunction)
from garage.torch.q_functions.discrete_cnn_q_function import (
DiscreteCNNQFunction)
from garage.torch.q_functions.discrete_mlp_q_function import (
DiscreteMLPQFunction)

__all__ = ['ContinuousMLPQFunction', 'DiscreteCNNQFunction']
__all__ = [
'ContinuousMLPQFunction', 'DiscreteCNNQFunction', 'DiscreteMLPQFunction'
]
60 changes: 60 additions & 0 deletions src/garage/torch/q_functions/discrete_mlp_q_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""This modules creates a continuous Q-function network."""

from torch import nn
from torch.nn import functional as F

from garage.torch.modules import MLPModule


# pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305
# pylint: disable=abstract-method
class DiscreteMLPQFunction(MLPModule):
"""Implements a discrete MLP Q-value network.
It predicts the Q-value for all possible actions based on the
input state.
Args:
env_spec (EnvSpec): Environment specification.
hidden_sizes (list[int]): Output dimension of dense layer(s).
For example, (32, 32) means this MLP consists of two
hidden layers, each with 32 hidden units.
hidden_nonlinearity (callable or torch.nn.Module): Activation function
for intermediate dense layer(s). It should return a torch.Tensor.
Set it to None to maintain a linear activation.
hidden_w_init (callable): Initializer function for the weight
of intermediate dense layer(s). The function should return a
torch.Tensor.
hidden_b_init (callable): Initializer function for the bias
of intermediate dense layer(s). The function should return a
torch.Tensor.
output_nonlinearity (callable or torch.nn.Module): Activation function
for output dense layer. It should return a torch.Tensor.
Set it to None to maintain a linear activation.
output_w_init (callable): Initializer function for the weight
of output dense layer(s). The function should return a
torch.Tensor.
output_b_init (callable): Initializer function for the bias
of output dense layer(s). The function should return a
torch.Tensor.
layer_normalization (bool): Bool for using layer normalization or not.
"""

def __init__(self,
env_spec,
hidden_sizes,
hidden_nonlinearity=F.relu,
hidden_w_init=nn.init.xavier_normal_,
hidden_b_init=nn.init.zeros_,
output_nonlinearity=None,
output_w_init=nn.init.xavier_normal_,
output_b_init=nn.init.zeros_,
layer_normalization=False):

input_dim = env_spec.observation_space.flat_dim
output_dim = env_spec.action_space.flat_dim
super().__init__(input_dim, output_dim, hidden_sizes,
hidden_nonlinearity, hidden_w_init, hidden_b_init,
output_nonlinearity, output_w_init, output_b_init,
layer_normalization)
10 changes: 5 additions & 5 deletions tests/garage/tf/algos/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from garage.np.exploration_policies import EpsilonGreedyPolicy
from garage.replay_buffer import PathBuffer
from garage.tf.algos import DQN
from garage.tf.policies import DiscreteQfDerivedPolicy
from garage.tf.policies import DiscreteQFArgmaxPolicy
from garage.tf.q_functions import DiscreteMLPQFunction
from garage.trainer import TFTrainer

Expand All @@ -34,7 +34,7 @@ def test_dqn_cartpole(self):
env = GymEnv('CartPole-v0')
replay_buffer = PathBuffer(capacity_in_transitions=int(1e4))
qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64))
policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
epilson_greedy_policy = EpsilonGreedyPolicy(
env_spec=env.spec,
policy=policy,
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_dqn_cartpole_double_q(self):
env = GymEnv('CartPole-v0')
replay_buffer = PathBuffer(capacity_in_transitions=int(1e4))
qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64))
policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
epilson_greedy_policy = EpsilonGreedyPolicy(
env_spec=env.spec,
policy=policy,
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_dqn_cartpole_grad_clip(self):
env = GymEnv('CartPole-v0')
replay_buffer = PathBuffer(capacity_in_transitions=int(1e4))
qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64))
policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
epilson_greedy_policy = EpsilonGreedyPolicy(
env_spec=env.spec,
policy=policy,
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_dqn_cartpole_pickle(self):
env = GymEnv('CartPole-v0')
replay_buffer = PathBuffer(capacity_in_transitions=int(1e4))
qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64))
policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
epilson_greedy_policy = EpsilonGreedyPolicy(
env_spec=env.spec,
policy=policy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from garage.envs import GymEnv
from garage.envs.wrappers import AtariEnv
from garage.tf.policies import DiscreteQfDerivedPolicy
from garage.tf.policies import DiscreteQFArgmaxPolicy
from garage.tf.q_functions import DiscreteCNNQFunction

# yapf: disable
Expand All @@ -24,12 +24,12 @@ def setup_method(self):
super().setup_method()
self.env = GymEnv(DummyDiscreteEnv())
self.qf = SimpleQFunction(self.env.spec)
self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec,
qf=self.qf)
self.policy = DiscreteQFArgmaxPolicy(env_spec=self.env.spec,
qf=self.qf)
self.sess.run(tf.compat.v1.global_variables_initializer())
self.env.reset()

def test_discrete_qf_derived_policy(self):
def test_discrete_qf_argmax_policy(self):
obs = self.env.step(1).observation
action, _ = self.policy.get_action(obs)
assert self.env.action_space.contains(action)
Expand Down Expand Up @@ -62,14 +62,14 @@ def test_does_not_support_dict_obs_space(self):
with pytest.raises(ValueError):
qf = SimpleQFunction(env.spec,
name='does_not_support_dict_obs_space')
DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)

def test_invalid_action_spaces(self):
"""Test that policy raises error if passed a dict obs space."""
env = GymEnv(DummyDictEnv(act_space_type='box'))
with pytest.raises(ValueError):
qf = SimpleQFunction(env.spec)
DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)


class TestQfDerivedPolicyImageObs(TfGraphTestCase):
Expand All @@ -82,8 +82,8 @@ def setup_method(self):
filters=((1, (1, 1)), ),
strides=(1, ),
dueling=False)
self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec,
qf=self.qf)
self.policy = DiscreteQFArgmaxPolicy(env_spec=self.env.spec,
qf=self.qf)
self.sess.run(tf.compat.v1.global_variables_initializer())
self.env.reset()

Expand Down
Loading

0 comments on commit dc7511c

Please sign in to comment.