Skip to content

Commit

Permalink
Optional param to call model.eval() (#82)
Browse files Browse the repository at this point in the history
* param to toggle forced model.eval()

* brevity

* extend to get_model_wrapper
  • Loading branch information
coreyhu authored Jul 5, 2022
1 parent bcf1028 commit aba5933
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
10 changes: 9 additions & 1 deletion trulens/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_model_wrapper(
default_feed_dict=None,
session=None,
backend=None,
force_eval=True,
**kwargs
):
"""
Expand Down Expand Up @@ -136,6 +137,10 @@ def get_model_wrapper(
backend:
_Optional, for forcing a specific backend._ String values recognized
are pytorch, tensorflow, keras, or tf.keras.
force_eval:
_Optional, True will force a model.eval() call for PyTorch models. False
will retain current model state
Returns: ModelWrapper
"""
Expand Down Expand Up @@ -189,7 +194,10 @@ def get_model_wrapper(
elif B.backend == Backend.PYTORCH:
from trulens.nn.models.pytorch import PytorchModelWrapper
return PytorchModelWrapper(
model, logit_layer=logit_layer, device=device
model,
logit_layer=logit_layer,
device=device,
force_eval=force_eval
)
elif B.backend == Backend.TENSORFLOW:
import tensorflow as tf
Expand Down
21 changes: 17 additions & 4 deletions trulens/nn/models/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@ class PytorchModelWrapper(ModelWrapper):
of Pytorch nn.Module objects.
"""

def __init__(self, model, *, logit_layer=None, device=None, **kwargs):
def __init__(
self,
model,
*,
logit_layer=None,
device=None,
force_eval=True,
**kwargs
):
"""
__init__ Constructor
Expand All @@ -46,6 +54,9 @@ def __init__(self, model, *, logit_layer=None, device=None, **kwargs):
layer named 'logits' is the logit layer.
device : string, optional
device on which to run model, by default None
force_eval : bool, optional
If True, will call model.eval() to ensure determinism. Otherwise, keeps current model state, by default True
"""

if 'input_shape' in kwargs:
Expand All @@ -61,8 +72,9 @@ def __init__(self, model, *, logit_layer=None, device=None, **kwargs):

super().__init__(model, **kwargs)
# sets self._model, issues cross-backend messages

model.eval()
self.force_eval = force_eval
if self.force_eval:
model.eval()

if device is None:
device = pytorch.get_default_device()
Expand Down Expand Up @@ -337,7 +349,8 @@ def hookfn(self, inpt, outpt):
with memory_suggestions(device=self.device):
# Run the network.
try:
self._model.eval() # needed for determinism sometimes
if self.force_eval:
self._model.eval() # needed for determinism sometimes
output = model_inputs.call_on(self._model)

if isinstance(output, tuple):
Expand Down

0 comments on commit aba5933

Please sign in to comment.