Skip to content

Commit

Permalink
Query and forward parameter overrides.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 661328902
Change-Id: If987cd214ce059cea652990943e6ea495bd8af19
  • Loading branch information
Sax Authors authored and copybara-github committed Aug 9, 2024
1 parent aa563db commit f1dd7c1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
4 changes: 4 additions & 0 deletions saxml/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ pytype_strict_library(
":model_service_base",
":servable_model_registry",
":spmd_backend",
"//saxml/client/python:sax",
"//saxml/protobuf:admin_py_pb2",
"//saxml/protobuf:admin_py_pb2_grpc",
"//saxml/protobuf:modelet_py_pb2",
"//saxml/protobuf:modelet_py_pb2_grpc",
"//saxml/server/jax:jax_spmd_backend",
Expand All @@ -321,6 +324,7 @@ pytype_strict_library(
"//third_party/py/grpcio",
"//third_party/py/jax",
"//third_party/py/tensorflow:tensorflow_no_contrib",
"@pybind11_abseil//pybind11_abseil:status",
],
)

Expand Down
22 changes: 20 additions & 2 deletions saxml/server/model_service_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@
import grpc
import jax
from jax.experimental.compilation_cache import compilation_cache
from saxml.client.python import sax
from saxml.protobuf import modelet_pb2
from saxml.protobuf import modelet_pb2_grpc
from saxml.server import model_service_base
from saxml.server import servable_model_registry
from saxml.server import spmd_backend
import tensorflow as tf

from google3.third_party.pybind11_abseil import status as absl_status


_SAX_CELL = flags.DEFINE_string(
'sax_cell',
None,
Expand Down Expand Up @@ -119,11 +123,23 @@ def _load_static_model(
model_key: str,
checkpoint: str,
channel_creds: Optional[grpc.ChannelCredentials],
sax_cell: Optional[str],
) -> None:
"""Loads statically specified model to a started service."""
logging.info(
'Loading key %s, model %s, checkpoint %s.', model_key, model, checkpoint
)
# Get overrides that might have been provided via 'saxutil publish' and apply
# them.
overrides = {}
if sax_cell:
try:
overrides = sax.ListDetail(model_key).overrides
logging.info('Got overrides: %s', overrides)
except absl_status.StatusNotOk as e:
logging.warning(
"Could not get model details, not applying overrides: '%s'", e
)
if channel_creds is None:
channel = grpc.insecure_channel(f'localhost:{port}')
else:
Expand All @@ -132,7 +148,8 @@ def _load_static_model(
grpc.channel_ready_future(channel).result(timeout=10)
stub = modelet_pb2_grpc.ModeletStub(channel)
req = modelet_pb2.LoadRequest(
model_key=model_key, model_path=model, checkpoint_path=checkpoint
model_key=model_key, model_path=model, checkpoint_path=checkpoint,
overrides=overrides,
)
try:
stub.Load(req)
Expand Down Expand Up @@ -235,7 +252,8 @@ def run(channel_creds: Optional[grpc.ChannelCredentials]) -> None:
for model, key, ckpt in zip(
_MODELS.value, _MODEL_KEYS.value, _CHECKPOINTS.value
):
_load_static_model(_PORT.value, model, key, ckpt, channel_creds)
_load_static_model(_PORT.value, model, key, ckpt, channel_creds,
_SAX_CELL.value)
runner.on_initial_models_load_completion()
runner.wait()
finally:
Expand Down

0 comments on commit f1dd7c1

Please sign in to comment.