Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
integrate with business logic in SingularPipelineComponent
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng committed Sep 17, 2023
1 parent e62f670 commit e87252c
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 39 deletions.
88 changes: 60 additions & 28 deletions wyvern/components/business_logic/business_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import logging
from datetime import datetime
from typing import Generic, List, Optional
from typing import Generic, List, Optional, Sequence

from ddtrace import tracer
from pydantic.generics import GenericModel
Expand All @@ -15,7 +15,10 @@
from wyvern.entities.candidate_entities import (
GENERALIZED_WYVERN_ENTITY,
ScoredCandidate,
ScoredEntity,
ScoredEntityProtocol,
)
from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE
from wyvern.event_logging import event_logger
from wyvern.wyvern_typing import REQUEST_ENTITY

Expand All @@ -35,8 +38,8 @@ class BusinessLogicEventData(EntityEventData):

business_logic_pipeline_order: int
business_logic_name: str
old_score: float
new_score: float
old_score: str
new_score: str


class BusinessLogicEvent(LoggedEvent[BusinessLogicEventData]):
Expand Down Expand Up @@ -65,7 +68,7 @@ class BusinessLogicRequest(

class SingularBusinessLogicRequest(
GenericModel,
Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY],
):
"""
A request to the business logic layer to perform business logic on a single candidate
Expand All @@ -76,7 +79,7 @@ class SingularBusinessLogicRequest(
"""

request: REQUEST_ENTITY
scored_candidate: ScoredCandidate[GENERALIZED_WYVERN_ENTITY]
scored_entity: ScoredEntity[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE]


# TODO (suchintan): Possibly delete this now that events are gone
Expand All @@ -98,7 +101,7 @@ class BusinessLogicResponse(

class SingularBusinessLogicResponse(
GenericModel,
Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY],
):
"""
The response from the business logic layer after performing business logic on a single candidate
Expand All @@ -108,8 +111,12 @@ class SingularBusinessLogicResponse(
adjusted_candidate: The candidate that the business logic layer performed business logic on
"""

request: SingularBusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]
adjusted_candidate: ScoredCandidate[GENERALIZED_WYVERN_ENTITY]
request: SingularBusinessLogicRequest[
GENERALIZED_WYVERN_ENTITY,
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
]
adjusted_entity: ScoredEntity[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE]


