diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 37735152..eed35f8c 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -885,3 +885,17 @@ async def Decode( # pylint: disable=invalid-overridden-method ) # Reset buffer after flushed. buffered_response_list = [] + + async def HealthCheck( # pylint: disable=invalid-overridden-method + self, + request: jetstream_pb2.HealthCheckRequest, + context: Optional[grpc.aio.ServicerContext] = None, + ) -> jetstream_pb2.HealthCheckResponse: + """HealthCheck.""" + if context is None: + logging.warning( + "LLM orchestrator is being used in offline test mode, and will not" + " respond to gRPC queries - only direct function calls." + ) + is_live = self._driver.live + return jetstream_pb2.HealthCheckResponse(is_live=is_live) diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index 540eb48e..5f2e8869 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -21,6 +21,8 @@ package jetstream_proto; service Orchestrator { // Query LLM to generate text or tokens. rpc Decode(DecodeRequest) returns (stream DecodeResponse) {} + // Checks if the model server is live. + rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse) {} } message DecodeRequest { @@ -74,4 +76,11 @@ message DecodeResponse { } reserved 1; // Next ID: 4 +} + +message HealthCheckRequest {} + +message HealthCheckResponse { + // Denotes whether the model server is live + bool is_live = 1; } \ No newline at end of file diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index 4f39c52d..3fadd54c 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -28,7 +28,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\x32]\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x62\x06proto3' + b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' ) _globals = globals() @@ -52,6 +52,10 @@ _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 670 _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 629 _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 670 - _globals["_ORCHESTRATOR"]._serialized_start = 689 - _globals["_ORCHESTRATOR"]._serialized_end = 782 + _globals["_HEALTHCHECKREQUEST"]._serialized_start = 689 + _globals["_HEALTHCHECKREQUEST"]._serialized_end = 709 + _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 711 + _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 749 + _globals["_ORCHESTRATOR"]._serialized_start = 752 + _globals["_ORCHESTRATOR"]._serialized_end = 937 # @@protoc_insertion_point(module_scope) diff --git a/jetstream/core/proto/jetstream_pb2_grpc.py b/jetstream/core/proto/jetstream_pb2_grpc.py index 044f6d80..84521185 100644 --- a/jetstream/core/proto/jetstream_pb2_grpc.py +++ b/jetstream/core/proto/jetstream_pb2_grpc.py @@ -34,6 +34,11 @@ def __init__(self, channel): request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString, response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString, ) + self.HealthCheck = channel.unary_unary( + "/jetstream_proto.Orchestrator/HealthCheck", + request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString, + response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString, + ) class OrchestratorServicer(object): @@ -45,6 +50,12 @@ def Decode(self, request, context): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def HealthCheck(self, request, context): + """Checks if the model server is live.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def add_OrchestratorServicer_to_server(servicer, server): rpc_method_handlers = { @@ -53,6 +64,11 @@ def add_OrchestratorServicer_to_server(servicer, server): request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.FromString, response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.SerializeToString, ), + "HealthCheck": grpc.unary_unary_rpc_method_handler( + servicer.HealthCheck, + request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.FromString, + response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( "jetstream_proto.Orchestrator", rpc_method_handlers @@ -92,3 +108,32 @@ def Decode( timeout, metadata, ) + + @staticmethod + def HealthCheck( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/jetstream_proto.Orchestrator/HealthCheck", + jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString, + jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index f083a823..150ac39d 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -82,6 +82,12 @@ async def test_server( ) as channel: stub = jetstream_pb2_grpc.OrchestratorStub(channel) + healthcheck_request = jetstream_pb2.HealthCheckRequest() + healthcheck_response = stub.HealthCheck(healthcheck_request) + healthcheck_response = await healthcheck_response + + assert healthcheck_response.is_live is True + # The string representation of np.array([[65, 66]]), [2] will be prepended # as BOS text = "AB"