diff --git a/wyvern/components/business_logic/business_logic.py b/wyvern/components/business_logic/business_logic.py index bb30a68..280512a 100644 --- a/wyvern/components/business_logic/business_logic.py +++ b/wyvern/components/business_logic/business_logic.py @@ -15,9 +15,9 @@ from wyvern.entities.candidate_entities import ( GENERALIZED_WYVERN_ENTITY, ScoredCandidate, - ScoredEntity, ScoredEntityProtocol, ) +from wyvern.entities.identifier import Identifier from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE from wyvern.event_logging import event_logger from wyvern.wyvern_typing import REQUEST_ENTITY @@ -66,9 +66,9 @@ class BusinessLogicRequest( scored_candidates: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]] = [] -class SingularBusinessLogicRequest( +class SingleEntityBusinessLogicRequest( GenericModel, - Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY], + Generic[MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY], ): """ A request to the business logic layer to perform business logic on a single candidate @@ -78,8 +78,9 @@ class SingularBusinessLogicRequest( candidate: The candidate that the business logic layer is being asked to perform business logic on """ + identifier: Identifier request: REQUEST_ENTITY - scored_entity: ScoredEntity[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE] + model_output: MODEL_OUTPUT_DATA_TYPE # TODO (suchintan): Possibly delete this now that events are gone @@ -99,9 +100,9 @@ class BusinessLogicResponse( adjusted_candidates: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]] -class SingularBusinessLogicResponse( +class SingleEntityBusinessLogicResponse( GenericModel, - Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY], + Generic[MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY], ): """ The response from the business logic layer after performing business logic on a single candidate @@ -111,12 +112,8 @@ class SingularBusinessLogicResponse( adjusted_candidate: The candidate that the business logic layer performed business logic on """ - request: SingularBusinessLogicRequest[ - GENERALIZED_WYVERN_ENTITY, - MODEL_OUTPUT_DATA_TYPE, - REQUEST_ENTITY, - ] - adjusted_entity: ScoredEntity[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE] + request: REQUEST_ENTITY + adjusted_model_output: MODEL_OUTPUT_DATA_TYPE class BusinessLogicComponent( @@ -135,16 +132,15 @@ class BusinessLogicComponent( pass -class SingularBusinessLogicComponent( +class SingleEntityBusinessLogicComponent( Component[ - SingularBusinessLogicRequest[ - GENERALIZED_WYVERN_ENTITY, + SingleEntityBusinessLogicRequest[ MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY, ], - ScoredEntity[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE], + MODEL_OUTPUT_DATA_TYPE, ], - Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY], + Generic[MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY], ): """ A component that performs business logic on an entity with a set of candidates @@ -287,26 +283,22 @@ def log_events( ) -class SingularBusinessLogicPipeline( +class SingleEntityBusinessLogicPipeline( Component[ - SingularBusinessLogicRequest[ - GENERALIZED_WYVERN_ENTITY, + SingleEntityBusinessLogicRequest[ MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY, ], - SingularBusinessLogicResponse[ - GENERALIZED_WYVERN_ENTITY, + SingleEntityBusinessLogicResponse[ MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY, ], ], - ExtractEventMixin[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE], - Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY], + Generic[MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY], ): def __init__( self, - *upstreams: SingularBusinessLogicComponent[ - GENERALIZED_WYVERN_ENTITY, + *upstreams: SingleEntityBusinessLogicComponent[ MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY, ], @@ -317,32 +309,25 @@ def __init__( async def execute( self, - input: SingularBusinessLogicRequest[ - GENERALIZED_WYVERN_ENTITY, + input: SingleEntityBusinessLogicRequest[ MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY, ], **kwargs, - ) -> SingularBusinessLogicResponse[ - GENERALIZED_WYVERN_ENTITY, - MODEL_OUTPUT_DATA_TYPE, - REQUEST_ENTITY, - ]: + ) -> SingleEntityBusinessLogicResponse[MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY]: argument = input for (pipeline_index, upstream) in enumerate(self.ordered_upstreams): - old_scores = [argument.scored_entity.score] - - # this output might have the same reference as the argument.scored_candidates + old_output = argument.model_output output = await upstream.execute(argument, **kwargs) - extracted_events: List[ BusinessLogicEvent ] = self.extract_business_logic_events( - [output], + input.identifier, + output, + old_output, pipeline_index, upstream.name, argument.request.request_id, - old_scores, ) def log_events( @@ -350,19 +335,59 @@ def log_events( ): return extracted_events - # TODO (suchintan): "invariant" list error event_logger.log_events(log_events) # type: ignore - argument = SingularBusinessLogicRequest[ - GENERALIZED_WYVERN_ENTITY, + argument = SingleEntityBusinessLogicRequest[ MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY, ]( + identifier=input.identifier, request=input.request, - scored_entity=output, + model_output=output, ) - return SingularBusinessLogicResponse( - request=input, - adjusted_entity=argument.scored_entity, + return SingleEntityBusinessLogicResponse( + request=input.request, + adjusted_model_output=argument.model_output, ) + + def extract_business_logic_events( + self, + identifier: Identifier, + output: MODEL_OUTPUT_DATA_TYPE, + old_output: MODEL_OUTPUT_DATA_TYPE, + pipeline_index: int, + upstream_name: str, + request_id: str, + ) -> List[BusinessLogicEvent]: + """ + Extracts the business logic events from the output of a business logic component + + Args: + output: The output of a business logic component + pipeline_index: The index of the business logic component in the business logic pipeline + upstream_name: The name of the business logic component + request_id: The request id of the request that the business logic component was called in + old_output: The old scores of the candidates that the business logic component was called on + + Returns: + The business logic events that were extracted from the output of the business logic component + """ + timestamp = datetime.utcnow() + events = [ + BusinessLogicEvent( + request_id=request_id, + api_source=request_context.ensure_current_request().url_path, + event_timestamp=timestamp, + event_data=BusinessLogicEventData( + business_logic_pipeline_order=pipeline_index, + business_logic_name=upstream_name, + old_score=str(old_output), + new_score=str(output), + entity_identifier=identifier.identifier, + entity_identifier_type=identifier.identifier_type, + ), + ), + ] + + return events diff --git a/wyvern/components/models/model_component.py b/wyvern/components/models/model_component.py index 69bf7ef..d676ea2 100644 --- a/wyvern/components/models/model_component.py +++ b/wyvern/components/models/model_component.py @@ -13,13 +13,10 @@ from wyvern.config import settings from wyvern.entities.identifier import Identifier from wyvern.entities.identifier_entities import WyvernEntity -from wyvern.entities.model_entities import ( - MODEL_INPUT, - MODEL_OUTPUT, - SINGULAR_MODEL_INPUT, -) +from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT from wyvern.entities.request import BaseWyvernRequest from wyvern.event_logging import event_logger +from wyvern.wyvern_typing import INPUT_TYPE, REQUEST_ENTITY logger = logging.getLogger(__name__) @@ -59,7 +56,7 @@ class ModelEvent(LoggedEvent[ModelEventData]): class BaseModelComponent( Component[ - MODEL_INPUT, + INPUT_TYPE, MODEL_OUTPUT, ], ): @@ -100,13 +97,13 @@ def manifest_feature_names(self) -> Set[str]: """ return set() - async def execute(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: + async def execute(self, input: INPUT_TYPE, **kwargs) -> MODEL_OUTPUT: """ The model_name and model_score will be automatically logged """ wyvern_request = request_context.ensure_current_request() api_source = wyvern_request.url_path - request_id = input.request.request_id + request_id = self._get_request_id(input) model_output = await self.inference(input, **kwargs) if self.cache_output: @@ -156,13 +153,16 @@ def events_generator() -> List[ModelEvent]: async def inference( self, - input: MODEL_INPUT, + input: INPUT_TYPE, **kwargs, ) -> MODEL_OUTPUT: raise NotImplementedError + def _get_request_id(self, input: INPUT_TYPE) -> Optional[str]: + raise NotImplementedError + -class ModelComponent(BaseModelComponent[MODEL_INPUT, MODEL_OUTPUT]): +class MultiEntityModelComponent(BaseModelComponent[MODEL_INPUT, MODEL_OUTPUT]): async def batch_inference( self, request: BaseWyvernRequest, @@ -218,7 +218,16 @@ async def inference( model_name=self.name, ) + def _get_request_id(self, input: MODEL_INPUT) -> Optional[str]: + return input.request.request_id + -class SingularModelComponent(BaseModelComponent[SINGULAR_MODEL_INPUT, MODEL_OUTPUT]): - async def inference(self, input: SINGULAR_MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: +ModelComponent = MultiEntityModelComponent + + +class SingleEntityModelComponent(BaseModelComponent[REQUEST_ENTITY, MODEL_OUTPUT]): + async def inference(self, input: REQUEST_ENTITY, **kwargs) -> MODEL_OUTPUT: raise NotImplementedError + + def _get_request_id(self, input: REQUEST_ENTITY) -> Optional[str]: + return input.request_id diff --git a/wyvern/components/single_entity_pipeline.py b/wyvern/components/single_entity_pipeline.py new file mode 100644 index 0000000..342c9a2 --- /dev/null +++ b/wyvern/components/single_entity_pipeline.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +from typing import Any, Dict, Generic, List, Optional, Union + +from pydantic.generics import GenericModel + +from wyvern.components.business_logic.business_logic import ( + SingleEntityBusinessLogicPipeline, + SingleEntityBusinessLogicRequest, +) +from wyvern.components.component import Component +from wyvern.components.events.events import LoggedEvent +from wyvern.components.models.model_component import SingleEntityModelComponent +from wyvern.components.pipeline_component import PipelineComponent +from wyvern.entities.identifier import Identifier +from wyvern.event_logging import event_logger +from wyvern.exceptions import MissingModeloutputError +from wyvern.wyvern_typing import REQUEST_ENTITY, RESPONSE_SCHEMA + + +class SingleEntityPipelineResponse(GenericModel, Generic[RESPONSE_SCHEMA]): + data: RESPONSE_SCHEMA + events: Optional[List[LoggedEvent[Any]]] + + +class SingleEntityPipeline( + PipelineComponent[ + REQUEST_ENTITY, + SingleEntityPipelineResponse[RESPONSE_SCHEMA], + ], + Generic[REQUEST_ENTITY, RESPONSE_SCHEMA], +): + def __init__( + self, + *upstreams: Component, + model: SingleEntityModelComponent, + business_logic: Optional[SingleEntityBusinessLogicPipeline] = None, + name: Optional[str] = None, + handle_feature_store_exceptions: bool = False, + ) -> None: + self.model = model + self.business_logic: SingleEntityBusinessLogicPipeline + + upstream_components = list(upstreams) + upstream_components.append(self.model) + if business_logic: + self.business_logic = business_logic + else: + self.business_logic = SingleEntityBusinessLogicPipeline() + upstream_components.append(self.business_logic) + super().__init__( + *upstream_components, + name=name, + handle_feature_store_exceptions=handle_feature_store_exceptions, + ) + + async def execute( + self, + input: REQUEST_ENTITY, + **kwargs, + ) -> SingleEntityPipelineResponse[RESPONSE_SCHEMA]: + output = await self.model.execute(input, **kwargs) + identifiers: List[Identifier] = output.keys() + if not identifiers: + raise MissingModeloutputError() + identifier = identifiers[0] + model_output_data: Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + ] = output.data.get(identifier) + + business_logic_input = SingleEntityBusinessLogicRequest[ + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + ], + REQUEST_ENTITY, + ]( + identifier=identifier, + request=input, + model_output=model_output_data, + ) + business_logic_output = await self.business_logic.execute( + input=business_logic_input, + **kwargs, + ) + return SingleEntityPipelineResponse( + data=business_logic_output.adjusted_model_output, + events=event_logger.get_logged_events() if input.include_events else None, + ) diff --git a/wyvern/components/singular_pipeline.py b/wyvern/components/singular_pipeline.py deleted file mode 100644 index a706305..0000000 --- a/wyvern/components/singular_pipeline.py +++ /dev/null @@ -1,112 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Any, Dict, Generic, List, Optional, Union - -from pydantic.generics import GenericModel - -from wyvern.components.business_logic.business_logic import ( - SingularBusinessLogicPipeline, - SingularBusinessLogicRequest, -) -from wyvern.components.component import Component -from wyvern.components.events.events import LoggedEvent -from wyvern.components.models.model_component import SingularModelComponent -from wyvern.components.pipeline_component import PipelineComponent -from wyvern.entities.candidate_entities import ScoredEntity -from wyvern.entities.model_entities import SingularModelInput -from wyvern.entities.request import BaseWyvernRequest -from wyvern.event_logging import event_logger -from wyvern.wyvern_typing import ( - GENERALIZED_WYVERN_ENTITY, - REQUEST_ENTITY, - RESPONSE_SCHEMA, -) - - -class SingularPipelineRequest( - BaseWyvernRequest, - Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], -): - entity: GENERALIZED_WYVERN_ENTITY - request: REQUEST_ENTITY - - -class SingularPipelineResponse(GenericModel, Generic[RESPONSE_SCHEMA]): - data: RESPONSE_SCHEMA - events: Optional[List[LoggedEvent[Any]]] - - -class SingularPipelineComponent( - PipelineComponent[ - SingularPipelineRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], - SingularPipelineResponse[RESPONSE_SCHEMA], - ], - Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY, RESPONSE_SCHEMA], -): - def __init__( - self, - *upstreams: Component, - model: SingularModelComponent, - business_logic: Optional[SingularBusinessLogicPipeline] = None, - name: Optional[str] = None, - handle_feature_store_exceptions: bool = False, - ) -> None: - self.model = model - self.business_logic: SingularBusinessLogicPipeline - - upstream_components = list(upstreams) - upstream_components.append(self.model) - if business_logic: - self.business_logic = business_logic - else: - self.business_logic = SingularBusinessLogicPipeline() - upstream_components.append(self.business_logic) - super().__init__( - *upstream_components, - name=name, - handle_feature_store_exceptions=handle_feature_store_exceptions, - ) - - async def execute( - self, - input: SingularPipelineRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], - **kwargs, - ) -> SingularPipelineResponse[RESPONSE_SCHEMA]: - model_input = SingularModelInput[ - GENERALIZED_WYVERN_ENTITY, - SingularPipelineRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], - ]( - request=input, - entity=input.entity, - ) - output = await self.model.execute(model_input, **kwargs) - entity_data: Union[ - float, - str, - List[float], - Dict[str, Optional[Union[float, str, list[float]]]], - ] = output.data.get(input.entity.identifier) - - business_logic_input = SingularBusinessLogicRequest[ - GENERALIZED_WYVERN_ENTITY, - Union[ - float, - str, - List[float], - Dict[str, Optional[Union[float, str, list[float]]]], - ], - SingularPipelineRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], - ]( - request=input, - scored_entity=ScoredEntity( - entity=input.entity, - score=entity_data, - ), - ) - business_logic_output = await self.business_logic.execute( - input=business_logic_input, - **kwargs, - ) - return SingularPipelineResponse( - data=business_logic_output.adjusted_entity.score, - events=event_logger.get_logged_events() if input.include_events else None, - ) diff --git a/wyvern/entities/candidate_entities.py b/wyvern/entities/candidate_entities.py index e2ecd9a..1aa93da 100644 --- a/wyvern/entities/candidate_entities.py +++ b/wyvern/entities/candidate_entities.py @@ -50,19 +50,3 @@ class CandidateSetEntity( CANDIDATE_SET_ENTITY = TypeVar("CANDIDATE_SET_ENTITY", bound=CandidateSetEntity) - - -class ScoredEntity( - GenericModel, - Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE], -): - """ - An entity with a model score. - - Attributes: - entity: The candidate entity. - score: Type could be float, str, float or dict. The output from the model for the entity. - """ - - entity: GENERALIZED_WYVERN_ENTITY - score: MODEL_OUTPUT_DATA_TYPE diff --git a/wyvern/entities/model_entities.py b/wyvern/entities/model_entities.py index 763f892..71f5e8b 100644 --- a/wyvern/entities/model_entities.py +++ b/wyvern/entities/model_entities.py @@ -86,14 +86,6 @@ def first_identifier(self) -> Identifier: return self.first_entity.identifier -class SingularModelInput( - ModelInput, - Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], -): - entity: GENERALIZED_WYVERN_ENTITY - - -SINGULAR_MODEL_INPUT = TypeVar("SINGULAR_MODEL_INPUT", bound=SingularModelInput) MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput) MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput) diff --git a/wyvern/exceptions.py b/wyvern/exceptions.py index 68cee4e..3f2abd6 100644 --- a/wyvern/exceptions.py +++ b/wyvern/exceptions.py @@ -158,3 +158,7 @@ class EntityColumnMissingError(WyvernError): class MissingModelChainOutputError(WyvernError): message = "Model chain output is missing" + + +class MissingModeloutputError(WyvernError): + message = "Identifier is missing in the model output"