Skip to content

Commit

Permalink
Add healthcheck support for JetStream (#90)
Browse files Browse the repository at this point in the history
* Add healthcheck support for JetStream

* fix indentation

* fix pylint unit test

* use pyink to reformat generated protos
  • Loading branch information
vivianrwu authored May 29, 2024
1 parent a223df9 commit 0c56aac
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 3 deletions.
14 changes: 14 additions & 0 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,3 +885,17 @@ async def Decode( # pylint: disable=invalid-overridden-method
)
# Reset buffer after flushed.
buffered_response_list = []

async def HealthCheck( # pylint: disable=invalid-overridden-method
self,
request: jetstream_pb2.HealthCheckRequest,
context: Optional[grpc.aio.ServicerContext] = None,
) -> jetstream_pb2.HealthCheckResponse:
"""HealthCheck."""
if context is None:
logging.warning(
"LLM orchestrator is being used in offline test mode, and will not"
" respond to gRPC queries - only direct function calls."
)
is_live = self._driver.live
return jetstream_pb2.HealthCheckResponse(is_live=is_live)
9 changes: 9 additions & 0 deletions jetstream/core/proto/jetstream.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ package jetstream_proto;
service Orchestrator {
// Query LLM to generate text or tokens.
rpc Decode(DecodeRequest) returns (stream DecodeResponse) {}
// Checks if the model server is live.
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse) {}
}

message DecodeRequest {
Expand Down Expand Up @@ -74,4 +76,11 @@ message DecodeResponse {
}
reserved 1;
// Next ID: 4
}

message HealthCheckRequest {}

message HealthCheckResponse {
// Denotes whether the model server is live
bool is_live = 1;
}
10 changes: 7 additions & 3 deletions jetstream/core/proto/jetstream_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\x32]\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x62\x06proto3'
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3'
)

_globals = globals()
Expand All @@ -52,6 +52,10 @@
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 670
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 629
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 670
_globals["_ORCHESTRATOR"]._serialized_start = 689
_globals["_ORCHESTRATOR"]._serialized_end = 782
_globals["_HEALTHCHECKREQUEST"]._serialized_start = 689
_globals["_HEALTHCHECKREQUEST"]._serialized_end = 709
_globals["_HEALTHCHECKRESPONSE"]._serialized_start = 711
_globals["_HEALTHCHECKRESPONSE"]._serialized_end = 749
_globals["_ORCHESTRATOR"]._serialized_start = 752
_globals["_ORCHESTRATOR"]._serialized_end = 937
# @@protoc_insertion_point(module_scope)
45 changes: 45 additions & 0 deletions jetstream/core/proto/jetstream_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def __init__(self, channel):
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString,
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString,
)
self.HealthCheck = channel.unary_unary(
"/jetstream_proto.Orchestrator/HealthCheck",
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString,
)


class OrchestratorServicer(object):
Expand All @@ -45,6 +50,12 @@ def Decode(self, request, context):
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")

def HealthCheck(self, request, context):
"""Checks if the model server is live."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")


def add_OrchestratorServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -53,6 +64,11 @@ def add_OrchestratorServicer_to_server(servicer, server):
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.FromString,
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.SerializeToString,
),
"HealthCheck": grpc.unary_unary_rpc_method_handler(
servicer.HealthCheck,
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.FromString,
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
"jetstream_proto.Orchestrator", rpc_method_handlers
Expand Down Expand Up @@ -92,3 +108,32 @@ def Decode(
timeout,
metadata,
)

@staticmethod
def HealthCheck(
request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.unary_unary(
request,
target,
"/jetstream_proto.Orchestrator/HealthCheck",
jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString,
jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
6 changes: 6 additions & 0 deletions jetstream/tests/core/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ async def test_server(
) as channel:
stub = jetstream_pb2_grpc.OrchestratorStub(channel)

healthcheck_request = jetstream_pb2.HealthCheckRequest()
healthcheck_response = stub.HealthCheck(healthcheck_request)
healthcheck_response = await healthcheck_response

assert healthcheck_response.is_live is True

# The string representation of np.array([[65, 66]]), [2] will be prepended
# as BOS
text = "AB"
Expand Down

0 comments on commit 0c56aac

Please sign in to comment.