diff --git a/wyvern/components/business_logic/business_logic.py b/wyvern/components/business_logic/business_logic.py index 1cdd9bc..bb30a68 100644 --- a/wyvern/components/business_logic/business_logic.py +++ b/wyvern/components/business_logic/business_logic.py @@ -3,7 +3,7 @@ import logging from datetime import datetime -from typing import Generic, List, Optional +from typing import Generic, List, Optional, Sequence from ddtrace import tracer from pydantic.generics import GenericModel @@ -15,7 +15,10 @@ from wyvern.entities.candidate_entities import ( GENERALIZED_WYVERN_ENTITY, ScoredCandidate, + ScoredEntity, + ScoredEntityProtocol, ) +from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE from wyvern.event_logging import event_logger from wyvern.wyvern_typing import REQUEST_ENTITY @@ -35,8 +38,8 @@ class BusinessLogicEventData(EntityEventData): business_logic_pipeline_order: int business_logic_name: str - old_score: float - new_score: float + old_score: str + new_score: str class BusinessLogicEvent(LoggedEvent[BusinessLogicEventData]): @@ -65,7 +68,7 @@ class BusinessLogicRequest( class SingularBusinessLogicRequest( GenericModel, - Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], + Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY], ): """ A request to the business logic layer to perform business logic on a single candidate @@ -76,7 +79,7 @@ class SingularBusinessLogicRequest( """ request: REQUEST_ENTITY - scored_candidate: ScoredCandidate[GENERALIZED_WYVERN_ENTITY] + scored_entity: ScoredEntity[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE] # TODO (suchintan): Possibly delete this now that events are gone @@ -98,7 +101,7 @@ class BusinessLogicResponse( class SingularBusinessLogicResponse( GenericModel, - Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], + Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY], ): """ The response from the business logic layer after performing business logic on a single candidate @@ -108,8 +111,12 @@ class SingularBusinessLogicResponse( adjusted_candidate: The candidate that the business logic layer performed business logic on """ - request: SingularBusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY] - adjusted_candidate: ScoredCandidate[GENERALIZED_WYVERN_ENTITY] + request: SingularBusinessLogicRequest[ + GENERALIZED_WYVERN_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + REQUEST_ENTITY, + ] + adjusted_entity: ScoredEntity[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE] class BusinessLogicComponent( @@ -130,10 +137,14 @@ class BusinessLogicComponent( class SingularBusinessLogicComponent( Component[ - SingularBusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], - ScoredCandidate[GENERALIZED_WYVERN_ENTITY], + SingularBusinessLogicRequest[ + GENERALIZED_WYVERN_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + REQUEST_ENTITY, + ], + ScoredEntity[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE], ], - Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], + Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY], ): """ A component that performs business logic on an entity with a set of candidates @@ -144,14 +155,14 @@ class SingularBusinessLogicComponent( pass -class ExtractEventMixin(Generic[GENERALIZED_WYVERN_ENTITY]): +class ExtractEventMixin(Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE]): def extract_business_logic_events( self, - output: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]], + output: Sequence[ScoredEntityProtocol], pipeline_index: int, upstream_name: str, request_id: str, - old_scores: List[float], + old_scores: List, ) -> List[BusinessLogicEvent]: """ Extracts the business logic events from the output of a business logic component @@ -175,8 +186,8 @@ def extract_business_logic_events( event_data=BusinessLogicEventData( business_logic_pipeline_order=pipeline_index, business_logic_name=upstream_name, - old_score=old_scores[j], - new_score=output[j].score, + old_score=str(old_scores[j]), + new_score=str(output[j].score), entity_identifier=candidate.entity.identifier.identifier, entity_identifier_type=candidate.entity.identifier.identifier_type, ), @@ -193,7 +204,7 @@ class BusinessLogicPipeline( BusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], BusinessLogicResponse[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], ], - ExtractEventMixin[GENERALIZED_WYVERN_ENTITY], + ExtractEventMixin[GENERALIZED_WYVERN_ENTITY, float], Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], ): """ @@ -278,16 +289,25 @@ def log_events( class SingularBusinessLogicPipeline( Component[ - SingularBusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], - SingularBusinessLogicResponse[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], + SingularBusinessLogicRequest[ + GENERALIZED_WYVERN_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + REQUEST_ENTITY, + ], + SingularBusinessLogicResponse[ + GENERALIZED_WYVERN_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + REQUEST_ENTITY, + ], ], - ExtractEventMixin[GENERALIZED_WYVERN_ENTITY], - Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], + ExtractEventMixin[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE], + Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY], ): def __init__( self, *upstreams: SingularBusinessLogicComponent[ GENERALIZED_WYVERN_ENTITY, + MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY, ], name: Optional[str] = None, @@ -297,15 +317,23 @@ def __init__( async def execute( self, - input: SingularBusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], + input: SingularBusinessLogicRequest[ + GENERALIZED_WYVERN_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + REQUEST_ENTITY, + ], **kwargs, - ) -> SingularBusinessLogicResponse[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]: + ) -> SingularBusinessLogicResponse[ + GENERALIZED_WYVERN_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + REQUEST_ENTITY, + ]: argument = input for (pipeline_index, upstream) in enumerate(self.ordered_upstreams): - old_scores = [argument.scored_candidate.score] + old_scores = [argument.scored_entity.score] # this output might have the same reference as the argument.scored_candidates - output = await upstream.execute(input=argument, **kwargs) + output = await upstream.execute(argument, **kwargs) extracted_events: List[ BusinessLogicEvent @@ -325,12 +353,16 @@ def log_events( # TODO (suchintan): "invariant" list error event_logger.log_events(log_events) # type: ignore - argument = SingularBusinessLogicRequest( + argument = SingularBusinessLogicRequest[ + GENERALIZED_WYVERN_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + REQUEST_ENTITY, + ]( request=input.request, - scored_candidate=output, + scored_entity=output, ) return SingularBusinessLogicResponse( request=input, - adjusted_candidate=argument.scored_candidate, + adjusted_entity=argument.scored_entity, ) diff --git a/wyvern/components/singular_pipeline.py b/wyvern/components/singular_pipeline.py index def1f12..a706305 100644 --- a/wyvern/components/singular_pipeline.py +++ b/wyvern/components/singular_pipeline.py @@ -1,13 +1,17 @@ # -*- coding: utf-8 -*- -from typing import Any, Generic, List, Optional +from typing import Any, Dict, Generic, List, Optional, Union from pydantic.generics import GenericModel -from wyvern.components.business_logic.business_logic import BusinessLogicPipeline +from wyvern.components.business_logic.business_logic import ( + SingularBusinessLogicPipeline, + SingularBusinessLogicRequest, +) from wyvern.components.component import Component from wyvern.components.events.events import LoggedEvent from wyvern.components.models.model_component import SingularModelComponent from wyvern.components.pipeline_component import PipelineComponent +from wyvern.entities.candidate_entities import ScoredEntity from wyvern.entities.model_entities import SingularModelInput from wyvern.entities.request import BaseWyvernRequest from wyvern.event_logging import event_logger @@ -42,16 +46,20 @@ def __init__( self, *upstreams: Component, model: SingularModelComponent, - business_logic: Optional[BusinessLogicPipeline] = None, + business_logic: Optional[SingularBusinessLogicPipeline] = None, name: Optional[str] = None, handle_feature_store_exceptions: bool = False, ) -> None: self.model = model - self.business_logic = business_logic + self.business_logic: SingularBusinessLogicPipeline + upstream_components = list(upstreams) upstream_components.append(self.model) - if self.business_logic: - upstream_components.append(self.business_logic) + if business_logic: + self.business_logic = business_logic + else: + self.business_logic = SingularBusinessLogicPipeline() + upstream_components.append(self.business_logic) super().__init__( *upstream_components, name=name, @@ -71,8 +79,34 @@ async def execute( entity=input.entity, ) output = await self.model.execute(model_input, **kwargs) - entity_data = output.data.get(input.entity.identifier) + entity_data: Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + ] = output.data.get(input.entity.identifier) + + business_logic_input = SingularBusinessLogicRequest[ + GENERALIZED_WYVERN_ENTITY, + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + ], + SingularPipelineRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], + ]( + request=input, + scored_entity=ScoredEntity( + entity=input.entity, + score=entity_data, + ), + ) + business_logic_output = await self.business_logic.execute( + input=business_logic_input, + **kwargs, + ) return SingularPipelineResponse( - data=entity_data, + data=business_logic_output.adjusted_entity.score, events=event_logger.get_logged_events() if input.include_events else None, ) diff --git a/wyvern/entities/candidate_entities.py b/wyvern/entities/candidate_entities.py index 3f5c0fc..4aac961 100644 --- a/wyvern/entities/candidate_entities.py +++ b/wyvern/entities/candidate_entities.py @@ -1,16 +1,29 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from typing import Generic, List, TypeVar +from typing import Generic, List, Protocol, TypeVar from pydantic.generics import GenericModel from wyvern.entities.identifier_entities import WyvernDataModel +from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY +class ScoredEntityProtocol( + Protocol, + Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE], +): + entity: GENERALIZED_WYVERN_ENTITY + score: MODEL_OUTPUT_DATA_TYPE + + # TODO (suchintan): This should be renamed to ScoredEntity probably -class ScoredCandidate(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY]): +class ScoredCandidate( + ScoredEntityProtocol[GENERALIZED_WYVERN_ENTITY, float], + GenericModel, + Generic[GENERALIZED_WYVERN_ENTITY], +): """ A candidate entity with a score. @@ -30,7 +43,6 @@ class CandidateSetEntity( ): """ A set of candidate entities. This is a generic model that can be used to represent a set of candidate entities. - Attributes: candidates: The list of candidate entities. """ @@ -39,3 +51,20 @@ class CandidateSetEntity( CANDIDATE_SET_ENTITY = TypeVar("CANDIDATE_SET_ENTITY", bound=CandidateSetEntity) + + +class ScoredEntity( + ScoredEntityProtocol[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE], + GenericModel, + Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE], +): + """ + An entity with a model score. + + Attributes: + entity: The candidate entity. + score: Type could be float, str, float or dict. The output from the model for the entity. + """ + + entity: GENERALIZED_WYVERN_ENTITY + score: MODEL_OUTPUT_DATA_TYPE