class BusinessLogicComponent(
Expand All @@ -130,10 +137,14 @@ class BusinessLogicComponent(

class SingularBusinessLogicComponent(
Component[
SingularBusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
ScoredCandidate[GENERALIZED_WYVERN_ENTITY],
SingularBusinessLogicRequest[
GENERALIZED_WYVERN_ENTITY,
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
],
ScoredEntity[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE],
],
Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY],
):
"""
A component that performs business logic on an entity with a set of candidates
Expand All @@ -144,14 +155,14 @@ class SingularBusinessLogicComponent(
pass


class ExtractEventMixin(Generic[GENERALIZED_WYVERN_ENTITY]):
class ExtractEventMixin(Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE]):
def extract_business_logic_events(
self,
output: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]],
output: Sequence[ScoredEntityProtocol],
pipeline_index: int,
upstream_name: str,
request_id: str,
old_scores: List[float],
old_scores: List,
) -> List[BusinessLogicEvent]:
"""
Extracts the business logic events from the output of a business logic component
Expand All @@ -175,8 +186,8 @@ def extract_business_logic_events(
event_data=BusinessLogicEventData(
business_logic_pipeline_order=pipeline_index,
business_logic_name=upstream_name,
old_score=old_scores[j],
new_score=output[j].score,
old_score=str(old_scores[j]),
new_score=str(output[j].score),
entity_identifier=candidate.entity.identifier.identifier,
entity_identifier_type=candidate.entity.identifier.identifier_type,
),
Expand All @@ -193,7 +204,7 @@ class BusinessLogicPipeline(
BusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
BusinessLogicResponse[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
],
ExtractEventMixin[GENERALIZED_WYVERN_ENTITY],
ExtractEventMixin[GENERALIZED_WYVERN_ENTITY, float],
Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
):
"""
Expand Down Expand Up @@ -278,16 +289,25 @@ def log_events(

class SingularBusinessLogicPipeline(
Component[
SingularBusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
SingularBusinessLogicResponse[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
SingularBusinessLogicRequest[
GENERALIZED_WYVERN_ENTITY,
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
],
SingularBusinessLogicResponse[
GENERALIZED_WYVERN_ENTITY,
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
],
],
ExtractEventMixin[GENERALIZED_WYVERN_ENTITY],
Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
ExtractEventMixin[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE],
Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE, REQUEST_ENTITY],
):
def __init__(
self,
*upstreams: SingularBusinessLogicComponent[
GENERALIZED_WYVERN_ENTITY,
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
],
name: Optional[str] = None,
Expand All @@ -297,15 +317,23 @@ def __init__(

async def execute(
self,
input: SingularBusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
input: SingularBusinessLogicRequest[
GENERALIZED_WYVERN_ENTITY,
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
],
**kwargs,
) -> SingularBusinessLogicResponse[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]:
) -> SingularBusinessLogicResponse[
GENERALIZED_WYVERN_ENTITY,
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
]:
argument = input
for (pipeline_index, upstream) in enumerate(self.ordered_upstreams):
old_scores = [argument.scored_candidate.score]
old_scores = [argument.scored_entity.score]

# this output might have the same reference as the argument.scored_candidates
output = await upstream.execute(input=argument, **kwargs)
output = await upstream.execute(argument, **kwargs)

extracted_events: List[
BusinessLogicEvent
Expand All @@ -325,12 +353,16 @@ def log_events(
# TODO (suchintan): "invariant" list error
event_logger.log_events(log_events) # type: ignore

argument = SingularBusinessLogicRequest(
argument = SingularBusinessLogicRequest[
GENERALIZED_WYVERN_ENTITY,
MODEL_OUTPUT_DATA_TYPE,
REQUEST_ENTITY,
](
request=input.request,
scored_candidate=output,
scored_entity=output,
)

return SingularBusinessLogicResponse(
request=input,
adjusted_candidate=argument.scored_candidate,
adjusted_entity=argument.scored_entity,
)
50 changes: 42 additions & 8 deletions wyvern/components/singular_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# -*- coding: utf-8 -*-
from typing import Any, Generic, List, Optional
from typing import Any, Dict, Generic, List, Optional, Union

from pydantic.generics import GenericModel

from wyvern.components.business_logic.business_logic import BusinessLogicPipeline
from wyvern.components.business_logic.business_logic import (
SingularBusinessLogicPipeline,
SingularBusinessLogicRequest,
)
from wyvern.components.component import Component
from wyvern.components.events.events import LoggedEvent
from wyvern.components.models.model_component import SingularModelComponent
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.entities.candidate_entities import ScoredEntity
from wyvern.entities.model_entities import SingularModelInput
from wyvern.entities.request import BaseWyvernRequest
from wyvern.event_logging import event_logger
Expand Down Expand Up @@ -42,16 +46,20 @@ def __init__(
self,
*upstreams: Component,
model: SingularModelComponent,
business_logic: Optional[BusinessLogicPipeline] = None,
business_logic: Optional[SingularBusinessLogicPipeline] = None,
name: Optional[str] = None,
handle_feature_store_exceptions: bool = False,
) -> None:
self.model = model
self.business_logic = business_logic
self.business_logic: SingularBusinessLogicPipeline

upstream_components = list(upstreams)
upstream_components.append(self.model)
if self.business_logic:
upstream_components.append(self.business_logic)
if business_logic:
self.business_logic = business_logic
else:
self.business_logic = SingularBusinessLogicPipeline()
upstream_components.append(self.business_logic)
super().__init__(
*upstream_components,
name=name,
Expand All @@ -71,8 +79,34 @@ async def execute(
entity=input.entity,
)
output = await self.model.execute(model_input, **kwargs)
entity_data = output.data.get(input.entity.identifier)
entity_data: Union[
float,
str,
List[float],
Dict[str, Optional[Union[float, str, list[float]]]],
] = output.data.get(input.entity.identifier)

business_logic_input = SingularBusinessLogicRequest[
GENERALIZED_WYVERN_ENTITY,
Union[
float,
str,
List[float],
Dict[str, Optional[Union[float, str, list[float]]]],
],
SingularPipelineRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
](
request=input,
scored_entity=ScoredEntity(
entity=input.entity,
score=entity_data,
),
)
business_logic_output = await self.business_logic.execute(
input=business_logic_input,
**kwargs,
)
return SingularPipelineResponse(
data=entity_data,
data=business_logic_output.adjusted_entity.score,
events=event_logger.get_logged_events() if input.include_events else None,
)
35 changes: 32 additions & 3 deletions wyvern/entities/candidate_entities.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from typing import Generic, List, TypeVar
from typing import Generic, List, Protocol, TypeVar

from pydantic.generics import GenericModel

from wyvern.entities.identifier_entities import WyvernDataModel
from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE
from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY


class ScoredEntityProtocol(
Protocol,
Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE],
):
entity: GENERALIZED_WYVERN_ENTITY
score: MODEL_OUTPUT_DATA_TYPE


# TODO (suchintan): This should be renamed to ScoredEntity probably
class ScoredCandidate(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY]):
class ScoredCandidate(
ScoredEntityProtocol[GENERALIZED_WYVERN_ENTITY, float],
GenericModel,
Generic[GENERALIZED_WYVERN_ENTITY],
):
"""
A candidate entity with a score.
Expand All @@ -30,7 +43,6 @@ class CandidateSetEntity(
):
"""
A set of candidate entities. This is a generic model that can be used to represent a set of candidate entities.
Attributes:
candidates: The list of candidate entities.
"""
Expand All @@ -39,3 +51,20 @@ class CandidateSetEntity(


CANDIDATE_SET_ENTITY = TypeVar("CANDIDATE_SET_ENTITY", bound=CandidateSetEntity)


class ScoredEntity(
ScoredEntityProtocol[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE],
GenericModel,
Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE],
):
"""
An entity with a model score.
Attributes:
entity: The candidate entity.
score: Type could be float, str, float or dict. The output from the model for the entity.
"""

entity: GENERALIZED_WYVERN_ENTITY
score: MODEL_OUTPUT_DATA_TYPE

0 comments on commit e87252c

Please sign in to comment.