Skip to content

Commit

Permalink
Query and forward parameter overrides.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 661134789
Change-Id: I3c5dfe27272c57cac9d258db8686b18ec2c032fb
  • Loading branch information
Sax Authors authored and copybara-github committed Aug 9, 2024
1 parent 37492af commit f95a0c2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
3 changes: 3 additions & 0 deletions saxml/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ pytype_strict_library(
":model_service_base",
":servable_model_registry",
":spmd_backend",
"//saxml/client/python:sax",
"//saxml/protobuf:admin_py_pb2",
"//saxml/protobuf:admin_py_pb2_grpc",
"//saxml/protobuf:modelet_py_pb2",
"//saxml/protobuf:modelet_py_pb2_grpc",
"//saxml/server/jax:jax_spmd_backend",
Expand Down
9 changes: 8 additions & 1 deletion saxml/server/model_service_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import grpc
import jax
from jax.experimental.compilation_cache import compilation_cache
from saxml.client.python import sax
from saxml.protobuf import modelet_pb2
from saxml.protobuf import modelet_pb2_grpc
from saxml.server import model_service_base
Expand Down Expand Up @@ -124,6 +125,11 @@ def _load_static_model(
logging.info(
'Loading key %s, model %s, checkpoint %s.', model_key, model, checkpoint
)
# Get overrides that might have been provided via 'saxutil publish' and apply
# them.
overrides = sax.ListDetail(model_key).overrides
logging.info('Got overrides: %s', overrides)

if channel_creds is None:
channel = grpc.insecure_channel(f'localhost:{port}')
else:
Expand All @@ -132,7 +138,8 @@ def _load_static_model(
grpc.channel_ready_future(channel).result(timeout=10)
stub = modelet_pb2_grpc.ModeletStub(channel)
req = modelet_pb2.LoadRequest(
model_key=model_key, model_path=model, checkpoint_path=checkpoint
model_key=model_key, model_path=model, checkpoint_path=checkpoint,
overrides=overrides,
)
try:
stub.Load(req)
Expand Down

0 comments on commit f95a0c2

Please sign in to comment.