Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
SingleEntityModelbitComponent
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng committed Sep 19, 2023
1 parent f7008f3 commit 1f6d5f0
Showing 1 changed file with 72 additions and 38 deletions.
110 changes: 72 additions & 38 deletions wyvern/components/models/modelbit_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union

from wyvern.components.models.model_component import ModelComponent
from wyvern.components.models.model_component import (
BaseModelComponent,
MultiEntityModelComponent,
SingleEntityModelComponent,
)
from wyvern.config import settings
from wyvern.core.http import aiohttp_client
from wyvern.entities.identifier import Identifier
Expand All @@ -15,23 +19,13 @@
WyvernModelbitTokenMissingError,
WyvernModelbitValidationError,
)
from wyvern.wyvern_typing import INPUT_TYPE, REQUEST_ENTITY

JSON: TypeAlias = Union[Dict[str, "JSON"], List["JSON"], str, int, float, bool, None]
logger = logging.getLogger(__name__)


class ModelbitComponent(ModelComponent[MODEL_INPUT, MODEL_OUTPUT]):
"""
ModelbitComponent is a base class for all modelbit model components. It provides a common interface to implement
all modelbit models.
ModelbitComponent is a subclass of ModelComponent.
Attributes:
AUTH_TOKEN: A class variable that stores the auth token for Modelbit.
URL: A class variable that stores the url for Modelbit.
"""

class ModelbitMixin(BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT]):
AUTH_TOKEN: str = ""
URL: str = ""

Expand Down Expand Up @@ -80,31 +74,7 @@ def manifest_feature_names(self) -> Set[str]:
"""
return set(self.modelbit_features)

async def build_requests(
self,
input: MODEL_INPUT,
) -> Tuple[List[Identifier], List[Any]]:
"""
Please refer to modlebit batch inference API:
https://doc.modelbit.com/deployments/rest-api/
"""
target_entities: List[
Union[WyvernEntity, BaseWyvernRequest]
] = input.entities or [input.request]
target_identifiers = [entity.identifier for entity in target_entities]
all_requests = [
[
idx + 1,
[
self.get_feature(identifier, feature_name)
for feature_name in self.modelbit_features
],
]
for idx, identifier in enumerate(target_identifiers)
]
return target_identifiers, all_requests

async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:
async def inference(self, input: INPUT_TYPE, **kwargs) -> MODEL_OUTPUT:
"""
This method sends a request to Modelbit and returns the output.
"""
Expand Down Expand Up @@ -153,3 +123,67 @@ async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:
data=output_data,
model_name=self.name,
)

async def build_requests(
self,
input: INPUT_TYPE,
) -> Tuple[List[Identifier], List[Any]]:
"""
This method builds requests for Modelbit. This method should be implemented by the subclass.
"""
raise NotImplementedError


class ModelbitComponent(
ModelbitMixin[MODEL_INPUT, MODEL_OUTPUT],
MultiEntityModelComponent[MODEL_INPUT, MODEL_OUTPUT],
):
"""
ModelbitComponent is a base class for all modelbit model components. It provides a common interface to implement
all modelbit models.
ModelbitComponent is a subclass of ModelComponent.
Attributes:
AUTH_TOKEN: A class variable that stores the auth token for Modelbit.
URL: A class variable that stores the url for Modelbit.
"""

async def build_requests(
self,
input: MODEL_INPUT,
) -> Tuple[List[Identifier], List[Any]]:
"""
Please refer to modlebit batch inference API:
https://doc.modelbit.com/deployments/rest-api/
"""
target_entities: List[
Union[WyvernEntity, BaseWyvernRequest]
] = input.entities or [input.request]
target_identifiers = [entity.identifier for entity in target_entities]
all_requests = [
[
idx + 1,
[
self.get_feature(identifier, feature_name)
for feature_name in self.modelbit_features
],
]
for idx, identifier in enumerate(target_identifiers)
]
return target_identifiers, all_requests


class SingleEntityModelbitComponent(
SingleEntityModelComponent[REQUEST_ENTITY, MODEL_OUTPUT],
):
async def build_requests(
self,
input: REQUEST_ENTITY,
) -> Tuple[List[Identifier], List[Any]]:
target_identifier, request = await self.build_request(input)
modelbit_request = [[1, request]]
return [target_identifier], modelbit_request

async def build_request(self, input: REQUEST_ENTITY) -> Tuple[Identifier, Any]:
raise NotImplementedError

0 comments on commit 1f6d5f0

Please sign in to comment.