diff --git a/saxml/common/testutil.go b/saxml/common/testutil.go index 789bad35..39373de4 100644 --- a/saxml/common/testutil.go +++ b/saxml/common/testutil.go @@ -593,6 +593,18 @@ func (s *stubVisionModelServer) VideoToText(ctx context.Context, in *vmpb.VideoT }, nil } +func (s *stubVisionModelServer) VideoToToken(ctx context.Context, in *vmpb.VideoToTokenRequest) (*vmpb.VideoToTokenResponse, error) { + value := float64(len(in.GetImageFrames())) + return &vmpb.VideoToTokenResponse{ + Tokens: []float64{ + 5.0, + 6.0, + 1.0, + value, + }, + }, nil +} + type stubAudioModelServer struct{} func (s *stubAudioModelServer) Recognize(ctx context.Context, in *ampb.AsrRequest) (*ampb.AsrResponse, error) { diff --git a/saxml/protobuf/vision.proto b/saxml/protobuf/vision.proto index 9f966d7c..d7526f11 100644 --- a/saxml/protobuf/vision.proto +++ b/saxml/protobuf/vision.proto @@ -180,6 +180,16 @@ message VideoToTextResponse { repeated DecodedText texts = 1; } +message VideoToTokenRequest { + string model_key = 1; + repeated bytes image_frames = 2; // Video composed of multiple image frames. + .sax.ExtraInputs extra_inputs = 3; +} + +message VideoToTokenResponse { + repeated double tokens = 1; // quantized or soft tokens. +} + service VisionService { // Returns the score (e.g., log pplx) given the text. rpc Classify(ClassifyRequest) returns (ClassifyResponse); @@ -206,4 +216,7 @@ service VisionService { // Returns text generation results given video. rpc VideoToText(VideoToTextRequest) returns (VideoToTextResponse); + + // Returns video tokens results given video. + rpc VideoToToken(VideoToTokenRequest) returns (VideoToTokenResponse); } diff --git a/saxml/server/pax/vision/servable_vision_model.py b/saxml/server/pax/vision/servable_vision_model.py index 59ee0e64..d3b07765 100644 --- a/saxml/server/pax/vision/servable_vision_model.py +++ b/saxml/server/pax/vision/servable_vision_model.py @@ -172,6 +172,21 @@ class VideoToTextHParams(servable_model_params.ServableMethodParams): model_method_name: Optional[str] = None +@dataclasses.dataclass +class VideoToTokenHParams(servable_model_params.ServableMethodParams): + """HParameters for VideoToToken method. + + Attributes: + image_preprocessor: Pre-processing function to convert a single frame + image_bytes into image tensor. Required. + model_method_name: The name of the method to call to return tokens from an + input image. Required. + """ + + image_preprocessor: Optional[Callable[[str], tf.Tensor]] = None + model_method_name: Optional[str] = None + + class VisionModelParamsBase(servable_model_params.ServableModelParams): """Base Vision Model params. @@ -217,6 +232,9 @@ def methods(self) -> Dict[str, servable_model_params.ServableMethodParams]: video_to_text_params = self.video_to_text() if video_to_text_params is not None: methods[VisionMethodName.VIDEO_TO_TEXT] = video_to_text_params + video_to_token_params = self.video_to_token() + if video_to_token_params is not None: + methods[VisionMethodName.VIDEO_TO_TOKEN] = video_to_token_params # pylint: enable=assignment-from-none return methods @@ -250,6 +268,9 @@ def image_to_text(self) -> Optional[ImageToTextHParams]: def video_to_text(self) -> Optional[VideoToTextHParams]: return None + def video_to_token(self) -> Optional[VideoToTokenHParams]: + return None + def create_model(self, primary_process_id: int) -> 'VisionModel': return VisionModel( self, @@ -835,6 +856,69 @@ def _preprocess_images(self, raw_input: Any) -> NestedNpTensor: return image_data +class VideoToToken(servable_model.ServableMethod): + """Method for implementing video tokenization.""" + + def __init__( + self, + model, + model_fn_name: str, + model_state, + method_hparams: VideoToTokenHParams, + prng_key, + dummy_input_sample: Any, + model_config: Any, + ): + self._model_config = model_config + if method_hparams.image_preprocessor is None: + raise ValueError( + 'image_preprocessor method must be defined in VideoToTokenHParams' + ) + self._cluster = copy.deepcopy(cluster_factory.Current()) + self._cluster.params.do_eval = True + with self._cluster: + self._image_preprocessor = method_hparams.image_preprocessor + + super().__init__( + model, + model_fn_name, + model_state, + method_hparams, + prng_key, + dummy_input_sample, + ) + + @classmethod + def service_id(cls) -> str: + return vision_service.SERVICE_ID + + def fetch_output( + self, model_fn_outputs: NestedJTensor, model_fn_inputs: NestedJTensor + ) -> NestedJTensor: + """Fetches useful output tensors from the model function outputs.""" + return NestedMap(tokens=model_fn_outputs[0]) + + def pre_processing(self, raw_inputs: List[Any]) -> NestedNpTensor: + """Preprocesses an unpadded batch of data into host numpy arrays.""" + batched_video_tensors = [] + for inp in raw_inputs: + video_tensors = [] + for image_frame in inp['image_frames']: + image_tensor = self._image_preprocessor(image_frame) + video_tensors.append(image_tensor.numpy()) + video_tensors = np.stack(video_tensors) + batched_video_tensors.append(video_tensors) + return NestedMap(images=np.stack(batched_video_tensors)) + + def post_processing(self, compute_outputs: NestedNpTensor) -> List[Any]: + """Postprocesses the output numpy arrays to final host output.""" + # Take output ids and convert back to strings using tokenizer. + tokens = compute_outputs['tokens'] # [batch, ...] + if tokens.dtype not in [np.float32, np.float64]: + tokens = tokens.astype(np.float32) + return list(tokens) + + class VisionModel(servable_model.ServableModel): """Model for vision tasks.""" @@ -961,5 +1045,23 @@ def init_method( dummy_input_sample=dummy_input, model_config=self.model_config, ) + elif method == VisionMethodName.VIDEO_TO_TOKEN: + assert isinstance(method_params, VideoToTokenHParams) + if method_params.model_method_name is None: + raise ValueError( + 'Must specify `model_method_name` in VideoToTokenHParams.' + ) + # TODO(huangyp): Use model-specific dummy input. + image_bytes = tf.image.encode_jpeg(np.ones((256, 256, 3), dtype=np.uint8)) + dummy_input = {'image_frames': [image_bytes]} + return VideoToToken( + model, + method_params.model_method_name, + model_state, + method_params, + prng_key=prng_key, + dummy_input_sample=dummy_input, + model_config=self.model_config, + ) else: raise NotImplementedError(f'method {method} not implemented.') diff --git a/saxml/server/services/vision_service.py b/saxml/server/services/vision_service.py index 4397cc86..680dad96 100644 --- a/saxml/server/services/vision_service.py +++ b/saxml/server/services/vision_service.py @@ -32,6 +32,7 @@ class VisionMethodName: IMAGE_TO_TEXT = 'vm.image_to_text' IMAGE_TO_IMAGE = 'vm.image_to_image' VIDEO_TO_TEXT = 'vm.video_to_text' + VIDEO_TO_TOKEN = 'vm.video_to_token' class VisionService(model_service_base.ModelService): @@ -81,6 +82,10 @@ def ParseMethodRPCRequest(self, method_name: str, request: Any) -> Any: 'image_frames': list(request.image_frames), 'text': np.array(request.text), } + if method_name == VisionMethodName.VIDEO_TO_TOKEN: + return { + 'image_frames': list(request.image_frames), + } raise NotImplementedError(f'Method {method_name} unimplemented.') def FillRPCResponse( @@ -163,6 +168,11 @@ def FillRPCResponse( for text, score in zip(texts, scores): response.texts.append(vision_pb2.DecodedText(text=text, score=score)) return + if method_name == VisionMethodName.VIDEO_TO_TOKEN: + tokens = method_outputs + for token in tokens: + response.tokens.append(token) + return raise NotImplementedError(f'Method {method_name} unimplemented.') @@ -255,3 +265,14 @@ async def VideoToText(self, request, context): resp, ) return resp + + async def VideoToToken(self, request, context): + resp = vision_pb2.VideoToTokenResponse() + await self.EnqueueRequest( + VisionMethodName.VIDEO_TO_TOKEN, + request.model_key, + context, + request, + resp, + ) + return resp