Skip to content

Commit

Permalink
Query and forward parameter overrides.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 661290405
Change-Id: I4c51a9f542a00a94424fedf2369128aa6d073f1a
  • Loading branch information
Sax Authors authored and copybara-github committed Aug 9, 2024
1 parent f95a0c2 commit aa563db
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 11 deletions.
3 changes: 0 additions & 3 deletions saxml/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,6 @@ pytype_strict_library(
":model_service_base",
":servable_model_registry",
":spmd_backend",
"//saxml/client/python:sax",
"//saxml/protobuf:admin_py_pb2",
"//saxml/protobuf:admin_py_pb2_grpc",
"//saxml/protobuf:modelet_py_pb2",
"//saxml/protobuf:modelet_py_pb2_grpc",
"//saxml/server/jax:jax_spmd_backend",
Expand Down
9 changes: 1 addition & 8 deletions saxml/server/model_service_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import grpc
import jax
from jax.experimental.compilation_cache import compilation_cache
from saxml.client.python import sax
from saxml.protobuf import modelet_pb2
from saxml.protobuf import modelet_pb2_grpc
from saxml.server import model_service_base
Expand Down Expand Up @@ -125,11 +124,6 @@ def _load_static_model(
logging.info(
'Loading key %s, model %s, checkpoint %s.', model_key, model, checkpoint
)
# Get overrides that might have been provided via 'saxutil publish' and apply
# them.
overrides = sax.ListDetail(model_key).overrides
logging.info('Got overrides: %s', overrides)

if channel_creds is None:
channel = grpc.insecure_channel(f'localhost:{port}')
else:
Expand All @@ -138,8 +132,7 @@ def _load_static_model(
grpc.channel_ready_future(channel).result(timeout=10)
stub = modelet_pb2_grpc.ModeletStub(channel)
req = modelet_pb2.LoadRequest(
model_key=model_key, model_path=model, checkpoint_path=checkpoint,
overrides=overrides,
model_key=model_key, model_path=model, checkpoint_path=checkpoint
)
try:
stub.Load(req)
Expand Down

0 comments on commit aa563db

Please sign in to comment.