diff --git a/trulens/nn/models/__init__.py b/trulens/nn/models/__init__.py index ae23f4873..08c03da07 100644 --- a/trulens/nn/models/__init__.py +++ b/trulens/nn/models/__init__.py @@ -74,6 +74,7 @@ def get_model_wrapper( default_feed_dict=None, session=None, backend=None, + force_eval=True, **kwargs ): """ @@ -136,6 +137,10 @@ def get_model_wrapper( backend: _Optional, for forcing a specific backend._ String values recognized are pytorch, tensorflow, keras, or tf.keras. + + force_eval: + _Optional, True will force a model.eval() call for PyTorch models. False + will retain current model state Returns: ModelWrapper """ @@ -189,7 +194,10 @@ def get_model_wrapper( elif B.backend == Backend.PYTORCH: from trulens.nn.models.pytorch import PytorchModelWrapper return PytorchModelWrapper( - model, logit_layer=logit_layer, device=device + model, + logit_layer=logit_layer, + device=device, + force_eval=force_eval ) elif B.backend == Backend.TENSORFLOW: import tensorflow as tf diff --git a/trulens/nn/models/pytorch.py b/trulens/nn/models/pytorch.py index cad6ddeec..339ae5bcb 100644 --- a/trulens/nn/models/pytorch.py +++ b/trulens/nn/models/pytorch.py @@ -33,7 +33,15 @@ class PytorchModelWrapper(ModelWrapper): of Pytorch nn.Module objects. """ - def __init__(self, model, *, logit_layer=None, device=None, **kwargs): + def __init__( + self, + model, + *, + logit_layer=None, + device=None, + force_eval=True, + **kwargs + ): """ __init__ Constructor @@ -46,6 +54,9 @@ def __init__(self, model, *, logit_layer=None, device=None, **kwargs): layer named 'logits' is the logit layer. device : string, optional device on which to run model, by default None + force_eval : bool, optional + If True, will call model.eval() to ensure determinism. Otherwise, keeps current model state, by default True + """ if 'input_shape' in kwargs: @@ -61,8 +72,9 @@ def __init__(self, model, *, logit_layer=None, device=None, **kwargs): super().__init__(model, **kwargs) # sets self._model, issues cross-backend messages - - model.eval() + self.force_eval = force_eval + if self.force_eval: + model.eval() if device is None: device = pytorch.get_default_device() @@ -337,7 +349,8 @@ def hookfn(self, inpt, outpt): with memory_suggestions(device=self.device): # Run the network. try: - self._model.eval() # needed for determinism sometimes + if self.force_eval: + self._model.eval() # needed for determinism sometimes output = model_inputs.call_on(self._model) if isinstance(output, tuple):