From 1586b0fb7c94594e114d775c0a872c8932fbe294 Mon Sep 17 00:00:00 2001 From: Christopher Yeh Date: Tue, 19 Nov 2024 03:25:07 +0000 Subject: [PATCH 1/2] Add typing annotations to gpytorch.Module --- gpytorch/module.py | 233 +++++++++++++++++++++++---------------------- 1 file changed, 121 insertions(+), 112 deletions(-) diff --git a/gpytorch/module.py b/gpytorch/module.py index ff431a421..3c8b9bce1 100644 --- a/gpytorch/module.py +++ b/gpytorch/module.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 +from __future__ import annotations + import copy import inspect import itertools import operator -from collections import OrderedDict -from typing import Union +from typing import Callable, Iterator, Mapping, MutableSet, Optional, TypeVar, Union import torch from linear_operator.operators import LinearOperator @@ -13,14 +14,64 @@ from torch.distributions import Distribution from .constraints import Interval +from .priors import Prior + + +NnModuleSelf = TypeVar("NnModuleSelf", bound=nn.Module) # TODO: replace w/ typing.Self in Python 3.11 +ModuleSelf = TypeVar("ModuleSelf", bound="Module") # TODO: replace w/ typing.Self in Python 3.11 +RandomModuleSelf = TypeVar("RandomModuleSelf", bound="RandomModuleMixin") # TODO: replace w/ typing.Self in Python 3.11 + +Closure = Callable[[nn.Module], Tensor] +SettingClosure = Callable[[ModuleSelf, Union[Tensor, float]], ModuleSelf] +SamplesDict = Mapping[str, Union[Tensor, float]] + + +class RandomModuleMixin: + def initialize(self: RandomModuleSelf, **kwargs) -> RandomModuleSelf: + """ + Set a value for a parameter + + kwargs: (param_name, value) - parameter to initialize. + Can also initialize recursively by passing in the full name of a + parameter. For example if model has attribute model.likelihood, + we can initialize the noise with either + `model.initialize(**{'likelihood.noise': 0.1})` + or + `model.likelihood.initialize(noise=0.1)`. + The former method would allow users to more easily store the + initialization values as one object. + + Value must be a Tensor + """ + for name, value in kwargs.items(): + if not isinstance(value, Tensor): + raise RuntimeError("Initialize in RandomModules can only be done with Tensor values.") + + names = name.rsplit(".") + if len(names) > 1: + mod_name, param_name = names + mod = operator.attrgetter(mod_name)(self) + else: + mod, param_name = self, name + + old_param = getattr(mod, param_name) + is_property = hasattr(type(self), name) and isinstance(getattr(type(self), name), property) + if not isinstance(old_param, torch.nn.Parameter) or is_property: + # Presumably we're calling a getter that will call initialize again on the actual parameter. + setattr(mod, param_name, value.expand(old_param.shape)) + else: + delattr(mod, param_name) + setattr(mod, param_name, value.expand(old_param.shape)) + + return self class Module(nn.Module): def __init__(self): super().__init__() - self._added_loss_terms = OrderedDict() - self._priors = OrderedDict() - self._constraints = OrderedDict() + self._added_loss_terms = {} + self._priors: dict[str, tuple[Prior, Closure, Optional[SettingClosure]]] = {} + self._constraints: dict[str, Interval] = {} self._strict_init = True self._load_strict_shapes = True @@ -40,7 +91,7 @@ def _clear_cache(self): """ pass - def _get_module_and_name(self, parameter_name): + def _get_module_and_name(self, parameter_name: str) -> tuple[nn.Module, str]: """Get module and name from full parameter name.""" module, name = parameter_name.split(".", 1) if module in self._modules: @@ -50,7 +101,7 @@ def _get_module_and_name(self, parameter_name): "Invalid parameter name {}. {} has no module {}".format(parameter_name, type(self).__name__, module) ) - def _strict(self, value): + def _strict(self, value: bool) -> None: _set_strict(self, value) def added_loss_terms(self): @@ -68,7 +119,7 @@ def hyperparameters(self): for _, param in self.named_hyperparameters(): yield param - def initialize(self, **kwargs): + def initialize(self: ModuleSelf, **kwargs) -> ModuleSelf: """ Set a value for a parameter @@ -98,7 +149,7 @@ def initialize(self, **kwargs): raise AttributeError("Unknown parameter {p} for {c}".format(p=name, c=self.__class__.__name__)) elif name not in self._parameters and name not in self._buffers: setattr(self, name, val) - elif torch.is_tensor(val): + elif isinstance(val, Tensor): constraint = self.constraint_for_parameter_name(name) if constraint is not None and constraint.enforced and not constraint.check_raw(val): raise RuntimeError( @@ -158,7 +209,7 @@ def named_hyperparameters(self): for elem in module.named_parameters(prefix=module_prefix, recurse=False): yield elem - def named_priors(self, memo=None, prefix=""): + def named_priors(self) -> Iterator[tuple[str, nn.Module, Prior, Closure, SettingClosure | None]]: """Returns an iterator over the module's priors, yielding the name of the prior, the prior, the associated parameter names, and the transformation callable. @@ -172,7 +223,7 @@ def named_priors(self, memo=None, prefix=""): """ return _extract_named_priors(module=self, prefix="") - def named_constraints(self, memo=None, prefix=""): + def named_constraints(self) -> Iterator[tuple[str, Interval]]: return _extract_named_constraints(module=self, memo=None, prefix="") def named_variational_parameters(self): @@ -186,30 +237,22 @@ def named_variational_parameters(self): def register_added_loss_term(self, name): self._added_loss_terms[name] = None - def register_parameter(self, name, parameter): - r""" - Adds a parameter to the module. The parameter can be accessed as an attribute using the given name. - - Args: - name (str): - The name of the parameter - parameter (torch.nn.Parameter): - The parameter - """ - if "_parameters" not in self.__dict__: - raise AttributeError("Cannot assign parameter before Module.__init__() call") - super().register_parameter(name, parameter) - - def register_prior(self, name, prior, param_or_closure, setting_closure=None): + def register_prior( + self, + name: str, + prior: Prior, + param_or_closure: Union[str, Closure], + setting_closure: Optional[SettingClosure] = None, + ) -> None: """ Adds a prior to the module. The prior can be accessed as an attribute using the given name. Args: - name (str): + name: The name of the prior - prior (Prior): + prior: The prior to be registered` - param_or_closure (string or callable): + param_or_closure: Either the name of the parameter, or a closure (which upon calling evalutes a function on the module instance and one or more parameters): single parameter without a transform: `.register_prior("foo_prior", foo_prior, "foo_param")` @@ -217,33 +260,36 @@ def register_prior(self, name, prior, param_or_closure, setting_closure=None): `.register_prior("foo_prior", NormalPrior(0, 1), lambda module: torch.log(module.foo_param))` function of multiple parameters: `.register_prior("foo2_prior", foo2_prior, lambda module: f(module.param1, module.param2)))` - setting_closure (callable, optional): + setting_closure: A function taking in the module instance and a tensor in (transformed) parameter space, initializing the internal parameter representation to the proper value by applying the inverse transform. Enables setting parametres directly in the transformed space, as well as sampling parameter values from priors (see `sample_from_prior`) - """ if isinstance(param_or_closure, str): - if param_or_closure not in self._parameters and not hasattr(self, param_or_closure): + param = param_or_closure + if param not in self._parameters and not hasattr(self, param): raise AttributeError( - "Unknown parameter {name} for {module}".format( - name=param_or_closure, module=self.__class__.__name__ - ) + "Unknown parameter {name} for {module}".format(name=param, module=self.__class__.__name__) + " Make sure the parameter is registered before registering a prior." ) - def closure(module): - return getattr(module, param_or_closure) + def closure_new(module: nn.Module) -> Tensor: + return getattr(module, param) + + closure = closure_new if setting_closure is not None: raise RuntimeError("Must specify a closure instead of a parameter name when providing setting_closure") - def setting_closure(module, val): - return module.initialize(**{param_or_closure: val}) + def setting_closure_new(module: ModuleSelf, val: Union[Tensor, float]) -> ModuleSelf: + return module.initialize(**{param: val}) + + setting_closure = setting_closure_new else: - if len(inspect.signature(param_or_closure).parameters) == 0: + closure = param_or_closure + if len(inspect.signature(closure).parameters) == 0: raise ValueError( """As of version 1.4, `param_or_closure` must operate on a module instance. For example: @@ -266,12 +312,11 @@ def setting_closure(module, val): ) """ ) - closure = param_or_closure self.add_module(name, prior) self._priors[name] = (prior, closure, setting_closure) - def register_constraint(self, param_name, constraint, replace=True): + def register_constraint(self, param_name: str, constraint: Interval, replace: bool = True) -> None: if param_name not in self._parameters: raise RuntimeError("Attempting to register constraint for nonexistent parameter.") @@ -299,7 +344,7 @@ def train(self, mode=True): self._clear_cache() return super().train(mode=mode) - def constraint_for_parameter_name(self, param_name): + def constraint_for_parameter_name(self, param_name: str) -> Interval | None: base_module = self base_name = param_name @@ -344,11 +389,11 @@ def apply_fn(module): self.apply(apply_fn) - def named_parameters_and_constraints(self): + def named_parameters_and_constraints(self) -> Iterator[tuple[str, nn.Parameter, Interval | None]]: for name, param in self.named_parameters(): yield name, param, self.constraint_for_parameter_name(name) - def sample_from_prior(self, prior_name): + def sample_from_prior(self, prior_name: str) -> None: """Sample parameter values from prior. Modifies the module's parameters in-place.""" if prior_name not in self._priors: raise RuntimeError("Unknown prior name '{}'".format(prior_name)) @@ -357,10 +402,10 @@ def sample_from_prior(self, prior_name): raise RuntimeError("Must provide inverse transform to be able to sample from prior.") setting_closure(self, prior.sample()) - def to_pyro_random_module(self): + def to_pyro_random_module(self) -> Module: return self.to_random_module() - def to_random_module(self): + def to_random_module(self) -> Module: random_module_cls = type("_Random" + self.__class__.__name__, (RandomModuleMixin, self.__class__), {}) if not isinstance(self, random_module_cls): new_module = copy.deepcopy(self) @@ -375,7 +420,7 @@ def to_random_module(self): return new_module - def pyro_sample_from_prior(self): + def pyro_sample_from_prior(self) -> Module: """ For each parameter in this Module and submodule that have defined priors, sample a value for that parameter from its corresponding prior with a pyro.sample primitive and load the resulting value in to the parameter. @@ -386,7 +431,7 @@ def pyro_sample_from_prior(self): new_module = self.to_pyro_random_module() return _pyro_sample_from_prior(module=new_module, memo=None, prefix="") - def local_load_samples(self, samples_dict, memo, prefix): + def local_load_samples(self, samples_dict: SamplesDict, memo: MutableSet[str], prefix: str) -> None: """ Defines local behavior of this Module when loading parameters from a samples_dict generated by a Pyro sampling mechanism. @@ -396,13 +441,15 @@ def local_load_samples(self, samples_dict, memo, prefix): acquire an extra batch dimension corresponding to the number of samples drawn. """ self._strict(False) - for name, (prior, closure, setting_closure) in self._priors.items(): + for name, (prior, _, setting_closure) in self._priors.items(): if prior is not None and prior not in memo: memo.add(prior) + if setting_closure is None: + raise RuntimeError("Must provide setting_closure to load samples.") setting_closure(self, samples_dict[prefix + ("." if prefix else "") + name]) self._strict(True) - def pyro_load_from_samples(self, samples_dict): + def pyro_load_from_samples(self, samples_dict: SamplesDict) -> None: """ Convert this Module in to a batch Module by loading parameters from the given `samples_dict`. `samples_dict` is typically produced by a Pyro sampling mechanism. @@ -412,9 +459,9 @@ def pyro_load_from_samples(self, samples_dict): the prior to properly set the unconstrained parameter. Args: - samples_dict (dict): Dictionary mapping *prior names* to sample values. + samples_dict: Dictionary mapping *prior names* to sample values. """ - return _pyro_load_from_samples(module=self, samples_dict=samples_dict, memo=None, prefix="") + _pyro_load_from_samples(module=self, samples_dict=samples_dict, memo=None, prefix="") def update_added_loss_term(self, name, added_loss_term): from .mlls import AddedLossTerm @@ -432,29 +479,23 @@ def variational_parameters(self): def _validate_module_outputs(outputs): if isinstance(outputs, tuple): - if not all( - torch.is_tensor(output) or isinstance(output, Distribution) or isinstance(output, LinearOperator) - for output in outputs - ): + if not all(isinstance(output, (Tensor, Distribution, LinearOperator)) for output in outputs): raise RuntimeError( - "All outputs must be a Distribution, torch.Tensor, or LinearOperator. " + "All outputs must be a torch.Tensor, Distribution, or LinearOperator. " "Got {}".format([output.__class__.__name__ for output in outputs]) ) if len(outputs) == 1: outputs = outputs[0] return outputs - elif torch.is_tensor(outputs) or isinstance(outputs, Distribution) or isinstance(outputs, LinearOperator): + elif isinstance(outputs, (Tensor, Distribution, LinearOperator)): return outputs else: raise RuntimeError( - "Output must be a Distribution, torch.Tensor, or LinearOperator. Got {}".format(outputs.__class__.__name__) + "Output must be a torch.Tensor, Distribution, or LinearOperator. Got {}".format(outputs.__class__.__name__) ) -def _set_strict(module, value, memo=None): - if memo is None: - memo = set() - +def _set_strict(module: nn.Module, value: bool) -> None: if hasattr(module, "_strict_init"): module._strict_init = value @@ -462,7 +503,9 @@ def _set_strict(module, value, memo=None): _set_strict(module_, value) -def _pyro_sample_from_prior(module, memo=None, prefix=""): +def _pyro_sample_from_prior( + module: NnModuleSelf, memo: Optional[MutableSet[str]] = None, prefix: str = "" +) -> NnModuleSelf: try: import pyro except ImportError: @@ -470,7 +513,7 @@ def _pyro_sample_from_prior(module, memo=None, prefix=""): if memo is None: memo = set() - if hasattr(module, "_priors"): + if isinstance(module, Module): for prior_name, (prior, closure, setting_closure) in module._priors.items(): if prior is not None and prior not in memo: if setting_closure is None: @@ -490,10 +533,12 @@ def _pyro_sample_from_prior(module, memo=None, prefix=""): return module -def _pyro_load_from_samples(module, samples_dict, memo=None, prefix=""): +def _pyro_load_from_samples( + module: nn.Module, samples_dict: SamplesDict, memo: Optional[MutableSet[str]] = None, prefix: str = "" +) -> None: if memo is None: memo = set() - if hasattr(module, "_priors"): + if isinstance(module, Module): module.local_load_samples(samples_dict, memo, prefix) for mname, module_ in module.named_children(): @@ -515,8 +560,10 @@ def _extract_named_added_loss_terms(module, memo=None, prefix=""): yield name, strategy -def _extract_named_priors(module, prefix=""): - if hasattr(module, "_priors"): +def _extract_named_priors( + module: nn.Module, prefix: str = "" +) -> Iterator[tuple[str, nn.Module, Prior, Closure, SettingClosure | None]]: + if isinstance(module, Module): for name, (prior, closure, inv_closure) in module._priors.items(): if prior is not None: full_name = ("." if prefix else "").join([prefix, name]) @@ -527,10 +574,12 @@ def _extract_named_priors(module, prefix=""): yield name, parent_module, prior, closure, inv_closure -def _extract_named_constraints(module, memo=None, prefix=""): +def _extract_named_constraints( + module: nn.Module, memo: Optional[MutableSet[Interval]] = None, prefix: str = "" +) -> Iterator[tuple[str, Interval]]: if memo is None: memo = set() - if hasattr(module, "_constraints"): + if isinstance(module, Module): for name, constraint in module._constraints.items(): if constraint is not None and constraint not in memo: memo.add(constraint) @@ -540,43 +589,3 @@ def _extract_named_constraints(module, memo=None, prefix=""): submodule_prefix = prefix + ("." if prefix else "") + mname for name, constraint in _extract_named_constraints(module_, memo=memo, prefix=submodule_prefix): yield name, constraint - - -class RandomModuleMixin(object): - def initialize(self, **kwargs): - """ - Set a value for a parameter - - kwargs: (param_name, value) - parameter to initialize. - Can also initialize recursively by passing in the full name of a - parameter. For example if model has attribute model.likelihood, - we can initialize the noise with either - `model.initialize(**{'likelihood.noise': 0.1})` - or - `model.likelihood.initialize(noise=0.1)`. - The former method would allow users to more easily store the - initialization values as one object. - - Value can take the form of a tensor, a float, or an int - """ - for name, value in kwargs.items(): - if not torch.is_tensor(value): - raise RuntimeError("Initialize in RandomModules can only be done with tensor values.") - - names = name.rsplit(".") - if len(names) > 1: - mod_name, param_name = names - mod = operator.attrgetter(mod_name)(self) - else: - mod, param_name = self, name - - old_param = getattr(mod, param_name) - is_property = hasattr(type(self), name) and isinstance(getattr(type(self), name), property) - if not isinstance(old_param, torch.nn.Parameter) or is_property: - # Presumably we're calling a getter that will call initialize again on the actual parameter. - setattr(mod, param_name, value.expand(old_param.shape)) - else: - delattr(mod, param_name) - setattr(mod, param_name, value.expand(old_param.shape)) - - return self From 02ac961c6e5b6a3a1280c4372f07361f294a1a96 Mon Sep 17 00:00:00 2001 From: Christopher Yeh Date: Tue, 19 Nov 2024 05:50:31 +0000 Subject: [PATCH 2/2] Undo removal of Module.register_parameter() --- gpytorch/module.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/gpytorch/module.py b/gpytorch/module.py index 3c8b9bce1..57550755a 100644 --- a/gpytorch/module.py +++ b/gpytorch/module.py @@ -237,6 +237,18 @@ def named_variational_parameters(self): def register_added_loss_term(self, name): self._added_loss_terms[name] = None + def register_parameter(self, name: str, parameter: Optional[nn.Parameter]) -> None: + r""" + Adds a parameter to the module. The parameter can be accessed as an attribute using the given name. + + Args: + name: + The name of the parameter + parameter: + The parameter + """ + super().register_parameter(name, parameter) + def register_prior( self, name: str,