Skip to content

Commit

Permalink
Adds a new "video_to_token" api to sax.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657254357
Change-Id: I924df899ccef3c043f2417754e56e94d1b2db425
  • Loading branch information
bignamehyp authored and copybara-github committed Jul 29, 2024
1 parent 0198c25 commit 76a6ca6
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 0 deletions.
12 changes: 12 additions & 0 deletions saxml/common/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,18 @@ func (s *stubVisionModelServer) VideoToText(ctx context.Context, in *vmpb.VideoT
}, nil
}

func (s *stubVisionModelServer) VideoToToken(ctx context.Context, in *vmpb.VideoToTokenRequest) (*vmpb.VideoToTokenResponse, error) {
value := float64(len(in.GetImageFrames()))
return &vmpb.VideoToTokenResponse{
Tokens: []float64{
5.0,
6.0,
1.0,
value,
},
}, nil
}

type stubAudioModelServer struct{}

func (s *stubAudioModelServer) Recognize(ctx context.Context, in *ampb.AsrRequest) (*ampb.AsrResponse, error) {
Expand Down
13 changes: 13 additions & 0 deletions saxml/protobuf/vision.proto
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,16 @@ message VideoToTextResponse {
repeated DecodedText texts = 1;
}

message VideoToTokenRequest {
string model_key = 1;
repeated bytes image_frames = 2; // Video composed of multiple image frames.
.sax.ExtraInputs extra_inputs = 3;
}

message VideoToTokenResponse {
repeated double tokens = 1; // quantized or soft tokens.
}

service VisionService {
// Returns the score (e.g., log pplx) given the text.
rpc Classify(ClassifyRequest) returns (ClassifyResponse);
Expand All @@ -206,4 +216,7 @@ service VisionService {

// Returns text generation results given video.
rpc VideoToText(VideoToTextRequest) returns (VideoToTextResponse);

// Returns video tokens results given video.
rpc VideoToToken(VideoToTokenRequest) returns (VideoToTokenResponse);
}
102 changes: 102 additions & 0 deletions saxml/server/pax/vision/servable_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,21 @@ class VideoToTextHParams(servable_model_params.ServableMethodParams):
model_method_name: Optional[str] = None


@dataclasses.dataclass
class VideoToTokenHParams(servable_model_params.ServableMethodParams):
"""HParameters for VideoToToken method.
Attributes:
image_preprocessor: Pre-processing function to convert a single frame
image_bytes into image tensor. Required.
model_method_name: The name of the method to call to return tokens from an
input image. Required.
"""

image_preprocessor: Optional[Callable[[str], tf.Tensor]] = None
model_method_name: Optional[str] = None


class VisionModelParamsBase(servable_model_params.ServableModelParams):
"""Base Vision Model params.
Expand Down Expand Up @@ -217,6 +232,9 @@ def methods(self) -> Dict[str, servable_model_params.ServableMethodParams]:
video_to_text_params = self.video_to_text()
if video_to_text_params is not None:
methods[VisionMethodName.VIDEO_TO_TEXT] = video_to_text_params
video_to_token_params = self.video_to_token()
if video_to_token_params is not None:
methods[VisionMethodName.VIDEO_TO_TOKEN] = video_to_token_params
# pylint: enable=assignment-from-none
return methods

Expand Down Expand Up @@ -250,6 +268,9 @@ def image_to_text(self) -> Optional[ImageToTextHParams]:
def video_to_text(self) -> Optional[VideoToTextHParams]:
return None

def video_to_token(self) -> Optional[VideoToTokenHParams]:
return None

def create_model(self, primary_process_id: int) -> 'VisionModel':
return VisionModel(
self,
Expand Down Expand Up @@ -835,6 +856,69 @@ def _preprocess_images(self, raw_input: Any) -> NestedNpTensor:
return image_data


class VideoToToken(servable_model.ServableMethod):
"""Method for implementing video tokenization."""

def __init__(
self,
model,
model_fn_name: str,
model_state,
method_hparams: VideoToTokenHParams,
prng_key,
dummy_input_sample: Any,
model_config: Any,
):
self._model_config = model_config
if method_hparams.image_preprocessor is None:
raise ValueError(
'image_preprocessor method must be defined in VideoToTokenHParams'
)
self._cluster = copy.deepcopy(cluster_factory.Current())
self._cluster.params.do_eval = True
with self._cluster:
self._image_preprocessor = method_hparams.image_preprocessor

super().__init__(
model,
model_fn_name,
model_state,
method_hparams,
prng_key,
dummy_input_sample,
)

@classmethod
def service_id(cls) -> str:
return vision_service.SERVICE_ID

def fetch_output(
self, model_fn_outputs: NestedJTensor, model_fn_inputs: NestedJTensor
) -> NestedJTensor:
"""Fetches useful output tensors from the model function outputs."""
return NestedMap(tokens=model_fn_outputs[0])

def pre_processing(self, raw_inputs: List[Any]) -> NestedNpTensor:
"""Preprocesses an unpadded batch of data into host numpy arrays."""
batched_video_tensors = []
for inp in raw_inputs:
video_tensors = []
for image_frame in inp['image_frames']:
image_tensor = self._image_preprocessor(image_frame)
video_tensors.append(image_tensor.numpy())
video_tensors = np.stack(video_tensors)
batched_video_tensors.append(video_tensors)
return NestedMap(images=np.stack(batched_video_tensors))

def post_processing(self, compute_outputs: NestedNpTensor) -> List[Any]:
"""Postprocesses the output numpy arrays to final host output."""
# Take output ids and convert back to strings using tokenizer.
tokens = compute_outputs['tokens'] # [batch, ...]
if tokens.dtype not in [np.float32, np.float64]:
tokens = tokens.astype(np.float32)
return list(tokens)


class VisionModel(servable_model.ServableModel):
"""Model for vision tasks."""

Expand Down Expand Up @@ -961,5 +1045,23 @@ def init_method(
dummy_input_sample=dummy_input,
model_config=self.model_config,
)
elif method == VisionMethodName.VIDEO_TO_TOKEN:
assert isinstance(method_params, VideoToTokenHParams)
if method_params.model_method_name is None:
raise ValueError(
'Must specify `model_method_name` in VideoToTokenHParams.'
)
# TODO(huangyp): Use model-specific dummy input.
image_bytes = tf.image.encode_jpeg(np.ones((256, 256, 3), dtype=np.uint8))
dummy_input = {'image_frames': [image_bytes]}
return VideoToToken(
model,
method_params.model_method_name,
model_state,
method_params,
prng_key=prng_key,
dummy_input_sample=dummy_input,
model_config=self.model_config,
)
else:
raise NotImplementedError(f'method {method} not implemented.')
21 changes: 21 additions & 0 deletions saxml/server/services/vision_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class VisionMethodName:
IMAGE_TO_TEXT = 'vm.image_to_text'
IMAGE_TO_IMAGE = 'vm.image_to_image'
VIDEO_TO_TEXT = 'vm.video_to_text'
VIDEO_TO_TOKEN = 'vm.video_to_token'


class VisionService(model_service_base.ModelService):
Expand Down Expand Up @@ -81,6 +82,10 @@ def ParseMethodRPCRequest(self, method_name: str, request: Any) -> Any:
'image_frames': list(request.image_frames),
'text': np.array(request.text),
}
if method_name == VisionMethodName.VIDEO_TO_TOKEN:
return {
'image_frames': list(request.image_frames),
}
raise NotImplementedError(f'Method {method_name} unimplemented.')

def FillRPCResponse(
Expand Down Expand Up @@ -163,6 +168,11 @@ def FillRPCResponse(
for text, score in zip(texts, scores):
response.texts.append(vision_pb2.DecodedText(text=text, score=score))
return
if method_name == VisionMethodName.VIDEO_TO_TOKEN:
tokens = method_outputs
for token in tokens:
response.tokens.append(token)
return
raise NotImplementedError(f'Method {method_name} unimplemented.')


Expand Down Expand Up @@ -255,3 +265,14 @@ async def VideoToText(self, request, context):
resp,
)
return resp

async def VideoToToken(self, request, context):
resp = vision_pb2.VideoToTokenResponse()
await self.EnqueueRequest(
VisionMethodName.VIDEO_TO_TOKEN,
request.model_key,
context,
request,
resp,
)
return resp

0 comments on commit 76a6ca6

Please sign in to comment.