diff --git a/pyproject.toml b/pyproject.toml index 6924830..342146e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wyvern-ai" -version = "0.0.17" +version = "0.0.18" description = "" authors = ["Wyvern AI "] readme = "README.md" diff --git a/tests/components/business_logic/test_pinning_business_logic.py b/tests/components/business_logic/test_pinning_business_logic.py index f0b945f..edc30a7 100644 --- a/tests/components/business_logic/test_pinning_business_logic.py +++ b/tests/components/business_logic/test_pinning_business_logic.py @@ -66,6 +66,7 @@ def __init__(self): entity_store={}, events=[], feature_map=FeatureMap(feature_map={}), + model_output_map={}, ), ) return await pipeline.execute(request) diff --git a/tests/scenarios/single_entity_pipelines/__init__.py b/tests/scenarios/single_entity_pipelines/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/scenarios/single_entity_pipelines/test_single_entity_pipeline.py b/tests/scenarios/single_entity_pipelines/test_single_entity_pipeline.py new file mode 100644 index 0000000..55e7881 --- /dev/null +++ b/tests/scenarios/single_entity_pipelines/test_single_entity_pipeline.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +from typing import List + +import pytest +from fastapi.testclient import TestClient + +from wyvern.components.business_logic.business_logic import ( + SingleEntityBusinessLogicComponent, + SingleEntityBusinessLogicPipeline, + SingleEntityBusinessLogicRequest, +) +from wyvern.components.models.model_chain_component import SingleEntityModelChain +from wyvern.components.models.model_component import SingleEntityModelComponent +from wyvern.components.single_entity_pipeline import ( + SingleEntityPipeline, + SingleEntityPipelineResponse, +) +from wyvern.entities.identifier import Identifier +from wyvern.entities.identifier_entities import WyvernEntity +from wyvern.entities.model_entities import ModelOutput +from wyvern.entities.request import BaseWyvernRequest +from wyvern.service import WyvernService + + +class Seller(WyvernEntity): + seller_id: str + + def generate_identifier(self) -> Identifier: + return Identifier( + identifier=self.seller_id, + identifier_type="seller", + ) + + +class Buyer(WyvernEntity): + buyer_id: str + + def generate_identifier(self) -> Identifier: + return Identifier( + identifier=self.buyer_id, + identifier_type="buyer", + ) + + +class Order(WyvernEntity): + order_id: str + + def generate_identifier(self) -> Identifier: + return Identifier( + identifier=self.order_id, + identifier_type="order", + ) + + +class FraudRequest(BaseWyvernRequest): + seller: Seller + buyer: Buyer + order: Order + + +class FraudResponse(SingleEntityPipelineResponse[float]): + reasons: List[str] + + +class FraudRuleModel(SingleEntityModelComponent[FraudRequest, ModelOutput[float]]): + async def inference(self, input: FraudRequest, **kwargs) -> ModelOutput[float]: + return ModelOutput( + data={ + input.order.identifier: 1, + }, + ) + + +class FraudAssessmentModel( + SingleEntityModelComponent[FraudRequest, ModelOutput[float]], +): + async def inference(self, input: FraudRequest, **kwargs) -> ModelOutput[float]: + return ModelOutput( + data={ + input.order.identifier: 1, + }, + ) + + +fraud_model = SingleEntityModelChain[FraudRequest, ModelOutput[float]]( + FraudRuleModel(), + FraudAssessmentModel(), + name="fraud_model", +) + + +class FraudBusinessLogicComponent( + SingleEntityBusinessLogicComponent[FraudRequest, float], +): + async def execute( + self, + input: SingleEntityBusinessLogicRequest[FraudRequest, float], + **kwargs, + ) -> float: + if input.request.seller.identifier.identifier == "test_seller_new": + return 0.0 + return input.model_output + + +fraud_biz_pipeline = SingleEntityBusinessLogicPipeline( + FraudBusinessLogicComponent(), + name="fraud_biz_pipeline", +) + + +class FraudPipeline(SingleEntityPipeline[FraudRequest, float]): + PATH = "/fraud" + REQUEST_SCHEMA_CLASS = FraudRequest + RESPONSE_SCHEMA_CLASS = FraudResponse + + def generate_response( + self, + input: FraudRequest, + pipeline_output: float, + ) -> FraudResponse: + if pipeline_output == 0.0: + return FraudResponse( + data=pipeline_output, + reasons=["Fraudulent order detected!"], + ) + return FraudResponse( + data=pipeline_output, + reasons=[], + ) + + +fraud_pipeline = FraudPipeline(model=fraud_model, business_logic=fraud_biz_pipeline) + + +@pytest.fixture +def mock_redis(mocker): + with mocker.patch( + "wyvern.redis.wyvern_redis.mget", + return_value=[], + ): + yield + + +@pytest.fixture +def test_client(mock_redis): + wyvern_app = WyvernService.generate_app( + route_components=[fraud_pipeline], + ) + yield TestClient(wyvern_app) + + +def test_end_to_end(test_client): + response = test_client.post( + "/api/v1/fraud", + json={ + "request_id": "test_request_id", + "seller": {"seller_id": "test_seller_id"}, + "buyer": {"buyer_id": "test_buyer_id"}, + "order": {"order_id": "test_order_id"}, + }, + ) + assert response.status_code == 200 + assert response.json() == {"data": 1.0, "reasons": []} + + +def test_end_to_end__new_seller(test_client): + response = test_client.post( + "/api/v1/fraud", + json={ + "request_id": "test_request_id", + "seller": {"seller_id": "test_seller_new"}, + "buyer": {"buyer_id": "test_buyer_id"}, + "order": {"order_id": "test_order_id"}, + }, + ) + assert response.status_code == 200 + assert response.json() == {"data": 0.0, "reasons": ["Fraudulent order detected!"]} diff --git a/tests/scenarios/test_product_ranking.py b/tests/scenarios/test_product_ranking.py index 68ff3fe..697249b 100644 --- a/tests/scenarios/test_product_ranking.py +++ b/tests/scenarios/test_product_ranking.py @@ -13,11 +13,7 @@ RealtimeFeatureComponent, RealtimeFeatureRequest, ) -from wyvern.components.models.model_component import ( - ModelComponent, - ModelInput, - ModelOutput, -) +from wyvern.components.models.model_component import ModelComponent from wyvern.components.pipeline_component import PipelineComponent from wyvern.config import settings from wyvern.core.compression import wyvern_encode @@ -26,6 +22,7 @@ from wyvern.entities.feature_entities import FeatureData, FeatureMap from wyvern.entities.identifier import Identifier from wyvern.entities.identifier_entities import ProductEntity, WyvernEntity +from wyvern.entities.model_entities import ModelInput, ModelOutput from wyvern.entities.request import BaseWyvernRequest from wyvern.service import WyvernService from wyvern.wyvern_request import WyvernRequest @@ -321,10 +318,11 @@ async def execute( @pytest.fixture def test_client(mock_redis): - wyvern_service = WyvernService.generate( + wyvern_app = WyvernService.generate_app( route_components=[RankingComponent], + realtime_feature_components=[], ) - yield TestClient(wyvern_service.service.app) + yield TestClient(wyvern_app) def test_get_all_identifiers(): @@ -387,6 +385,7 @@ async def test_hydrate(mock_redis): json=json_input, headers={}, entity_store={}, + model_output_map={}, events=[], feature_map=FeatureMap(feature_map={}), ) @@ -450,6 +449,7 @@ async def test_hydrate__duplicate_brand(mock_redis__duplicate_brand): entity_store={}, events=[], feature_map=FeatureMap(feature_map={}), + model_output_map={}, ) request_context.set(test_wyvern_request) diff --git a/wyvern/__init__.py b/wyvern/__init__.py index 08823ba..168f860 100644 --- a/wyvern/__init__.py +++ b/wyvern/__init__.py @@ -2,10 +2,11 @@ from wyvern.components.features.realtime_features_component import ( RealtimeFeatureComponent, ) +from wyvern.components.models.model_chain_component import SingleEntityModelChain from wyvern.components.models.model_component import ( ModelComponent, - ModelInput, - ModelOutput, + MultiEntityModelComponent, + SingleEntityModelComponent, ) from wyvern.components.pipeline_component import PipelineComponent from wyvern.components.ranking_pipeline import ( @@ -13,6 +14,10 @@ RankingRequest, RankingResponse, ) +from wyvern.components.single_entity_pipeline import ( + SingleEntityPipeline, + SingleEntityPipelineResponse, +) from wyvern.entities.candidate_entities import CandidateSetEntity from wyvern.entities.feature_entities import FeatureData, FeatureMap from wyvern.entities.identifier import CompositeIdentifier, Identifier, IdentifierType @@ -23,6 +28,7 @@ WyvernDataModel, WyvernEntity, ) +from wyvern.entities.model_entities import ChainedModelInput, ModelInput, ModelOutput from wyvern.feature_store.feature_server import generate_wyvern_store_app from wyvern.service import WyvernService from wyvern.wyvern_logging import setup_logging @@ -36,6 +42,7 @@ __all__ = [ "generate_wyvern_store_app", "CandidateSetEntity", + "ChainedModelInput", "CompositeIdentifier", "FeatureData", "FeatureMap", @@ -44,6 +51,7 @@ "ModelComponent", "ModelInput", "ModelOutput", + "MultiEntityModelComponent", "PipelineComponent", "ProductEntity", "QueryEntity", @@ -51,6 +59,10 @@ "RankingResponse", "RankingRequest", "RealtimeFeatureComponent", + "SingleEntityModelChain", + "SingleEntityModelComponent", + "SingleEntityPipeline", + "SingleEntityPipelineResponse", "UserEntity", "WyvernDataModel", "WyvernEntity", diff --git a/wyvern/components/business_logic/business_logic.py b/wyvern/components/business_logic/business_logic.py index a2ee695..1f5ec14 100644 --- a/wyvern/components/business_logic/business_logic.py +++ b/wyvern/components/business_logic/business_logic.py @@ -3,7 +3,7 @@ import logging from datetime import datetime -from typing import Generic, List, Optional +from typing import Generic, List, Optional, Sequence from ddtrace import tracer from pydantic.generics import GenericModel @@ -16,6 +16,8 @@ GENERALIZED_WYVERN_ENTITY, ScoredCandidate, ) +from wyvern.entities.identifier import Identifier +from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE from wyvern.event_logging import event_logger from wyvern.wyvern_typing import REQUEST_ENTITY @@ -35,8 +37,8 @@ class BusinessLogicEventData(EntityEventData): business_logic_pipeline_order: int business_logic_name: str - old_score: float - new_score: float + old_score: str + new_score: str class BusinessLogicEvent(LoggedEvent[BusinessLogicEventData]): @@ -60,10 +62,24 @@ class BusinessLogicRequest( """ request: REQUEST_ENTITY - scored_candidates: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]] + scored_candidates: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]] = [] - # TODO (suchintan): Give business logic layer access to the feature map in the future - # feature_map: FeatureMap + +class SingleEntityBusinessLogicRequest( + GenericModel, + Generic[REQUEST_ENTITY, MODEL_OUTPUT_DATA_TYPE], +): + """ + A request to the business logic layer to perform business logic on a single candidate + + Parameters: + request: The request that the business logic layer is being asked to perform business logic on + candidate: The candidate that the business logic layer is being asked to perform business logic on + """ + + identifier: Identifier + request: REQUEST_ENTITY + model_output: MODEL_OUTPUT_DATA_TYPE # TODO (suchintan): Possibly delete this now that events are gone @@ -83,6 +99,22 @@ class BusinessLogicResponse( adjusted_candidates: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]] +class SingleEntityBusinessLogicResponse( + GenericModel, + Generic[REQUEST_ENTITY, MODEL_OUTPUT_DATA_TYPE], +): + """ + The response from the business logic layer after performing business logic on a single candidate + + Parameters: + request: The request that the business logic layer was asked to perform business logic on + adjusted_candidate: The candidate that the business logic layer performed business logic on + """ + + request: REQUEST_ENTITY + adjusted_output: MODEL_OUTPUT_DATA_TYPE + + class BusinessLogicComponent( Component[ BusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], @@ -99,6 +131,25 @@ class BusinessLogicComponent( pass +class SingleEntityBusinessLogicComponent( + Component[ + SingleEntityBusinessLogicRequest[ + REQUEST_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + ], + MODEL_OUTPUT_DATA_TYPE, + ], + Generic[REQUEST_ENTITY, MODEL_OUTPUT_DATA_TYPE], +): + """ + A component that performs business logic on an entity with a set of candidates + + The request itself could contain more than just entities, for example it may contain a query and so on + """ + + pass + + class BusinessLogicPipeline( Component[ BusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], @@ -141,7 +192,7 @@ async def execute( """ argument = input - # Make sure that the inputted candidates are actually sorted + # Make sure that the input candidates are actually sorted output = await self.sorting_component.execute(input.scored_candidates) for (pipeline_index, upstream) in enumerate(self.ordered_upstreams): @@ -187,11 +238,11 @@ def log_events( def extract_business_logic_events( self, - output: List[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]], + output: Sequence[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]], pipeline_index: int, upstream_name: str, request_id: str, - old_scores: List[float], + old_scores: List, ) -> List[BusinessLogicEvent]: """ Extracts the business logic events from the output of a business logic component @@ -215,8 +266,8 @@ def extract_business_logic_events( event_data=BusinessLogicEventData( business_logic_pipeline_order=pipeline_index, business_logic_name=upstream_name, - old_score=old_scores[j], - new_score=output[j].score, + old_score=str(old_scores[j]), + new_score=str(output[j].score), entity_identifier=candidate.entity.identifier.identifier, entity_identifier_type=candidate.entity.identifier.identifier_type, ), @@ -226,3 +277,115 @@ def extract_business_logic_events( ] return events + + +class SingleEntityBusinessLogicPipeline( + Component[ + SingleEntityBusinessLogicRequest[ + REQUEST_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + ], + SingleEntityBusinessLogicResponse[ + REQUEST_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + ], + ], + Generic[REQUEST_ENTITY, MODEL_OUTPUT_DATA_TYPE], +): + def __init__( + self, + *upstreams: SingleEntityBusinessLogicComponent[ + REQUEST_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + ], + name: Optional[str] = None, + ): + self.ordered_upstreams = upstreams + super().__init__(*upstreams, name=name) + + async def execute( + self, + input: SingleEntityBusinessLogicRequest[ + REQUEST_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + ], + **kwargs, + ) -> SingleEntityBusinessLogicResponse[REQUEST_ENTITY, MODEL_OUTPUT_DATA_TYPE]: + argument = input + for (pipeline_index, upstream) in enumerate(self.ordered_upstreams): + old_output = str(argument.model_output) + + output = await upstream.execute(argument, **kwargs) + extracted_events: List[ + BusinessLogicEvent + ] = self.extract_business_logic_events( + input.identifier, + str(output), + old_output, + pipeline_index, + upstream.name, + input.request.request_id, + ) + + def log_events( + extracted_events: List[BusinessLogicEvent] = extracted_events, + ): + return extracted_events + + event_logger.log_events(log_events) # type: ignore + + argument = SingleEntityBusinessLogicRequest[ + REQUEST_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + ]( + identifier=input.identifier, + request=input.request, + model_output=output, + ) + + return SingleEntityBusinessLogicResponse( + request=input.request, + adjusted_output=argument.model_output, + ) + + def extract_business_logic_events( + self, + identifier: Identifier, + output: str, + old_output: str, + pipeline_index: int, + upstream_name: str, + request_id: str, + ) -> List[BusinessLogicEvent]: + """ + Extracts the business logic events from the output of a business logic component + + Args: + identifier: The identifier of the entity that the describe what the model output is for + output: The output of a business logic component + old_output: The old scores of the candidates that the business logic component was called on + pipeline_index: The index of the business logic component in the business logic pipeline + upstream_name: The name of the business logic component + request_id: The request id of the request that the business logic component was called in + + Returns: + The business logic events that were extracted from the output of the business logic component + """ + timestamp = datetime.utcnow() + events = [ + BusinessLogicEvent( + request_id=request_id, + api_source=request_context.ensure_current_request().url_path, + event_timestamp=timestamp, + event_data=BusinessLogicEventData( + business_logic_pipeline_order=pipeline_index, + business_logic_name=upstream_name, + old_score=old_output, + new_score=output, + entity_identifier=identifier.identifier, + entity_identifier_type=identifier.identifier_type, + ), + ), + ] + + return events diff --git a/wyvern/components/component.py b/wyvern/components/component.py index b660097..4c1f3d0 100644 --- a/wyvern/components/component.py +++ b/wyvern/components/component.py @@ -5,7 +5,7 @@ import logging from enum import Enum from functools import cached_property -from typing import Dict, Generic, Optional, Set +from typing import Dict, Generic, List, Optional, Set, Union from uuid import uuid4 from wyvern import request_context @@ -177,3 +177,25 @@ def get_all_features( if not feature_data: return {} return feature_data.features + + def get_model_output( + self, + model_name: str, + identifier: Identifier, + ) -> Optional[ + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + ] + ]: + """ + Gets the model output for the given identifier + + Args: + model_name: str. The name of the model + identifier: Identifier. The entity identifier + """ + current_request = request_context.ensure_current_request() + return current_request.get_model_output(model_name, identifier) diff --git a/wyvern/components/features/feature_retrieval_pipeline.py b/wyvern/components/features/feature_retrieval_pipeline.py index e8a073d..5575367 100644 --- a/wyvern/components/features/feature_retrieval_pipeline.py +++ b/wyvern/components/features/feature_retrieval_pipeline.py @@ -6,6 +6,7 @@ from ddtrace import tracer from pydantic.generics import GenericModel +from wyvern import request_context from wyvern.components.component import Component from wyvern.components.features.feature_logger import ( FeatureEventLoggingComponent, @@ -126,12 +127,15 @@ async def execute( # Or the client wants to evaluate the feature # TODO (suchintan): We don't support "chained" real-time features yet.. hopefully soon real_time_features = self._generate_real_time_features(input) - + real_time_feature_component_names = { + real_time_feature_component.name + for real_time_feature_component in real_time_features + } # Figure out which features are real-time features based on the definitions within the real-time feature object features_requested_by_real_time_features = { feature_name - for real_time_feature in real_time_features - for feature_name in real_time_feature.output_feature_names + for feature_name in input.requested_feature_names + if feature_name.split(":")[0] in real_time_feature_component_names } # Figure out which features come from the feature store @@ -156,6 +160,8 @@ async def execute( **kwargs, ) ) + current_request = request_context.ensure_current_request() + current_request.feature_map = feature_retrieval_response """ TODO (suchintan): diff --git a/wyvern/components/models/model_chain_component.py b/wyvern/components/models/model_chain_component.py new file mode 100644 index 0000000..5085126 --- /dev/null +++ b/wyvern/components/models/model_chain_component.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +from functools import cached_property +from typing import Optional, Set + +from wyvern.components.models.model_component import ( + BaseModelComponent, + MultiEntityModelComponent, + SingleEntityModelComponent, +) +from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT, ChainedModelInput +from wyvern.exceptions import MissingModelChainOutputError +from wyvern.wyvern_typing import REQUEST_ENTITY + + +class MultiEntityModelChain(MultiEntityModelComponent[MODEL_INPUT, MODEL_OUTPUT]): + def __init__(self, *upstreams: BaseModelComponent, name: Optional[str] = None): + super().__init__(*upstreams, name=name) + self.chain = upstreams + + @cached_property + def manifest_feature_names(self) -> Set[str]: + feature_names: Set[str] = set() + for model in self.chain: + feature_names = feature_names.union(model.manifest_feature_names) + return feature_names + + async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: + output = None + prev_model: Optional[BaseModelComponent] = None + for model in self.chain: + curr_input: ChainedModelInput + if prev_model is not None and output is not None: + curr_input = ChainedModelInput( + request=input.request, + entities=input.entities, + upstream_model_name=prev_model.name, + upstream_model_output=output.data, + ) + else: + curr_input = ChainedModelInput( + request=input.request, + entities=input.entities, + upstream_model_name=None, + upstream_model_output={}, + ) + output = await model.execute(curr_input, **kwargs) + prev_model = model + + if output is None: + raise MissingModelChainOutputError() + + # TODO: do type checking to make sure the output is of the correct type + return output + + +class SingleEntityModelChain(SingleEntityModelComponent[REQUEST_ENTITY, MODEL_OUTPUT]): + def __init__( + self, *upstreams: SingleEntityModelComponent, name: Optional[str] = None + ): + super().__init__(*upstreams, name=name) + self.chain = upstreams + + @cached_property + def manifest_feature_names(self) -> Set[str]: + feature_names: Set[str] = set() + for model in self.chain: + feature_names = feature_names.union(model.manifest_feature_names) + return feature_names + + async def inference(self, input: REQUEST_ENTITY, **kwargs) -> MODEL_OUTPUT: + output = None + for model in self.chain: + output = await model.execute(input, **kwargs) + + if output is None: + raise MissingModelChainOutputError() + + # TODO: do type checking to make sure the output is of the correct type + return output diff --git a/wyvern/components/models/model_component.py b/wyvern/components/models/model_component.py index 6aac719..d676ea2 100644 --- a/wyvern/components/models/model_component.py +++ b/wyvern/components/models/model_component.py @@ -3,21 +3,9 @@ import logging from datetime import datetime from functools import cached_property -from typing import ( - Dict, - Generic, - List, - Optional, - Sequence, - Set, - Type, - TypeVar, - Union, - get_args, -) +from typing import Dict, List, Optional, Sequence, Set, Type, Union, get_args from pydantic import BaseModel -from pydantic.generics import GenericModel from wyvern import request_context from wyvern.components.component import Component @@ -25,19 +13,10 @@ from wyvern.config import settings from wyvern.entities.identifier import Identifier from wyvern.entities.identifier_entities import WyvernEntity +from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT from wyvern.entities.request import BaseWyvernRequest from wyvern.event_logging import event_logger -from wyvern.exceptions import WyvernModelInputError -from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY - -MODEL_OUTPUT_DATA_TYPE = TypeVar( - "MODEL_OUTPUT_DATA_TYPE", - bound=Union[float, str, List[float]], -) -""" -MODEL_OUTPUT_DATA_TYPE is the type of the output of the model. It can be a float, a string, or a list of floats -(e.g. a list of probabilities, embeddings, etc.) -""" +from wyvern.wyvern_typing import INPUT_TYPE, REQUEST_ENTITY logger = logging.getLogger(__name__) @@ -52,12 +31,16 @@ class ModelEventData(BaseModel): entity_identifier: The identifier of the entity that was used to generate the model output. This is optional. entity_identifier_type: The type of the identifier of the entity that was used to generate the model output. This is optional. + model_key: The key in the dictionary output. + This attribute will only appear when the output of the model is a dictionary. + This is optional. """ model_name: str model_output: str entity_identifier: Optional[str] = None entity_identifier_type: Optional[str] = None + model_key: Optional[str] = None class ModelEvent(LoggedEvent[ModelEventData]): @@ -71,77 +54,9 @@ class ModelEvent(LoggedEvent[ModelEventData]): event_type: EventType = EventType.MODEL -class ModelOutput(GenericModel, Generic[MODEL_OUTPUT_DATA_TYPE]): - """ - This class defines the output of a model. - - Args: - data: A dictionary mapping entity identifiers to model outputs. The model outputs can also be None. - model_name: The name of the model. This is optional. - """ - - data: Dict[Identifier, Optional[MODEL_OUTPUT_DATA_TYPE]] - model_name: Optional[str] = None - - def get_entity_output( - self, - identifier: Identifier, - ) -> Optional[MODEL_OUTPUT_DATA_TYPE]: - """ - Get the model output for a given entity identifier. - - Args: - identifier: The identifier of the entity. - - Returns: - The model output for the given entity identifier. This can also be None if the model output is None. - """ - return self.data.get(identifier) - - -class ModelInput(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): - """ - This class defines the input to a model. - - Args: - request: The request that will be used to generate the model input. - entities: A list of entities that will be used to generate the model input. - """ - - request: REQUEST_ENTITY - entities: List[GENERALIZED_WYVERN_ENTITY] = [] - - @property - def first_entity(self) -> GENERALIZED_WYVERN_ENTITY: - """ - Get the first entity in the list of entities. This is useful when you know that there is only one entity. - - Returns: - The first entity in the list of entities. - """ - if not self.entities: - raise WyvernModelInputError(model_input=self) - return self.entities[0] - - @property - def first_identifier(self) -> Identifier: - """ - Get the identifier of the first entity in the list of entities. This is useful when you know that there is only - one entity. - - Returns: - The identifier of the first entity in the list of entities. - """ - return self.first_entity.identifier - - -MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput) -MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput) - - -class ModelComponent( +class BaseModelComponent( Component[ - MODEL_INPUT, + INPUT_TYPE, MODEL_OUTPUT, ], ): @@ -155,11 +70,14 @@ def __init__( self, *upstreams, name: Optional[str] = None, + cache_output: bool = False, ): super().__init__(*upstreams, name=name) self.model_input_type = self.get_type_args_simple(0) self.model_output_type = self.get_type_args_simple(1) + self.cache_output = cache_output + @classmethod def get_type_args_simple(cls, index: int) -> Type: """ @@ -174,42 +92,77 @@ def manifest_feature_names(self) -> Set[str]: Our system will automatically fetch the required features from the feature store to make this model evaluation possible + + By default, a model component does not require any features, so this function returns an empty set """ - raise NotImplementedError( - f"{self.__class__.__name__} is a ModelComponent. " - "The @cached_property function `manifest_feature_names` must be " - "implemented to define features required for the model.", - ) + return set() - async def execute(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: + async def execute(self, input: INPUT_TYPE, **kwargs) -> MODEL_OUTPUT: """ The model_name and model_score will be automatically logged """ - api_source = request_context.ensure_current_request().url_path - request_id = input.request.request_id + wyvern_request = request_context.ensure_current_request() + api_source = wyvern_request.url_path + request_id = self._get_request_id(input) model_output = await self.inference(input, **kwargs) + if self.cache_output: + wyvern_request.cache_model_output(self.name, model_output.data) + def events_generator() -> List[ModelEvent]: timestamp = datetime.utcnow() - return [ - ModelEvent( - request_id=request_id, - api_source=api_source, - event_timestamp=timestamp, - event_data=ModelEventData( - model_name=model_output.model_name or self.__class__.__name__, - model_output=str(output), - entity_identifier=identifier.identifier, - entity_identifier_type=identifier.identifier_type, - ), - ) - for identifier, output in model_output.data.items() - ] + all_events: List[ModelEvent] = [] + for identifier, output in model_output.data.items(): + if isinstance(output, dict): + for key, value in output.items(): + all_events.append( + ModelEvent( + request_id=request_id, + api_source=api_source, + event_timestamp=timestamp, + event_data=ModelEventData( + model_name=model_output.model_name + or self.__class__.__name__, + model_output=str(value), + model_key=key, + entity_identifier=identifier.identifier, + entity_identifier_type=identifier.identifier_type, + ), + ), + ) + else: + all_events.append( + ModelEvent( + request_id=request_id, + api_source=api_source, + event_timestamp=timestamp, + event_data=ModelEventData( + model_name=model_output.model_name + or self.__class__.__name__, + model_output=str(output), + entity_identifier=identifier.identifier, + entity_identifier_type=identifier.identifier_type, + ), + ), + ) + return all_events event_logger.log_events(events_generator) # type: ignore return model_output + async def inference( + self, + input: INPUT_TYPE, + **kwargs, + ) -> MODEL_OUTPUT: + raise NotImplementedError + + def _get_request_id(self, input: INPUT_TYPE) -> Optional[str]: + raise NotImplementedError + + +class MultiEntityModelComponent(BaseModelComponent[MODEL_INPUT, MODEL_OUTPUT]): async def batch_inference( self, request: BaseWyvernRequest, @@ -264,3 +217,17 @@ async def inference( data=output_data, model_name=self.name, ) + + def _get_request_id(self, input: MODEL_INPUT) -> Optional[str]: + return input.request.request_id + + +ModelComponent = MultiEntityModelComponent + + +class SingleEntityModelComponent(BaseModelComponent[REQUEST_ENTITY, MODEL_OUTPUT]): + async def inference(self, input: REQUEST_ENTITY, **kwargs) -> MODEL_OUTPUT: + raise NotImplementedError + + def _get_request_id(self, input: REQUEST_ENTITY) -> Optional[str]: + return input.request_id diff --git a/wyvern/components/models/modelbit_component.py b/wyvern/components/models/modelbit_component.py index 92d60c5..74756ef 100644 --- a/wyvern/components/models/modelbit_component.py +++ b/wyvern/components/models/modelbit_component.py @@ -2,39 +2,30 @@ import asyncio import logging from functools import cached_property -from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union +from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union, final from wyvern.components.models.model_component import ( - MODEL_INPUT, - MODEL_OUTPUT, - ModelComponent, + BaseModelComponent, + MultiEntityModelComponent, + SingleEntityModelComponent, ) from wyvern.config import settings from wyvern.core.http import aiohttp_client from wyvern.entities.identifier import Identifier from wyvern.entities.identifier_entities import WyvernEntity +from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT from wyvern.entities.request import BaseWyvernRequest from wyvern.exceptions import ( WyvernModelbitTokenMissingError, WyvernModelbitValidationError, ) +from wyvern.wyvern_typing import INPUT_TYPE, REQUEST_ENTITY JSON: TypeAlias = Union[Dict[str, "JSON"], List["JSON"], str, int, float, bool, None] logger = logging.getLogger(__name__) -class ModelbitComponent(ModelComponent[MODEL_INPUT, MODEL_OUTPUT]): - """ - ModelbitComponent is a base class for all modelbit model components. It provides a common interface to implement - all modelbit models. - - ModelbitComponent is a subclass of ModelComponent. - - Attributes: - AUTH_TOKEN: A class variable that stores the auth token for Modelbit. - URL: A class variable that stores the url for Modelbit. - """ - +class ModelbitMixin(BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT]): AUTH_TOKEN: str = "" URL: str = "" @@ -44,6 +35,7 @@ def __init__( name: Optional[str] = None, auth_token: Optional[str] = None, url: Optional[str] = None, + cache_output: bool = False, ) -> None: """ Args: @@ -55,7 +47,7 @@ def __init__( Raises: WyvernModelbitTokenMissingError: If the auth token is not provided. """ - super().__init__(*upstreams, name=name) + super().__init__(*upstreams, name=name, cache_output=cache_output) self._auth_token = auth_token or self.AUTH_TOKEN self._modelbit_url = url or self.URL self.headers = { @@ -82,31 +74,7 @@ def manifest_feature_names(self) -> Set[str]: """ return set(self.modelbit_features) - async def build_requests( - self, - input: MODEL_INPUT, - ) -> Tuple[List[Identifier], List[Any]]: - """ - Please refer to modlebit batch inference API: - https://doc.modelbit.com/deployments/rest-api/ - """ - target_entities: List[ - Union[WyvernEntity, BaseWyvernRequest] - ] = input.entities or [input.request] - target_identifiers = [entity.identifier for entity in target_entities] - all_requests = [ - [ - idx + 1, - [ - self.get_feature(identifier, feature_name) - for feature_name in self.modelbit_features - ], - ] - for idx, identifier in enumerate(target_identifiers) - ] - return target_identifiers, all_requests - - async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: + async def inference(self, input: INPUT_TYPE, **kwargs) -> MODEL_OUTPUT: """ This method sends a request to Modelbit and returns the output. """ @@ -155,3 +123,69 @@ async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: data=output_data, model_name=self.name, ) + + async def build_requests( + self, + input: INPUT_TYPE, + ) -> Tuple[List[Identifier], List[Any]]: + """ + This method builds requests for Modelbit. This method should be implemented by the subclass. + """ + raise NotImplementedError + + +class ModelbitComponent( + ModelbitMixin[MODEL_INPUT, MODEL_OUTPUT], + MultiEntityModelComponent[MODEL_INPUT, MODEL_OUTPUT], +): + """ + ModelbitComponent is a base class for all modelbit model components. It provides a common interface to implement + all modelbit models. + + ModelbitComponent is a subclass of ModelComponent. + + Attributes: + AUTH_TOKEN: A class variable that stores the auth token for Modelbit. + URL: A class variable that stores the url for Modelbit. + """ + + async def build_requests( + self, + input: MODEL_INPUT, + ) -> Tuple[List[Identifier], List[Any]]: + """ + Please refer to modlebit batch inference API: + https://doc.modelbit.com/deployments/rest-api/ + """ + target_entities: List[ + Union[WyvernEntity, BaseWyvernRequest] + ] = input.entities or [input.request] + target_identifiers = [entity.identifier for entity in target_entities] + all_requests = [ + [ + idx + 1, + [ + self.get_feature(identifier, feature_name) + for feature_name in self.modelbit_features + ], + ] + for idx, identifier in enumerate(target_identifiers) + ] + return target_identifiers, all_requests + + +class SingleEntityModelbitComponent( + ModelbitMixin[REQUEST_ENTITY, MODEL_OUTPUT], + SingleEntityModelComponent[REQUEST_ENTITY, MODEL_OUTPUT], +): + @final + async def build_requests( + self, + input: REQUEST_ENTITY, + ) -> Tuple[List[Identifier], List[Any]]: + target_identifier, request = await self.build_request(input) + all_requests = [[1, request]] + return [target_identifier], all_requests + + async def build_request(self, input: REQUEST_ENTITY) -> Tuple[Identifier, Any]: + raise NotImplementedError diff --git a/wyvern/components/ranking_pipeline.py b/wyvern/components/ranking_pipeline.py index bdbbafd..3625b40 100644 --- a/wyvern/components/ranking_pipeline.py +++ b/wyvern/components/ranking_pipeline.py @@ -13,7 +13,7 @@ ImpressionEventLoggingComponent, ImpressionEventLoggingRequest, ) -from wyvern.components.models.model_component import ModelComponent, ModelInput +from wyvern.components.models.model_component import ModelComponent from wyvern.components.pagination.pagination_component import ( PaginationComponent, PaginationRequest, @@ -22,6 +22,7 @@ from wyvern.components.pipeline_component import PipelineComponent from wyvern.entities.candidate_entities import ScoredCandidate from wyvern.entities.identifier_entities import QueryEntity +from wyvern.entities.model_entities import ModelInput from wyvern.entities.request import BaseWyvernRequest from wyvern.event_logging import event_logger from wyvern.wyvern_typing import WYVERN_ENTITY diff --git a/wyvern/components/single_entity_pipeline.py b/wyvern/components/single_entity_pipeline.py new file mode 100644 index 0000000..ed3a700 --- /dev/null +++ b/wyvern/components/single_entity_pipeline.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +from typing import Any, Generic, List, Optional + +from pydantic.generics import GenericModel + +from wyvern.components.business_logic.business_logic import ( + SingleEntityBusinessLogicPipeline, + SingleEntityBusinessLogicRequest, +) +from wyvern.components.component import Component +from wyvern.components.events.events import LoggedEvent +from wyvern.components.models.model_component import SingleEntityModelComponent +from wyvern.components.pipeline_component import PipelineComponent +from wyvern.entities.identifier import Identifier +from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE +from wyvern.event_logging import event_logger +from wyvern.exceptions import MissingModelOutputError +from wyvern.wyvern_typing import REQUEST_ENTITY + + +class SingleEntityPipelineResponse(GenericModel, Generic[MODEL_OUTPUT_DATA_TYPE]): + data: MODEL_OUTPUT_DATA_TYPE + events: Optional[List[LoggedEvent[Any]]] = None + + +class SingleEntityPipeline( + PipelineComponent[ + REQUEST_ENTITY, + SingleEntityPipelineResponse[MODEL_OUTPUT_DATA_TYPE], + ], + Generic[REQUEST_ENTITY, MODEL_OUTPUT_DATA_TYPE], +): + def __init__( + self, + *upstreams: Component, + model: SingleEntityModelComponent, + business_logic: Optional[ + SingleEntityBusinessLogicPipeline[REQUEST_ENTITY, MODEL_OUTPUT_DATA_TYPE] + ] = None, + name: Optional[str] = None, + handle_feature_store_exceptions: bool = False, + ) -> None: + upstream_components = list(upstreams) + + self.model = model + upstream_components.append(self.model) + + if not business_logic: + business_logic = SingleEntityBusinessLogicPipeline[ + REQUEST_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + ]() + self.business_logic = business_logic + upstream_components.append(self.business_logic) + + super().__init__( + *upstream_components, + name=name, + handle_feature_store_exceptions=handle_feature_store_exceptions, + ) + + async def execute( + self, + input: REQUEST_ENTITY, + **kwargs, + ) -> SingleEntityPipelineResponse[MODEL_OUTPUT_DATA_TYPE]: + output = await self.model.execute(input, **kwargs) + identifiers: List[Identifier] = list(output.data.keys()) + if not identifiers: + raise MissingModelOutputError() + identifier = identifiers[0] + model_output_data: MODEL_OUTPUT_DATA_TYPE = output.data.get(identifier) + + business_logic_input = SingleEntityBusinessLogicRequest[ + REQUEST_ENTITY, + MODEL_OUTPUT_DATA_TYPE, + ]( + identifier=identifier, + request=input, + model_output=model_output_data, + ) + business_logic_output = await self.business_logic.execute( + input=business_logic_input, + **kwargs, + ) + return self.generate_response( + input, + business_logic_output.adjusted_output, + ) + + def generate_response( + self, + input: REQUEST_ENTITY, + pipeline_output: MODEL_OUTPUT_DATA_TYPE, + ) -> SingleEntityPipelineResponse[MODEL_OUTPUT_DATA_TYPE]: + return SingleEntityPipelineResponse[MODEL_OUTPUT_DATA_TYPE]( + data=pipeline_output, + events=event_logger.get_logged_events() if input.include_events else None, + ) diff --git a/wyvern/entities/candidate_entities.py b/wyvern/entities/candidate_entities.py index 3f5c0fc..8c529a1 100644 --- a/wyvern/entities/candidate_entities.py +++ b/wyvern/entities/candidate_entities.py @@ -10,7 +10,10 @@ # TODO (suchintan): This should be renamed to ScoredEntity probably -class ScoredCandidate(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY]): +class ScoredCandidate( + GenericModel, + Generic[GENERALIZED_WYVERN_ENTITY], +): """ A candidate entity with a score. @@ -30,7 +33,6 @@ class CandidateSetEntity( ): """ A set of candidate entities. This is a generic model that can be used to represent a set of candidate entities. - Attributes: candidates: The list of candidate entities. """ diff --git a/wyvern/entities/model_entities.py b/wyvern/entities/model_entities.py new file mode 100644 index 0000000..f674ccf --- /dev/null +++ b/wyvern/entities/model_entities.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +from typing import Any, Dict, Generic, List, Optional, TypeVar, Union + +from pydantic.generics import GenericModel + +from wyvern.entities.identifier import Identifier +from wyvern.exceptions import WyvernModelInputError +from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY + +MODEL_OUTPUT_DATA_TYPE = TypeVar( + "MODEL_OUTPUT_DATA_TYPE", + bound=Union[ + float, + str, + List[float], + Dict[str, Any], + ], +) +""" +MODEL_OUTPUT_DATA_TYPE is the type of the output of the model. It can be a float, a string, or a list of floats +(e.g. a list of probabilities, embeddings, etc.) +""" + + +class ModelOutput(GenericModel, Generic[MODEL_OUTPUT_DATA_TYPE]): + """ + This class defines the output of a model. + + Args: + data: A dictionary mapping entity identifiers to model outputs. The model outputs can also be None. + model_name: The name of the model. This is optional. + """ + + data: Dict[Identifier, Optional[MODEL_OUTPUT_DATA_TYPE]] + model_name: Optional[str] = None + + def get_entity_output( + self, + identifier: Identifier, + ) -> Optional[MODEL_OUTPUT_DATA_TYPE]: + """ + Get the model output for a given entity identifier. + + Args: + identifier: The identifier of the entity. + + Returns: + The model output for the given entity identifier. This can also be None if the model output is None. + """ + return self.data.get(identifier) + + +class ModelInput(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): + """ + This class defines the input to a model. + + Args: + request: The request that will be used to generate the model input. + entities: A list of entities that will be used to generate the model input. + """ + + request: REQUEST_ENTITY + entities: List[GENERALIZED_WYVERN_ENTITY] = [] + + @property + def first_entity(self) -> GENERALIZED_WYVERN_ENTITY: + """ + Get the first entity in the list of entities. This is useful when you know that there is only one entity. + + Returns: + The first entity in the list of entities. + """ + if not self.entities: + raise WyvernModelInputError(model_input=self) + return self.entities[0] + + @property + def first_identifier(self) -> Identifier: + """ + Get the identifier of the first entity in the list of entities. This is useful when you know that there is only + one entity. + + Returns: + The identifier of the first entity in the list of entities. + """ + return self.first_entity.identifier + + +MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput) +MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput) + + +class ChainedModelInput(ModelInput, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): + upstream_model_output: Dict[ + Identifier, + Optional[ + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + ] + ], + ] + upstream_model_name: Optional[str] = None diff --git a/wyvern/exceptions.py b/wyvern/exceptions.py index 2007d88..a67b6e4 100644 --- a/wyvern/exceptions.py +++ b/wyvern/exceptions.py @@ -154,3 +154,11 @@ class ExperimentationClientInitializationError(WyvernError): class EntityColumnMissingError(WyvernError): message = "Entity column {entity} is missing in the entity data" + + +class MissingModelChainOutputError(WyvernError): + message = "Model chain output is missing" + + +class MissingModelOutputError(WyvernError): + message = "Identifier is missing in the model output" diff --git a/wyvern/service.py b/wyvern/service.py index 597e01a..687d018 100644 --- a/wyvern/service.py +++ b/wyvern/service.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from typing import List, Optional, Type +from typing import List, Optional, Type, Union from dotenv import load_dotenv from fastapi import FastAPI @@ -40,7 +40,7 @@ def __init__( async def register_routes( self, - route_components: List[Type[APIRouteComponent]], + route_components: List[Union[Type[APIRouteComponent], APIRouteComponent]], ) -> None: """ Register the routes for the Wyvern service @@ -69,7 +69,9 @@ def _run( @staticmethod def generate( *, - route_components: Optional[List[Type[APIRouteComponent]]] = None, + route_components: Optional[ + List[Union[Type[APIRouteComponent], APIRouteComponent]] + ] = None, realtime_feature_components: Optional[ List[Type[RealtimeFeatureComponent]] ] = None, @@ -105,7 +107,7 @@ def generate( @staticmethod def run( *, - route_components: List[Type[APIRouteComponent]], + route_components: List[Union[Type[APIRouteComponent], APIRouteComponent]], realtime_feature_components: Optional[ List[Type[RealtimeFeatureComponent]] ] = None, @@ -135,7 +137,9 @@ def run( @staticmethod def generate_app( *, - route_components: Optional[List[Type[APIRouteComponent]]] = None, + route_components: Optional[ + List[Union[Type[APIRouteComponent], APIRouteComponent]] + ] = None, realtime_feature_components: Optional[ List[Type[RealtimeFeatureComponent]] ] = None, diff --git a/wyvern/web_frameworks/fastapi.py b/wyvern/web_frameworks/fastapi.py index 7ef438a..f344f6b 100644 --- a/wyvern/web_frameworks/fastapi.py +++ b/wyvern/web_frameworks/fastapi.py @@ -2,7 +2,7 @@ import logging import time from contextlib import asynccontextmanager -from typing import Dict, Type +from typing import Dict, Type, Union import uvicorn from fastapi import BackgroundTasks, FastAPI, HTTPException, Request @@ -101,7 +101,7 @@ async def request_middleware(request: Request, call_next): async def register_route( self, - route_component: Type[APIRouteComponent], + route_component: Union[Type[APIRouteComponent], APIRouteComponent], ) -> None: """ Register a route component. This will register the route with FastAPI and also initialize the route component. @@ -112,10 +112,12 @@ async def register_route( Raises: WyvernRouteRegistrationError: If the route component is not a subclass of APIRouteComponent. """ - if not issubclass(route_component, APIRouteComponent): + if isinstance(route_component, APIRouteComponent): + root_component = route_component + elif not issubclass(route_component, APIRouteComponent): raise WyvernRouteRegistrationError(component=route_component) - - root_component = route_component() + else: + root_component = route_component() await root_component.initialize_wrapper() path = _massage_path(f"/api/{root_component.API_VERSION}/{root_component.PATH}") diff --git a/wyvern/wyvern_request.py b/wyvern/wyvern_request.py index f80342e..690d84a 100644 --- a/wyvern/wyvern_request.py +++ b/wyvern/wyvern_request.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from urllib.parse import urlparse import fastapi @@ -10,6 +10,7 @@ from wyvern.components.events.events import LoggedEvent from wyvern.entities.feature_entities import FeatureMap +from wyvern.entities.identifier import Identifier @dataclass @@ -44,6 +45,21 @@ class WyvernRequest: feature_map: FeatureMap + # the key is the name of the model and the value is a map of the identifier to the model score + model_output_map: Dict[ + str, + Dict[ + Identifier, + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + None, + ], + ], + ] + request_id: Optional[str] = None # TODO: params @@ -75,5 +91,41 @@ def parse_fastapi_request( entity_store={}, events=[], feature_map=FeatureMap(feature_map={}), + model_output_map={}, request_id=request_id, ) + + def cache_model_output( + self, + model_name: str, + data: Dict[ + Identifier, + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + None, + ], + ], + ) -> None: + if model_name not in self.model_output_map: + self.model_output_map[model_name] = {} + self.model_output_map[model_name].update(data) + + def get_model_output( + self, + model_name: str, + identifier: Identifier, + ) -> Optional[ + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + None, + ] + ]: + if model_name not in self.model_output_map: + return None + return self.model_output_map[model_name].get(identifier)