Skip to content

Commit

Permalink
Make apply_model_overrides a classmethod
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697795928
Change-Id: I3d6da3f128348d89945d4c980d0a3d8b18ecca2c
  • Loading branch information
alanwaketan authored and copybara-github committed Nov 19, 2024
1 parent c66d36d commit 1b0f132
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
2 changes: 1 addition & 1 deletion saxml/server/model_service_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,8 +742,8 @@ def load(
if not issubclass(model_class, servable_model_params.ServableModelParams):
raise ValueError(f'{model_path} is not a ServableModelParams')
# pytype: disable=not-instantiable
model_class = model_class.apply_model_overrides(overrides)
params = model_class()
params.apply_model_overrides(overrides)
loaded = params.load(key, ckpt_path, self._primary_process_id, prng_key)
# pytype: enable=not-instantiable
loaded.set_acls(acls)
Expand Down
26 changes: 18 additions & 8 deletions saxml/server/servable_model_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Base classes for servable model and method config classes."""

import abc
import copy
import json
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -117,9 +118,10 @@ def sax_registration_name(cls) -> Optional[str]:
"""Returns an optional custom registration name for the model."""
return None

# TODO(jwtan): Make this a classmethod as most of the config attributes are
# class attributes.
def apply_model_overrides(self, overrides: Dict[str, Any]) -> None:
@classmethod
def apply_model_overrides(
cls, overrides: Dict[str, Any]
) -> type['ServableModelParams']:
"""Applies model config overrides received from Publish.
The default handling of overrides is as follows:
Expand All @@ -132,21 +134,29 @@ def apply_model_overrides(self, overrides: Dict[str, Any]) -> None:
Args:
overrides: Model config key-value pairs supplied by the Publish command.
Returns:
A new ServableMethodParams instance with overrides applied.
"""
new_cls = copy.deepcopy(cls)
for k, v_raw in overrides.items():
if not hasattr(self, k):
logging.warning("Can't override %s because it's not set on %s", k, self)
if not hasattr(new_cls, k):
logging.warning(
"Can't override %s because it's not set on %s", k, new_cls
)
continue
try:
v = json.loads(v_raw)
except Exception as e: # pylint: disable=broad-exception-caught
logging.warning('Not a valid json value: %s %s', v_raw, e)
continue
cur_v = getattr(self, k)
cur_v = getattr(new_cls, k)
if v is not None and cur_v is not None and type(v) != type(cur_v): # pylint: disable=unidiomatic-typecheck
raise ValueError(
'Mismatched type of override: original: %s; override: %s'
% (cur_v, v)
)
setattr(self, k, v)
logging.info('Set override %s to %s on %s', k, v, self)
setattr(new_cls, k, v)
logging.info('Set override %s to %s on %s', k, v, new_cls)

return new_cls
6 changes: 3 additions & 3 deletions saxml/server/servable_model_params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ class ServableModelParamsTest(absltest.TestCase):
def setUp(self):
super().setUp()
servable_model_params.ServableModelParams.__abstractmethods__ = set()
self.params = servable_model_params.ServableModelParams()
self.params = servable_model_params.ServableModelParams

def test_overrides(self):
params = self.params
params.INT_KEY = 42
params.STR_KEY = "hi there"
params.LIST_KEY = [128, 256]
params.ANOTHER_LIST_KEY = [1, 2]
params.apply_model_overrides(dict(
params = params.apply_model_overrides(dict(
INT_KEY="100",
STR_KEY="\"foo\"",
LIST_KEY="[55, 65, 75]",
Expand All @@ -43,7 +43,7 @@ def test_overrides(self):
def test_skip_on_missing_field(self):
params = self.params
params.INT_KEY = 42
params.apply_model_overrides(dict(ANOTHER_INT_KEY="100",))
params = params.apply_model_overrides(dict(ANOTHER_INT_KEY="100",))
self.assertEqual(params.INT_KEY, 42)

def test_exception_on_different_type(self):
Expand Down

0 comments on commit 1b0f132

Please sign in to comment.