Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
introduce singular model component and singular pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng committed Sep 17, 2023
1 parent 91ddb8e commit 20adf41
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 2 deletions.
22 changes: 20 additions & 2 deletions wyvern/components/models/model_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from wyvern.config import settings
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import WyvernEntity
from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT
from wyvern.entities.model_entities import (
MODEL_INPUT,
MODEL_OUTPUT,
SINGULAR_MODEL_INPUT,
)
from wyvern.entities.request import BaseWyvernRequest
from wyvern.event_logging import event_logger

Expand Down Expand Up @@ -53,7 +57,7 @@ class ModelEvent(LoggedEvent[ModelEventData]):
event_type: EventType = EventType.MODEL


class ModelComponent(
class BaseModelComponent(
Component[
MODEL_INPUT,
MODEL_OUTPUT,
Expand Down Expand Up @@ -150,6 +154,15 @@ def events_generator() -> List[ModelEvent]:

return model_output

async def inference(
self,
input: MODEL_INPUT,
**kwargs,
) -> MODEL_OUTPUT:
raise NotImplementedError


class ModelComponent(BaseModelComponent[MODEL_INPUT, MODEL_OUTPUT]):
async def batch_inference(
self,
request: BaseWyvernRequest,
Expand Down Expand Up @@ -204,3 +217,8 @@ async def inference(
data=output_data,
model_name=self.name,
)


class SingularModelComponent(BaseModelComponent[SINGULAR_MODEL_INPUT, MODEL_OUTPUT]):
async def inference(self, input: SINGULAR_MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:
raise NotImplementedError
78 changes: 78 additions & 0 deletions wyvern/components/singular_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
from typing import Any, Generic, List, Optional

from pydantic.generics import GenericModel

from wyvern.components.business_logic.business_logic import BusinessLogicPipeline
from wyvern.components.component import Component
from wyvern.components.events.events import LoggedEvent
from wyvern.components.models.model_component import SingularModelComponent
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.entities.model_entities import SingularModelInput
from wyvern.entities.request import BaseWyvernRequest
from wyvern.event_logging import event_logger
from wyvern.wyvern_typing import (
GENERALIZED_WYVERN_ENTITY,
REQUEST_ENTITY,
RESPONSE_SCHEMA,
)


class SingularPipelineRequest(
BaseWyvernRequest,
Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
):
entity: GENERALIZED_WYVERN_ENTITY
request: REQUEST_ENTITY


class SingularPipelineResponse(GenericModel, Generic[RESPONSE_SCHEMA]):
data: RESPONSE_SCHEMA
events: Optional[List[LoggedEvent[Any]]]


class SingularPipelineComponent(
PipelineComponent[
SingularPipelineRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
SingularPipelineResponse[RESPONSE_SCHEMA],
],
Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY, RESPONSE_SCHEMA],
):
def __init__(
self,
*upstreams: Component,
model: SingularModelComponent,
business_logic: Optional[BusinessLogicPipeline] = None,
name: Optional[str] = None,
handle_feature_store_exceptions: bool = False,
) -> None:
self.model = model
self.business_logic = business_logic
upstream_components = list(upstreams)
upstream_components.append(self.model)
if self.business_logic:
upstream_components.append(self.business_logic)
super().__init__(
*upstream_components,
name=name,
handle_feature_store_exceptions=handle_feature_store_exceptions,
)

async def execute(
self,
input: SingularPipelineRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
**kwargs,
) -> SingularPipelineResponse[RESPONSE_SCHEMA]:
model_input = SingularModelInput[
GENERALIZED_WYVERN_ENTITY,
SingularPipelineRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
](
request=input,
entity=input.entity,
)
output = await self.model.execute(model_input, **kwargs)
entity_data = output.data.get(input.entity.identifier)
return SingularPipelineResponse(
data=entity_data,
events=event_logger.get_logged_events() if input.include_events else None,
)
8 changes: 8 additions & 0 deletions wyvern/entities/model_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ def first_identifier(self) -> Identifier:
return self.first_entity.identifier


class SingularModelInput(
ModelInput,
Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
):
entity: GENERALIZED_WYVERN_ENTITY


SINGULAR_MODEL_INPUT = TypeVar("SINGULAR_MODEL_INPUT", bound=SingularModelInput)
MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput)
MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput)

Expand Down

0 comments on commit 20adf41

Please sign in to comment.