From abad535d34790665863a7c284ed3e12955327bf3 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Thu, 7 Sep 2023 00:04:11 -0700 Subject: [PATCH 01/10] multi model evaluation --- tests/scenarios/test_product_ranking.py | 7 +- wyvern/components/models/model_component.py | 94 +------------------ .../components/models/modelbit_component.py | 7 +- wyvern/components/ranking_pipeline.py | 3 +- wyvern/entities/model_entities.py | 85 +++++++++++++++++ wyvern/wyvern_request.py | 3 + 6 files changed, 96 insertions(+), 103 deletions(-) create mode 100644 wyvern/entities/model_entities.py diff --git a/tests/scenarios/test_product_ranking.py b/tests/scenarios/test_product_ranking.py index 68ff3fe..096df65 100644 --- a/tests/scenarios/test_product_ranking.py +++ b/tests/scenarios/test_product_ranking.py @@ -13,11 +13,7 @@ RealtimeFeatureComponent, RealtimeFeatureRequest, ) -from wyvern.components.models.model_component import ( - ModelComponent, - ModelInput, - ModelOutput, -) +from wyvern.components.models.model_component import ModelComponent from wyvern.components.pipeline_component import PipelineComponent from wyvern.config import settings from wyvern.core.compression import wyvern_encode @@ -26,6 +22,7 @@ from wyvern.entities.feature_entities import FeatureData, FeatureMap from wyvern.entities.identifier import Identifier from wyvern.entities.identifier_entities import ProductEntity, WyvernEntity +from wyvern.entities.model_entities import ModelInput, ModelOutput from wyvern.entities.request import BaseWyvernRequest from wyvern.service import WyvernService from wyvern.wyvern_request import WyvernRequest diff --git a/wyvern/components/models/model_component.py b/wyvern/components/models/model_component.py index 6aac719..128c25b 100644 --- a/wyvern/components/models/model_component.py +++ b/wyvern/components/models/model_component.py @@ -3,21 +3,9 @@ import logging from datetime import datetime from functools import cached_property -from typing import ( - Dict, - Generic, - List, - Optional, - Sequence, - Set, - Type, - TypeVar, - Union, - get_args, -) +from typing import Dict, List, Optional, Sequence, Set, Type, Union, get_args from pydantic import BaseModel -from pydantic.generics import GenericModel from wyvern import request_context from wyvern.components.component import Component @@ -25,19 +13,9 @@ from wyvern.config import settings from wyvern.entities.identifier import Identifier from wyvern.entities.identifier_entities import WyvernEntity +from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT from wyvern.entities.request import BaseWyvernRequest from wyvern.event_logging import event_logger -from wyvern.exceptions import WyvernModelInputError -from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY - -MODEL_OUTPUT_DATA_TYPE = TypeVar( - "MODEL_OUTPUT_DATA_TYPE", - bound=Union[float, str, List[float]], -) -""" -MODEL_OUTPUT_DATA_TYPE is the type of the output of the model. It can be a float, a string, or a list of floats -(e.g. a list of probabilities, embeddings, etc.) -""" logger = logging.getLogger(__name__) @@ -71,74 +49,6 @@ class ModelEvent(LoggedEvent[ModelEventData]): event_type: EventType = EventType.MODEL -class ModelOutput(GenericModel, Generic[MODEL_OUTPUT_DATA_TYPE]): - """ - This class defines the output of a model. - - Args: - data: A dictionary mapping entity identifiers to model outputs. The model outputs can also be None. - model_name: The name of the model. This is optional. - """ - - data: Dict[Identifier, Optional[MODEL_OUTPUT_DATA_TYPE]] - model_name: Optional[str] = None - - def get_entity_output( - self, - identifier: Identifier, - ) -> Optional[MODEL_OUTPUT_DATA_TYPE]: - """ - Get the model output for a given entity identifier. - - Args: - identifier: The identifier of the entity. - - Returns: - The model output for the given entity identifier. This can also be None if the model output is None. - """ - return self.data.get(identifier) - - -class ModelInput(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): - """ - This class defines the input to a model. - - Args: - request: The request that will be used to generate the model input. - entities: A list of entities that will be used to generate the model input. - """ - - request: REQUEST_ENTITY - entities: List[GENERALIZED_WYVERN_ENTITY] = [] - - @property - def first_entity(self) -> GENERALIZED_WYVERN_ENTITY: - """ - Get the first entity in the list of entities. This is useful when you know that there is only one entity. - - Returns: - The first entity in the list of entities. - """ - if not self.entities: - raise WyvernModelInputError(model_input=self) - return self.entities[0] - - @property - def first_identifier(self) -> Identifier: - """ - Get the identifier of the first entity in the list of entities. This is useful when you know that there is only - one entity. - - Returns: - The identifier of the first entity in the list of entities. - """ - return self.first_entity.identifier - - -MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput) -MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput) - - class ModelComponent( Component[ MODEL_INPUT, diff --git a/wyvern/components/models/modelbit_component.py b/wyvern/components/models/modelbit_component.py index 92d60c5..9b1ada1 100644 --- a/wyvern/components/models/modelbit_component.py +++ b/wyvern/components/models/modelbit_component.py @@ -4,15 +4,12 @@ from functools import cached_property from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union -from wyvern.components.models.model_component import ( - MODEL_INPUT, - MODEL_OUTPUT, - ModelComponent, -) +from wyvern.components.models.model_component import ModelComponent from wyvern.config import settings from wyvern.core.http import aiohttp_client from wyvern.entities.identifier import Identifier from wyvern.entities.identifier_entities import WyvernEntity +from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT from wyvern.entities.request import BaseWyvernRequest from wyvern.exceptions import ( WyvernModelbitTokenMissingError, diff --git a/wyvern/components/ranking_pipeline.py b/wyvern/components/ranking_pipeline.py index bdbbafd..3625b40 100644 --- a/wyvern/components/ranking_pipeline.py +++ b/wyvern/components/ranking_pipeline.py @@ -13,7 +13,7 @@ ImpressionEventLoggingComponent, ImpressionEventLoggingRequest, ) -from wyvern.components.models.model_component import ModelComponent, ModelInput +from wyvern.components.models.model_component import ModelComponent from wyvern.components.pagination.pagination_component import ( PaginationComponent, PaginationRequest, @@ -22,6 +22,7 @@ from wyvern.components.pipeline_component import PipelineComponent from wyvern.entities.candidate_entities import ScoredCandidate from wyvern.entities.identifier_entities import QueryEntity +from wyvern.entities.model_entities import ModelInput from wyvern.entities.request import BaseWyvernRequest from wyvern.event_logging import event_logger from wyvern.wyvern_typing import WYVERN_ENTITY diff --git a/wyvern/entities/model_entities.py b/wyvern/entities/model_entities.py new file mode 100644 index 0000000..c80aec9 --- /dev/null +++ b/wyvern/entities/model_entities.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +from typing import Dict, Generic, List, Optional, TypeVar, Union + +from pydantic.generics import GenericModel + +from wyvern.entities.identifier import Identifier +from wyvern.exceptions import WyvernModelInputError +from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY + +MODEL_OUTPUT_DATA_TYPE = TypeVar( + "MODEL_OUTPUT_DATA_TYPE", + bound=Union[float, str, List[float]], +) +""" +MODEL_OUTPUT_DATA_TYPE is the type of the output of the model. It can be a float, a string, or a list of floats +(e.g. a list of probabilities, embeddings, etc.) +""" + + +class ModelOutput(GenericModel, Generic[MODEL_OUTPUT_DATA_TYPE]): + """ + This class defines the output of a model. + + Args: + data: A dictionary mapping entity identifiers to model outputs. The model outputs can also be None. + model_name: The name of the model. This is optional. + """ + + data: Dict[Identifier, Optional[MODEL_OUTPUT_DATA_TYPE]] + model_name: Optional[str] = None + + def get_entity_output( + self, + identifier: Identifier, + ) -> Optional[MODEL_OUTPUT_DATA_TYPE]: + """ + Get the model output for a given entity identifier. + + Args: + identifier: The identifier of the entity. + + Returns: + The model output for the given entity identifier. This can also be None if the model output is None. + """ + return self.data.get(identifier) + + +class ModelInput(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): + """ + This class defines the input to a model. + + Args: + request: The request that will be used to generate the model input. + entities: A list of entities that will be used to generate the model input. + """ + + request: REQUEST_ENTITY + entities: List[GENERALIZED_WYVERN_ENTITY] = [] + + @property + def first_entity(self) -> GENERALIZED_WYVERN_ENTITY: + """ + Get the first entity in the list of entities. This is useful when you know that there is only one entity. + + Returns: + The first entity in the list of entities. + """ + if not self.entities: + raise WyvernModelInputError(model_input=self) + return self.entities[0] + + @property + def first_identifier(self) -> Identifier: + """ + Get the identifier of the first entity in the list of entities. This is useful when you know that there is only + one entity. + + Returns: + The identifier of the first entity in the list of entities. + """ + return self.first_entity.identifier + + +MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput) +MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput) diff --git a/wyvern/wyvern_request.py b/wyvern/wyvern_request.py index f80342e..61b50ca 100644 --- a/wyvern/wyvern_request.py +++ b/wyvern/wyvern_request.py @@ -10,6 +10,7 @@ from wyvern.components.events.events import LoggedEvent from wyvern.entities.feature_entities import FeatureMap +from wyvern.entities.identifier import Identifier @dataclass @@ -43,6 +44,7 @@ class WyvernRequest: events: List[Callable[[], List[LoggedEvent[Any]]]] feature_map: FeatureMap + model_score_map: Dict[str, Dict[Identifier, float]] request_id: Optional[str] = None @@ -75,5 +77,6 @@ def parse_fastapi_request( entity_store={}, events=[], feature_map=FeatureMap(feature_map={}), + model_score_map={}, request_id=request_id, ) From 67b793a40831ef6394ae487b883a4fd5a6ef4d96 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Thu, 7 Sep 2023 00:45:41 -0700 Subject: [PATCH 02/10] fix test --- tests/components/business_logic/test_pinning_business_logic.py | 1 + tests/scenarios/test_product_ranking.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tests/components/business_logic/test_pinning_business_logic.py b/tests/components/business_logic/test_pinning_business_logic.py index f0b945f..665fc91 100644 --- a/tests/components/business_logic/test_pinning_business_logic.py +++ b/tests/components/business_logic/test_pinning_business_logic.py @@ -66,6 +66,7 @@ def __init__(self): entity_store={}, events=[], feature_map=FeatureMap(feature_map={}), + model_score_map={}, ), ) return await pipeline.execute(request) diff --git a/tests/scenarios/test_product_ranking.py b/tests/scenarios/test_product_ranking.py index 096df65..8cf4296 100644 --- a/tests/scenarios/test_product_ranking.py +++ b/tests/scenarios/test_product_ranking.py @@ -384,6 +384,7 @@ async def test_hydrate(mock_redis): json=json_input, headers={}, entity_store={}, + model_score_map={}, events=[], feature_map=FeatureMap(feature_map={}), ) @@ -447,6 +448,7 @@ async def test_hydrate__duplicate_brand(mock_redis__duplicate_brand): entity_store={}, events=[], feature_map=FeatureMap(feature_map={}), + model_score_map={}, ) request_context.set(test_wyvern_request) From 016d620fbaecd376ab28c5434a952ea1ea3c403d Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 13 Sep 2023 11:09:16 -0700 Subject: [PATCH 03/10] chained model evaluation --- .../models/multi_model_component.py | 51 +++++++++++++++++++ wyvern/entities/model_entities.py | 5 ++ wyvern/exceptions.py | 4 ++ wyvern/wyvern_request.py | 6 ++- 4 files changed, 64 insertions(+), 2 deletions(-) create mode 100644 wyvern/components/models/multi_model_component.py diff --git a/wyvern/components/models/multi_model_component.py b/wyvern/components/models/multi_model_component.py new file mode 100644 index 0000000..bc37639 --- /dev/null +++ b/wyvern/components/models/multi_model_component.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +from functools import cached_property +from typing import Optional, Set + +from wyvern.components.models.model_component import ModelComponent +from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT, ChainedModelInput +from wyvern.exceptions import MissingModelChainOutputError + + +class ModelChainComponent(ModelComponent[MODEL_INPUT, MODEL_OUTPUT]): + """ + Model chaining allows you to chain models together so that the output of one model can be the input to another model + + For all the models in the chain, all the request and entities in the model input are the same + """ + + def __init__(self, *upstreams: ModelComponent, name: Optional[str] = None): + super().__init__(*upstreams, name=name) + self.chain = upstreams + + @cached_property + def manifest_feature_names(self) -> Set[str]: + return set() + + async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: + output = None + prev_model: Optional[ModelComponent] = None + for model in self.chain: + curr_input: ChainedModelInput + if prev_model is not None and output is not None: + curr_input = ChainedModelInput( + request=input.request, + entities=input.entities, + upstream_model_name=prev_model.name, + upstream_model_output=output.data, + ) + else: + curr_input = ChainedModelInput( + request=input.request, + entities=input.entities, + upstream_model_name=None, + upstream_model_output={}, + ) + output = await model.execute(curr_input, **kwargs) + prev_model = model + + if output is None: + raise MissingModelChainOutputError() + + # TODO: do type checking to make sure the output is of the correct type + return output diff --git a/wyvern/entities/model_entities.py b/wyvern/entities/model_entities.py index c80aec9..cc384ba 100644 --- a/wyvern/entities/model_entities.py +++ b/wyvern/entities/model_entities.py @@ -81,5 +81,10 @@ def first_identifier(self) -> Identifier: return self.first_entity.identifier +class ChainedModelInput(ModelInput[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): + upstream_model_output: Dict[Identifier, Optional[Union[float, str, List[float]]]] + upstream_model_name: Optional[str] = None + + MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput) MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput) diff --git a/wyvern/exceptions.py b/wyvern/exceptions.py index 2007d88..68cee4e 100644 --- a/wyvern/exceptions.py +++ b/wyvern/exceptions.py @@ -154,3 +154,7 @@ class ExperimentationClientInitializationError(WyvernError): class EntityColumnMissingError(WyvernError): message = "Entity column {entity} is missing in the entity data" + + +class MissingModelChainOutputError(WyvernError): + message = "Model chain output is missing" diff --git a/wyvern/wyvern_request.py b/wyvern/wyvern_request.py index 61b50ca..e6b073b 100644 --- a/wyvern/wyvern_request.py +++ b/wyvern/wyvern_request.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from urllib.parse import urlparse import fastapi @@ -44,7 +44,9 @@ class WyvernRequest: events: List[Callable[[], List[LoggedEvent[Any]]]] feature_map: FeatureMap - model_score_map: Dict[str, Dict[Identifier, float]] + + # the key is the name of the model and the value is a map of the identifier to the model score + model_score_map: Dict[str, Dict[Identifier, Union[float, str, List[float], None]]] request_id: Optional[str] = None From 148de8b7884fc7fc13fd184e851483b0b8c32cf9 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 13 Sep 2023 11:22:41 -0700 Subject: [PATCH 04/10] cache model output --- wyvern/components/models/model_component.py | 9 ++++++++- wyvern/wyvern_request.py | 9 +++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/wyvern/components/models/model_component.py b/wyvern/components/models/model_component.py index 128c25b..48ebe7e 100644 --- a/wyvern/components/models/model_component.py +++ b/wyvern/components/models/model_component.py @@ -65,11 +65,14 @@ def __init__( self, *upstreams, name: Optional[str] = None, + cache_output: bool = True, ): super().__init__(*upstreams, name=name) self.model_input_type = self.get_type_args_simple(0) self.model_output_type = self.get_type_args_simple(1) + self.cache_output = cache_output + @classmethod def get_type_args_simple(cls, index: int) -> Type: """ @@ -95,10 +98,14 @@ async def execute(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: """ The model_name and model_score will be automatically logged """ - api_source = request_context.ensure_current_request().url_path + wyvern_request = request_context.ensure_current_request() + api_source = wyvern_request.url_path request_id = input.request.request_id model_output = await self.inference(input, **kwargs) + if self.cache_output: + wyvern_request.cache_model_score(self.name, model_output.data) + def events_generator() -> List[ModelEvent]: timestamp = datetime.utcnow() return [ diff --git a/wyvern/wyvern_request.py b/wyvern/wyvern_request.py index e6b073b..41cb396 100644 --- a/wyvern/wyvern_request.py +++ b/wyvern/wyvern_request.py @@ -82,3 +82,12 @@ def parse_fastapi_request( model_score_map={}, request_id=request_id, ) + + def cache_model_score( + self, + model_name: str, + data: Dict[Identifier, Union[float, str, List[float], None]], + ) -> None: + if model_name not in self.model_score_map: + self.model_score_map[model_name] = {} + self.model_score_map[model_name].update(data) From ee85723650b320df50eb3db8fec8016787943e54 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 13 Sep 2023 12:34:36 -0700 Subject: [PATCH 05/10] v bump --- pyproject.toml | 2 +- wyvern/__init__.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6924830..342146e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wyvern-ai" -version = "0.0.17" +version = "0.0.18" description = "" authors = ["Wyvern AI "] readme = "README.md" diff --git a/wyvern/__init__.py b/wyvern/__init__.py index 08823ba..96129d5 100644 --- a/wyvern/__init__.py +++ b/wyvern/__init__.py @@ -2,11 +2,7 @@ from wyvern.components.features.realtime_features_component import ( RealtimeFeatureComponent, ) -from wyvern.components.models.model_component import ( - ModelComponent, - ModelInput, - ModelOutput, -) +from wyvern.components.models.model_component import ModelComponent from wyvern.components.pipeline_component import PipelineComponent from wyvern.components.ranking_pipeline import ( RankingPipeline, @@ -23,6 +19,7 @@ WyvernDataModel, WyvernEntity, ) +from wyvern.entities.model_entities import ModelInput, ModelOutput from wyvern.feature_store.feature_server import generate_wyvern_store_app from wyvern.service import WyvernService from wyvern.wyvern_logging import setup_logging From 7cebacec889b190a81625a95d493b7cb6c4a5728 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 13 Sep 2023 17:38:17 -0700 Subject: [PATCH 06/10] support dict model output --- wyvern/components/models/model_component.py | 50 +++++++++++++++------ wyvern/entities/model_entities.py | 19 +++++++- wyvern/wyvern_request.py | 25 ++++++++++- 3 files changed, 76 insertions(+), 18 deletions(-) diff --git a/wyvern/components/models/model_component.py b/wyvern/components/models/model_component.py index 48ebe7e..cdaa3ec 100644 --- a/wyvern/components/models/model_component.py +++ b/wyvern/components/models/model_component.py @@ -36,6 +36,7 @@ class ModelEventData(BaseModel): model_output: str entity_identifier: Optional[str] = None entity_identifier_type: Optional[str] = None + target: Optional[str] = None class ModelEvent(LoggedEvent[ModelEventData]): @@ -108,20 +109,41 @@ async def execute(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: def events_generator() -> List[ModelEvent]: timestamp = datetime.utcnow() - return [ - ModelEvent( - request_id=request_id, - api_source=api_source, - event_timestamp=timestamp, - event_data=ModelEventData( - model_name=model_output.model_name or self.__class__.__name__, - model_output=str(output), - entity_identifier=identifier.identifier, - entity_identifier_type=identifier.identifier_type, - ), - ) - for identifier, output in model_output.data.items() - ] + all_events: List[ModelEvent] = [] + for identifier, output in model_output.data.items(): + if isinstance(output, dict): + for key, value in output.items(): + all_events.append( + ModelEvent( + request_id=request_id, + api_source=api_source, + event_timestamp=timestamp, + event_data=ModelEventData( + model_name=model_output.model_name + or self.__class__.__name__, + model_output=str(value), + target=key, + entity_identifier=identifier.identifier, + entity_identifier_type=identifier.identifier_type, + ), + ), + ) + else: + all_events.append( + ModelEvent( + request_id=request_id, + api_source=api_source, + event_timestamp=timestamp, + event_data=ModelEventData( + model_name=model_output.model_name + or self.__class__.__name__, + model_output=str(output), + entity_identifier=identifier.identifier, + entity_identifier_type=identifier.identifier_type, + ), + ), + ) + return all_events event_logger.log_events(events_generator) # type: ignore diff --git a/wyvern/entities/model_entities.py b/wyvern/entities/model_entities.py index cc384ba..f67842a 100644 --- a/wyvern/entities/model_entities.py +++ b/wyvern/entities/model_entities.py @@ -9,7 +9,12 @@ MODEL_OUTPUT_DATA_TYPE = TypeVar( "MODEL_OUTPUT_DATA_TYPE", - bound=Union[float, str, List[float]], + bound=Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + ], ) """ MODEL_OUTPUT_DATA_TYPE is the type of the output of the model. It can be a float, a string, or a list of floats @@ -82,7 +87,17 @@ def first_identifier(self) -> Identifier: class ChainedModelInput(ModelInput[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): - upstream_model_output: Dict[Identifier, Optional[Union[float, str, List[float]]]] + upstream_model_output: Dict[ + Identifier, + Optional[ + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + ] + ], + ] upstream_model_name: Optional[str] = None diff --git a/wyvern/wyvern_request.py b/wyvern/wyvern_request.py index 41cb396..8d9fc1a 100644 --- a/wyvern/wyvern_request.py +++ b/wyvern/wyvern_request.py @@ -46,7 +46,19 @@ class WyvernRequest: feature_map: FeatureMap # the key is the name of the model and the value is a map of the identifier to the model score - model_score_map: Dict[str, Dict[Identifier, Union[float, str, List[float], None]]] + model_score_map: Dict[ + str, + Dict[ + Identifier, + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + None, + ], + ], + ] request_id: Optional[str] = None @@ -86,7 +98,16 @@ def parse_fastapi_request( def cache_model_score( self, model_name: str, - data: Dict[Identifier, Union[float, str, List[float], None]], + data: Dict[ + Identifier, + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + None, + ], + ], ) -> None: if model_name not in self.model_score_map: self.model_score_map[model_name] = {} From d16a30f6c29f298df783ab63678c29a1a24a03b3 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 13 Sep 2023 19:10:40 -0700 Subject: [PATCH 07/10] rename --- wyvern/__init__.py | 2 ++ .../{multi_model_component.py => model_chain_component.py} | 0 2 files changed, 2 insertions(+) rename wyvern/components/models/{multi_model_component.py => model_chain_component.py} (100%) diff --git a/wyvern/__init__.py b/wyvern/__init__.py index 96129d5..cc7906c 100644 --- a/wyvern/__init__.py +++ b/wyvern/__init__.py @@ -2,6 +2,7 @@ from wyvern.components.features.realtime_features_component import ( RealtimeFeatureComponent, ) +from wyvern.components.models.model_chain_component import ModelChainComponent from wyvern.components.models.model_component import ModelComponent from wyvern.components.pipeline_component import PipelineComponent from wyvern.components.ranking_pipeline import ( @@ -38,6 +39,7 @@ "FeatureMap", "Identifier", "IdentifierType", + "ModelChainComponent", "ModelComponent", "ModelInput", "ModelOutput", diff --git a/wyvern/components/models/multi_model_component.py b/wyvern/components/models/model_chain_component.py similarity index 100% rename from wyvern/components/models/multi_model_component.py rename to wyvern/components/models/model_chain_component.py From b155ada97573cf396141765f7d4e1c688f9300d6 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 13 Sep 2023 19:32:18 -0700 Subject: [PATCH 08/10] add get_model_output support --- .../test_pinning_business_logic.py | 2 +- tests/scenarios/test_product_ranking.py | 4 +-- wyvern/components/component.py | 17 ++++++++++- wyvern/components/models/model_component.py | 2 +- wyvern/wyvern_request.py | 29 +++++++++++++++---- 5 files changed, 43 insertions(+), 11 deletions(-) diff --git a/tests/components/business_logic/test_pinning_business_logic.py b/tests/components/business_logic/test_pinning_business_logic.py index 665fc91..edc30a7 100644 --- a/tests/components/business_logic/test_pinning_business_logic.py +++ b/tests/components/business_logic/test_pinning_business_logic.py @@ -66,7 +66,7 @@ def __init__(self): entity_store={}, events=[], feature_map=FeatureMap(feature_map={}), - model_score_map={}, + model_output_map={}, ), ) return await pipeline.execute(request) diff --git a/tests/scenarios/test_product_ranking.py b/tests/scenarios/test_product_ranking.py index 8cf4296..a5c233a 100644 --- a/tests/scenarios/test_product_ranking.py +++ b/tests/scenarios/test_product_ranking.py @@ -384,7 +384,7 @@ async def test_hydrate(mock_redis): json=json_input, headers={}, entity_store={}, - model_score_map={}, + model_output_map={}, events=[], feature_map=FeatureMap(feature_map={}), ) @@ -448,7 +448,7 @@ async def test_hydrate__duplicate_brand(mock_redis__duplicate_brand): entity_store={}, events=[], feature_map=FeatureMap(feature_map={}), - model_score_map={}, + model_output_map={}, ) request_context.set(test_wyvern_request) diff --git a/wyvern/components/component.py b/wyvern/components/component.py index b660097..ded81ca 100644 --- a/wyvern/components/component.py +++ b/wyvern/components/component.py @@ -5,7 +5,7 @@ import logging from enum import Enum from functools import cached_property -from typing import Dict, Generic, Optional, Set +from typing import Dict, Generic, List, Optional, Set, Union from uuid import uuid4 from wyvern import request_context @@ -177,3 +177,18 @@ def get_all_features( if not feature_data: return {} return feature_data.features + + def get_model_output( + self, + model_name: str, + identifier: Identifier, + ) -> Optional[ + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + ] + ]: + current_request = request_context.ensure_current_request() + return current_request.get_model_output(model_name, identifier) diff --git a/wyvern/components/models/model_component.py b/wyvern/components/models/model_component.py index cdaa3ec..3eaa4f4 100644 --- a/wyvern/components/models/model_component.py +++ b/wyvern/components/models/model_component.py @@ -105,7 +105,7 @@ async def execute(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: model_output = await self.inference(input, **kwargs) if self.cache_output: - wyvern_request.cache_model_score(self.name, model_output.data) + wyvern_request.cache_model_output(self.name, model_output.data) def events_generator() -> List[ModelEvent]: timestamp = datetime.utcnow() diff --git a/wyvern/wyvern_request.py b/wyvern/wyvern_request.py index 8d9fc1a..690d84a 100644 --- a/wyvern/wyvern_request.py +++ b/wyvern/wyvern_request.py @@ -46,7 +46,7 @@ class WyvernRequest: feature_map: FeatureMap # the key is the name of the model and the value is a map of the identifier to the model score - model_score_map: Dict[ + model_output_map: Dict[ str, Dict[ Identifier, @@ -91,11 +91,11 @@ def parse_fastapi_request( entity_store={}, events=[], feature_map=FeatureMap(feature_map={}), - model_score_map={}, + model_output_map={}, request_id=request_id, ) - def cache_model_score( + def cache_model_output( self, model_name: str, data: Dict[ @@ -109,6 +109,23 @@ def cache_model_score( ], ], ) -> None: - if model_name not in self.model_score_map: - self.model_score_map[model_name] = {} - self.model_score_map[model_name].update(data) + if model_name not in self.model_output_map: + self.model_output_map[model_name] = {} + self.model_output_map[model_name].update(data) + + def get_model_output( + self, + model_name: str, + identifier: Identifier, + ) -> Optional[ + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + None, + ] + ]: + if model_name not in self.model_output_map: + return None + return self.model_output_map[model_name].get(identifier) From 3cfe5b4ec441e24755487e15c4ea99e1a716bcf9 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Thu, 14 Sep 2023 03:11:17 -0700 Subject: [PATCH 09/10] fix --- wyvern/components/models/model_chain_component.py | 5 ++++- wyvern/entities/model_entities.py | 10 +++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/wyvern/components/models/model_chain_component.py b/wyvern/components/models/model_chain_component.py index bc37639..a7d10e1 100644 --- a/wyvern/components/models/model_chain_component.py +++ b/wyvern/components/models/model_chain_component.py @@ -20,7 +20,10 @@ def __init__(self, *upstreams: ModelComponent, name: Optional[str] = None): @cached_property def manifest_feature_names(self) -> Set[str]: - return set() + feature_names: Set[str] = set() + for model in self.chain: + feature_names = feature_names.union(model.manifest_feature_names) + return feature_names async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: output = None diff --git a/wyvern/entities/model_entities.py b/wyvern/entities/model_entities.py index f67842a..71f5e8b 100644 --- a/wyvern/entities/model_entities.py +++ b/wyvern/entities/model_entities.py @@ -86,7 +86,11 @@ def first_identifier(self) -> Identifier: return self.first_entity.identifier -class ChainedModelInput(ModelInput[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): +MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput) +MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput) + + +class ChainedModelInput(ModelInput, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): upstream_model_output: Dict[ Identifier, Optional[ @@ -99,7 +103,3 @@ class ChainedModelInput(ModelInput[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): ], ] upstream_model_name: Optional[str] = None - - -MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput) -MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput) From f6a12c016ff1365995b7f36a70c09d9e3eb97f0d Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Thu, 14 Sep 2023 03:14:04 -0700 Subject: [PATCH 10/10] 0.0.18-beta1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 342146e..db33e04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wyvern-ai" -version = "0.0.18" +version = "0.0.18-beta1" description = "" authors = ["Wyvern AI "] readme = "README.md"