diff --git a/wyvern/components/models/modelbit_component.py b/wyvern/components/models/modelbit_component.py index 9b1ada1..d41522c 100644 --- a/wyvern/components/models/modelbit_component.py +++ b/wyvern/components/models/modelbit_component.py @@ -2,7 +2,7 @@ import asyncio import logging from functools import cached_property -from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union +from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union, cast from wyvern.components.models.model_component import ModelComponent from wyvern.config import settings @@ -146,9 +146,33 @@ async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: # individual_output[1] is the actual output output_data[ target_identifiers[batch_idx * settings.MODELBIT_BATCH_SIZE + idx] - ] = individual_output[1] + ] = self.transform_response(individual_output[1]) return self.model_output_type( data=output_data, model_name=self.name, ) + + def transform_response( + self, + modelbit_resp: Any, + ) -> Optional[Union[float, str, List[float]]]: + """ + This method parses the response from Modelbit. + """ + if isinstance(modelbit_resp, list): + return cast(List[float], modelbit_resp) + if isinstance(modelbit_resp, bool): + return float(modelbit_resp) + if isinstance(modelbit_resp, dict): + return self.transform_dict_response(modelbit_resp) + return modelbit_resp + + def transform_dict_response( + self, + modelbit_resp: Dict[str, Any], + ) -> Optional[Union[float, str, List[float]]]: + """ + This method parses the response from Modelbit and return the data format that's supported by wyvern. + """ + raise NotImplementedError