Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
update to SingleEntity style
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng committed Sep 18, 2023
1 parent bfb61c8 commit 524c58f
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 194 deletions.
117 changes: 71 additions & 46 deletions wyvern/components/business_logic/business_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from wyvern.entities.candidate_entities import (
GENERALIZED_WYVERN_ENTITY,
ScoredCandidate,
ScoredEntity,
ScoredEntityProtocol,
)
from wyvern.entities.identifier import Identifier
from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE
from wyvern.event_logging import event_logger
from wyvern.wyvern_typing import REQUEST_ENTITY
Expand Down Expand Up @@ -66,9 +66,9 @@ class BusinessLogicRequest(
scored_candidates: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]] = []


class SingularBusinessLogicRequest(
class SingleEntityBusinessLogicRequest(
GenericModel,
Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY],
Generic[MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY],
):
"""
A request to the business logic layer to perform business logic on a single candidate
Expand All @@ -78,8 +78,9 @@ class SingularBusinessLogicRequest(
candidate: The candidate that the business logic layer is being asked to perform business logic on
"""

identifier: Identifier
request: REQUEST_ENTITY
scored_entity: ScoredEntity[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE]
model_output: MODEL_OUTPUT_DATA_TYPE


# TODO (suchintan): Possibly delete this now that events are gone
Expand All @@ -99,9 +100,9 @@ class BusinessLogicResponse(
adjusted_candidates: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]]


class SingularBusinessLogicResponse(
class SingleEntityBusinessLogicResponse(
GenericModel,
Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY],
Generic[MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY],
):
"""
The response from the business logic layer after performing business logic on a single candidate
Expand All @@ -111,12 +112,8 @@ class SingularBusinessLogicResponse(
adjusted_candidate: The candidate that the business logic layer performed business logic on
"""

request: SingularBusinessLogicRequest[
GENERALIZED_WYVERN_ENTITY,
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
]
adjusted_entity: ScoredEntity[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE]
request: REQUEST_ENTITY
adjusted_model_output: MODEL_OUTPUT_DATA_TYPE


