diff --git a/saxml/server/BUILD b/saxml/server/BUILD index 08a2fd2..204226b 100644 --- a/saxml/server/BUILD +++ b/saxml/server/BUILD @@ -312,6 +312,9 @@ pytype_strict_library( ":model_service_base", ":servable_model_registry", ":spmd_backend", + "//saxml/client/python:sax", + "//saxml/protobuf:admin_py_pb2", + "//saxml/protobuf:admin_py_pb2_grpc", "//saxml/protobuf:modelet_py_pb2", "//saxml/protobuf:modelet_py_pb2_grpc", "//saxml/server/jax:jax_spmd_backend", @@ -321,6 +324,7 @@ pytype_strict_library( "//third_party/py/grpcio", "//third_party/py/jax", "//third_party/py/tensorflow:tensorflow_no_contrib", + "@pybind11_abseil//pybind11_abseil:status", ], ) diff --git a/saxml/server/model_service_main.py b/saxml/server/model_service_main.py index 3a4d2e0..ef4f017 100644 --- a/saxml/server/model_service_main.py +++ b/saxml/server/model_service_main.py @@ -22,6 +22,7 @@ import grpc import jax from jax.experimental.compilation_cache import compilation_cache +from saxml.client.python import sax from saxml.protobuf import modelet_pb2 from saxml.protobuf import modelet_pb2_grpc from saxml.server import model_service_base @@ -29,6 +30,9 @@ from saxml.server import spmd_backend import tensorflow as tf +from google3.third_party.pybind11_abseil import status as absl_status + + _SAX_CELL = flags.DEFINE_string( 'sax_cell', None, @@ -119,11 +123,23 @@ def _load_static_model( model_key: str, checkpoint: str, channel_creds: Optional[grpc.ChannelCredentials], + sax_cell: Optional[str], ) -> None: """Loads statically specified model to a started service.""" logging.info( 'Loading key %s, model %s, checkpoint %s.', model_key, model, checkpoint ) + # Get overrides that might have been provided via 'saxutil publish' and apply + # them. + overrides = {} + if sax_cell: + try: + overrides = sax.ListDetail(model_key).overrides + logging.info('Got overrides: %s', overrides) + except absl_status.StatusNotOk as e: + logging.warning( + "Could not get model details, not applying overrides: '%s'", e + ) if channel_creds is None: channel = grpc.insecure_channel(f'localhost:{port}') else: @@ -132,7 +148,8 @@ def _load_static_model( grpc.channel_ready_future(channel).result(timeout=10) stub = modelet_pb2_grpc.ModeletStub(channel) req = modelet_pb2.LoadRequest( - model_key=model_key, model_path=model, checkpoint_path=checkpoint + model_key=model_key, model_path=model, checkpoint_path=checkpoint, + overrides=overrides, ) try: stub.Load(req) @@ -235,7 +252,8 @@ def run(channel_creds: Optional[grpc.ChannelCredentials]) -> None: for model, key, ckpt in zip( _MODELS.value, _MODEL_KEYS.value, _CHECKPOINTS.value ): - _load_static_model(_PORT.value, model, key, ckpt, channel_creds) + _load_static_model(_PORT.value, model, key, ckpt, channel_creds, + _SAX_CELL.value) runner.on_initial_models_load_completion() runner.wait() finally: