From 1f6d5f0c95e4a1bfc74506585f80a3a429c3e88d Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Mon, 18 Sep 2023 20:16:38 -0700 Subject: [PATCH] SingleEntityModelbitComponent --- .../components/models/modelbit_component.py | 110 ++++++++++++------ 1 file changed, 72 insertions(+), 38 deletions(-) diff --git a/wyvern/components/models/modelbit_component.py b/wyvern/components/models/modelbit_component.py index 39c3e5e..450c39a 100644 --- a/wyvern/components/models/modelbit_component.py +++ b/wyvern/components/models/modelbit_component.py @@ -4,7 +4,11 @@ from functools import cached_property from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union -from wyvern.components.models.model_component import ModelComponent +from wyvern.components.models.model_component import ( + BaseModelComponent, + MultiEntityModelComponent, + SingleEntityModelComponent, +) from wyvern.config import settings from wyvern.core.http import aiohttp_client from wyvern.entities.identifier import Identifier @@ -15,23 +19,13 @@ WyvernModelbitTokenMissingError, WyvernModelbitValidationError, ) +from wyvern.wyvern_typing import INPUT_TYPE, REQUEST_ENTITY JSON: TypeAlias = Union[Dict[str, "JSON"], List["JSON"], str, int, float, bool, None] logger = logging.getLogger(__name__) -class ModelbitComponent(ModelComponent[MODEL_INPUT, MODEL_OUTPUT]): - """ - ModelbitComponent is a base class for all modelbit model components. It provides a common interface to implement - all modelbit models. - - ModelbitComponent is a subclass of ModelComponent. - - Attributes: - AUTH_TOKEN: A class variable that stores the auth token for Modelbit. - URL: A class variable that stores the url for Modelbit. - """ - +class ModelbitMixin(BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT]): AUTH_TOKEN: str = "" URL: str = "" @@ -80,31 +74,7 @@ def manifest_feature_names(self) -> Set[str]: """ return set(self.modelbit_features) - async def build_requests( - self, - input: MODEL_INPUT, - ) -> Tuple[List[Identifier], List[Any]]: - """ - Please refer to modlebit batch inference API: - https://doc.modelbit.com/deployments/rest-api/ - """ - target_entities: List[ - Union[WyvernEntity, BaseWyvernRequest] - ] = input.entities or [input.request] - target_identifiers = [entity.identifier for entity in target_entities] - all_requests = [ - [ - idx + 1, - [ - self.get_feature(identifier, feature_name) - for feature_name in self.modelbit_features - ], - ] - for idx, identifier in enumerate(target_identifiers) - ] - return target_identifiers, all_requests - - async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: + async def inference(self, input: INPUT_TYPE, **kwargs) -> MODEL_OUTPUT: """ This method sends a request to Modelbit and returns the output. """ @@ -153,3 +123,67 @@ async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: data=output_data, model_name=self.name, ) + + async def build_requests( + self, + input: INPUT_TYPE, + ) -> Tuple[List[Identifier], List[Any]]: + """ + This method builds requests for Modelbit. This method should be implemented by the subclass. + """ + raise NotImplementedError + + +class ModelbitComponent( + ModelbitMixin[MODEL_INPUT, MODEL_OUTPUT], + MultiEntityModelComponent[MODEL_INPUT, MODEL_OUTPUT], +): + """ + ModelbitComponent is a base class for all modelbit model components. It provides a common interface to implement + all modelbit models. + + ModelbitComponent is a subclass of ModelComponent. + + Attributes: + AUTH_TOKEN: A class variable that stores the auth token for Modelbit. + URL: A class variable that stores the url for Modelbit. + """ + + async def build_requests( + self, + input: MODEL_INPUT, + ) -> Tuple[List[Identifier], List[Any]]: + """ + Please refer to modlebit batch inference API: + https://doc.modelbit.com/deployments/rest-api/ + """ + target_entities: List[ + Union[WyvernEntity, BaseWyvernRequest] + ] = input.entities or [input.request] + target_identifiers = [entity.identifier for entity in target_entities] + all_requests = [ + [ + idx + 1, + [ + self.get_feature(identifier, feature_name) + for feature_name in self.modelbit_features + ], + ] + for idx, identifier in enumerate(target_identifiers) + ] + return target_identifiers, all_requests + + +class SingleEntityModelbitComponent( + SingleEntityModelComponent[REQUEST_ENTITY, MODEL_OUTPUT], +): + async def build_requests( + self, + input: REQUEST_ENTITY, + ) -> Tuple[List[Identifier], List[Any]]: + target_identifier, request = await self.build_request(input) + modelbit_request = [[1, request]] + return [target_identifier], modelbit_request + + async def build_request(self, input: REQUEST_ENTITY) -> Tuple[Identifier, Any]: + raise NotImplementedError