class BusinessLogicComponent(
Expand All @@ -135,16 +132,15 @@ class BusinessLogicComponent(
pass


class SingularBusinessLogicComponent(
class SingleEntityBusinessLogicComponent(
Component[
SingularBusinessLogicRequest[
GENERALIZED_WYVERN_ENTITY,
SingleEntityBusinessLogicRequest[
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
],
ScoredEntity[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE],
MODEL_OUTPUT_DATA_TYPE,
],
Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY],
Generic[MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY],
):
"""
A component that performs business logic on an entity with a set of candidates
Expand Down Expand Up @@ -287,26 +283,22 @@ def log_events(
)


class SingularBusinessLogicPipeline(
class SingleEntityBusinessLogicPipeline(
Component[
SingularBusinessLogicRequest[
GENERALIZED_WYVERN_ENTITY,
SingleEntityBusinessLogicRequest[
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
],
SingularBusinessLogicResponse[
GENERALIZED_WYVERN_ENTITY,
SingleEntityBusinessLogicResponse[
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
],
],
ExtractEventMixin[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE],
Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY],
Generic[MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY],
):
def __init__(
self,
*upstreams: SingularBusinessLogicComponent[
GENERALIZED_WYVERN_ENTITY,
*upstreams: SingleEntityBusinessLogicComponent[
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
],
Expand All @@ -317,52 +309,85 @@ def __init__(

async def execute(
self,
input: SingularBusinessLogicRequest[
GENERALIZED_WYVERN_ENTITY,
input: SingleEntityBusinessLogicRequest[
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
],
**kwargs,
) -> SingularBusinessLogicResponse[
GENERALIZED_WYVERN_ENTITY,
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
]:
) -> SingleEntityBusinessLogicResponse[MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY]:
argument = input
for (pipeline_index, upstream) in enumerate(self.ordered_upstreams):
old_scores = [argument.scored_entity.score]

# this output might have the same reference as the argument.scored_candidates
old_output = argument.model_output
output = await upstream.execute(argument, **kwargs)

extracted_events: List[
BusinessLogicEvent
] = self.extract_business_logic_events(
[output],
input.identifier,
output,
old_output,
pipeline_index,
upstream.name,
argument.request.request_id,
old_scores,
)

def log_events(
extracted_events: List[BusinessLogicEvent] = extracted_events,
):
return extracted_events

# TODO (suchintan): "invariant" list error
event_logger.log_events(log_events) # type: ignore

argument = SingularBusinessLogicRequest[
GENERALIZED_WYVERN_ENTITY,
argument = SingleEntityBusinessLogicRequest[
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
](
identifier=input.identifier,
request=input.request,
scored_entity=output,
model_output=output,
)

return SingularBusinessLogicResponse(
request=input,
adjusted_entity=argument.scored_entity,
return SingleEntityBusinessLogicResponse(
request=input.request,
adjusted_model_output=argument.model_output,
)

def extract_business_logic_events(
self,
identifier: Identifier,
output: MODEL_OUTPUT_DATA_TYPE,
old_output: MODEL_OUTPUT_DATA_TYPE,
pipeline_index: int,
upstream_name: str,
request_id: str,
) -> List[BusinessLogicEvent]:
"""
Extracts the business logic events from the output of a business logic component
Args:
output: The output of a business logic component
pipeline_index: The index of the business logic component in the business logic pipeline
upstream_name: The name of the business logic component
request_id: The request id of the request that the business logic component was called in
old_output: The old scores of the candidates that the business logic component was called on
Returns:
The business logic events that were extracted from the output of the business logic component
"""
timestamp = datetime.utcnow()
events = [
BusinessLogicEvent(
request_id=request_id,
api_source=request_context.ensure_current_request().url_path,
event_timestamp=timestamp,
event_data=BusinessLogicEventData(
business_logic_pipeline_order=pipeline_index,
business_logic_name=upstream_name,
old_score=str(old_output),
new_score=str(output),
entity_identifier=identifier.identifier,
entity_identifier_type=identifier.identifier_type,
),
),
]

return events
33 changes: 21 additions & 12 deletions wyvern/components/models/model_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
from wyvern.config import settings
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import WyvernEntity
from wyvern.entities.model_entities import (
MODEL_INPUT,
MODEL_OUTPUT,
SINGULAR_MODEL_INPUT,
)
from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT
from wyvern.entities.request import BaseWyvernRequest
from wyvern.event_logging import event_logger
from wyvern.wyvern_typing import INPUT_TYPE, REQUEST_ENTITY

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,7 +56,7 @@ class ModelEvent(LoggedEvent[ModelEventData]):

class BaseModelComponent(
Component[
MODEL_INPUT,
INPUT_TYPE,
MODEL_OUTPUT,
],
):
Expand Down Expand Up @@ -100,13 +97,13 @@ def manifest_feature_names(self) -> Set[str]:
"""
return set()

async def execute(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:
async def execute(self, input: INPUT_TYPE, **kwargs) -> MODEL_OUTPUT:
"""
The model_name and model_score will be automatically logged
"""
wyvern_request = request_context.ensure_current_request()
api_source = wyvern_request.url_path
request_id = input.request.request_id
request_id = self._get_request_id(input)
model_output = await self.inference(input, **kwargs)

if self.cache_output:
Expand Down Expand Up @@ -156,13 +153,16 @@ def events_generator() -> List[ModelEvent]:

async def inference(
self,
input: MODEL_INPUT,
input: INPUT_TYPE,
**kwargs,
) -> MODEL_OUTPUT:
raise NotImplementedError

def _get_request_id(self, input: INPUT_TYPE) -> Optional[str]:
raise NotImplementedError


class ModelComponent(BaseModelComponent[MODEL_INPUT, MODEL_OUTPUT]):
class MultiEntityModelComponent(BaseModelComponent[MODEL_INPUT, MODEL_OUTPUT]):
async def batch_inference(
self,
request: BaseWyvernRequest,
Expand Down Expand Up @@ -218,7 +218,16 @@ async def inference(
model_name=self.name,
)

def _get_request_id(self, input: MODEL_INPUT) -> Optional[str]:
return input.request.request_id


class SingularModelComponent(BaseModelComponent[SINGULAR_MODEL_INPUT, MODEL_OUTPUT]):
async def inference(self, input: SINGULAR_MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:
ModelComponent = MultiEntityModelComponent


class SingleEntityModelComponent(BaseModelComponent[REQUEST_ENTITY, MODEL_OUTPUT]):
async def inference(self, input: REQUEST_ENTITY, **kwargs) -> MODEL_OUTPUT:
raise NotImplementedError

def _get_request_id(self, input: REQUEST_ENTITY) -> Optional[str]:
return input.request_id
93 changes: 93 additions & 0 deletions wyvern/components/single_entity_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
from typing import Any, Dict, Generic, List, Optional, Union

from pydantic.generics import GenericModel

from wyvern.components.business_logic.business_logic import (
SingleEntityBusinessLogicPipeline,
SingleEntityBusinessLogicRequest,
)
from wyvern.components.component import Component
from wyvern.components.events.events import LoggedEvent
from wyvern.components.models.model_component import SingleEntityModelComponent
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.entities.identifier import Identifier
from wyvern.event_logging import event_logger
from wyvern.exceptions import MissingModeloutputError
from wyvern.wyvern_typing import REQUEST_ENTITY, RESPONSE_SCHEMA


class SingleEntityPipelineResponse(GenericModel, Generic[RESPONSE_SCHEMA]):
data: RESPONSE_SCHEMA
events: Optional[List[LoggedEvent[Any]]]


class SingleEntityPipeline(
PipelineComponent[
REQUEST_ENTITY,
SingleEntityPipelineResponse[RESPONSE_SCHEMA],
],
Generic[REQUEST_ENTITY, RESPONSE_SCHEMA],
):
def __init__(
self,
*upstreams: Component,
model: SingleEntityModelComponent,
business_logic: Optional[SingleEntityBusinessLogicPipeline] = None,
name: Optional[str] = None,
handle_feature_store_exceptions: bool = False,
) -> None:
self.model = model
self.business_logic: SingleEntityBusinessLogicPipeline

upstream_components = list(upstreams)
upstream_components.append(self.model)
if business_logic:
self.business_logic = business_logic
else:
self.business_logic = SingleEntityBusinessLogicPipeline()
upstream_components.append(self.business_logic)
super().__init__(
*upstream_components,
name=name,
handle_feature_store_exceptions=handle_feature_store_exceptions,
)

async def execute(
self,
input: REQUEST_ENTITY,
**kwargs,
) -> SingleEntityPipelineResponse[RESPONSE_SCHEMA]:
output = await self.model.execute(input, **kwargs)
identifiers: List[Identifier] = output.keys()
if not identifiers:
raise MissingModeloutputError()
identifier = identifiers[0]
model_output_data: Union[
float,
str,
List[float],
Dict[str, Optional[Union[float, str, list[float]]]],
] = output.data.get(identifier)

business_logic_input = SingleEntityBusinessLogicRequest[
Union[
float,
str,
List[float],
Dict[str, Optional[Union[float, str, list[float]]]],
],
REQUEST_ENTITY,
](
identifier=identifier,
request=input,
model_output=model_output_data,
)
business_logic_output = await self.business_logic.execute(
input=business_logic_input,
**kwargs,
)
return SingleEntityPipelineResponse(
data=business_logic_output.adjusted_model_output,
events=event_logger.get_logged_events() if input.include_events else None,
)
Loading

0 comments on commit 524c58f

Please sign in to comment.