Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

v0.0.18 - chained model evaluation #63

Merged
merged 11 commits into from
Sep 26, 2023
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "wyvern-ai"
version = "0.0.17"
version = "0.0.18-beta3"
description = ""
authors = ["Wyvern AI <[email protected]>"]
readme = "README.md"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self):
entity_store={},
events=[],
feature_map=FeatureMap(feature_map={}),
model_output_map={},
),
)
return await pipeline.execute(request)
Expand Down
9 changes: 4 additions & 5 deletions tests/scenarios/test_product_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
RealtimeFeatureComponent,
RealtimeFeatureRequest,
)
from wyvern.components.models.model_component import (
ModelComponent,
ModelInput,
ModelOutput,
)
from wyvern.components.models.model_component import ModelComponent
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.config import settings
from wyvern.core.compression import wyvern_encode
Expand All @@ -26,6 +22,7 @@
from wyvern.entities.feature_entities import FeatureData, FeatureMap
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import ProductEntity, WyvernEntity
from wyvern.entities.model_entities import ModelInput, ModelOutput
from wyvern.entities.request import BaseWyvernRequest
from wyvern.service import WyvernService
from wyvern.wyvern_request import WyvernRequest
Expand Down Expand Up @@ -387,6 +384,7 @@ async def test_hydrate(mock_redis):
json=json_input,
headers={},
entity_store={},
model_output_map={},
events=[],
feature_map=FeatureMap(feature_map={}),
)
Expand Down Expand Up @@ -450,6 +448,7 @@ async def test_hydrate__duplicate_brand(mock_redis__duplicate_brand):
entity_store={},
events=[],
feature_map=FeatureMap(feature_map={}),
model_output_map={},
)
request_context.set(test_wyvern_request)

Expand Down
10 changes: 5 additions & 5 deletions wyvern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
from wyvern.components.features.realtime_features_component import (
RealtimeFeatureComponent,
)
from wyvern.components.models.model_component import (
ModelComponent,
ModelInput,
ModelOutput,
)
from wyvern.components.models.model_chain_component import ModelChainComponent
from wyvern.components.models.model_component import ModelComponent
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.components.ranking_pipeline import (
RankingPipeline,
Expand All @@ -23,6 +20,7 @@
WyvernDataModel,
WyvernEntity,
)
from wyvern.entities.model_entities import ChainedModelInput, ModelInput, ModelOutput
from wyvern.feature_store.feature_server import generate_wyvern_store_app
from wyvern.service import WyvernService
from wyvern.wyvern_logging import setup_logging
Expand All @@ -36,11 +34,13 @@
__all__ = [
"generate_wyvern_store_app",
"CandidateSetEntity",
"ChainedModelInput",
"CompositeIdentifier",
"FeatureData",
"FeatureMap",
"Identifier",
"IdentifierType",
"ModelChainComponent",
"ModelComponent",
"ModelInput",
"ModelOutput",
Expand Down
24 changes: 23 additions & 1 deletion wyvern/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from enum import Enum
from functools import cached_property
from typing import Dict, Generic, Optional, Set
from typing import Dict, Generic, List, Optional, Set, Union
from uuid import uuid4

from wyvern import request_context
Expand Down Expand Up @@ -177,3 +177,25 @@ def get_all_features(
if not feature_data:
return {}
return feature_data.features

def get_model_output(
self,
model_name: str,
identifier: Identifier,
) -> Optional[
Union[
float,
str,
List[float],
Dict[str, Optional[Union[float, str, list[float]]]],
]
]:
"""
Gets the model output for the given identifier

Args:
model_name: str. The name of the model
identifier: Identifier. The entity identifier
"""
current_request = request_context.ensure_current_request()
return current_request.get_model_output(model_name, identifier)
54 changes: 54 additions & 0 deletions wyvern/components/models/model_chain_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
from functools import cached_property
from typing import Optional, Set

from wyvern.components.models.model_component import ModelComponent
from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT, ChainedModelInput
from wyvern.exceptions import MissingModelChainOutputError


class ModelChainComponent(ModelComponent[MODEL_INPUT, MODEL_OUTPUT]):
"""
Model chaining allows you to chain models together so that the output of one model can be the input to another model

For all the models in the chain, all the request and entities in the model input are the same
"""

def __init__(self, *upstreams: ModelComponent, name: Optional[str] = None):
super().__init__(*upstreams, name=name)
self.chain = upstreams

@cached_property
def manifest_feature_names(self) -> Set[str]:
feature_names: Set[str] = set()
for model in self.chain:
feature_names = feature_names.union(model.manifest_feature_names)
return feature_names

async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:
output = None
prev_model: Optional[ModelComponent] = None
for model in self.chain:
curr_input: ChainedModelInput
if prev_model is not None and output is not None:
curr_input = ChainedModelInput(
request=input.request,
entities=input.entities,
upstream_model_name=prev_model.name,
upstream_model_output=output.data,
)
else:
curr_input = ChainedModelInput(
request=input.request,
entities=input.entities,
upstream_model_name=None,
upstream_model_output={},
)
output = await model.execute(curr_input, **kwargs)
prev_model = model

if output is None:
raise MissingModelChainOutputError()

# TODO: do type checking to make sure the output is of the correct type
return output
156 changes: 49 additions & 107 deletions wyvern/components/models/model_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,19 @@
import logging
from datetime import datetime
from functools import cached_property
from typing import (
Dict,
Generic,
List,
Optional,
Sequence,
Set,
Type,
TypeVar,
Union,
get_args,
)
from typing import Dict, List, Optional, Sequence, Set, Type, Union, get_args

from pydantic import BaseModel
from pydantic.generics import GenericModel

from wyvern import request_context
from wyvern.components.component import Component
from wyvern.components.events.events import EventType, LoggedEvent
from wyvern.config import settings
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import WyvernEntity
from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT
from wyvern.entities.request import BaseWyvernRequest
from wyvern.event_logging import event_logger
from wyvern.exceptions import WyvernModelInputError
from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY

MODEL_OUTPUT_DATA_TYPE = TypeVar(
"MODEL_OUTPUT_DATA_TYPE",
bound=Union[float, str, List[float]],
)
"""
MODEL_OUTPUT_DATA_TYPE is the type of the output of the model. It can be a float, a string, or a list of floats
(e.g. a list of probabilities, embeddings, etc.)
"""

logger = logging.getLogger(__name__)

Expand All @@ -52,12 +30,16 @@ class ModelEventData(BaseModel):
entity_identifier: The identifier of the entity that was used to generate the model output. This is optional.
entity_identifier_type: The type of the identifier of the entity that was used to generate the model output.
This is optional.
target: The key in the dictionary output.
This attribute will only appear when the output of the model is a dictionary.
This is optional.
"""

model_name: str
model_output: str
entity_identifier: Optional[str] = None
entity_identifier_type: Optional[str] = None
target: Optional[str] = None


class ModelEvent(LoggedEvent[ModelEventData]):
Expand All @@ -71,74 +53,6 @@ class ModelEvent(LoggedEvent[ModelEventData]):
event_type: EventType = EventType.MODEL


class ModelOutput(GenericModel, Generic[MODEL_OUTPUT_DATA_TYPE]):
"""
This class defines the output of a model.

Args:
data: A dictionary mapping entity identifiers to model outputs. The model outputs can also be None.
model_name: The name of the model. This is optional.
"""

data: Dict[Identifier, Optional[MODEL_OUTPUT_DATA_TYPE]]
model_name: Optional[str] = None

def get_entity_output(
self,
identifier: Identifier,
) -> Optional[MODEL_OUTPUT_DATA_TYPE]:
"""
Get the model output for a given entity identifier.

Args:
identifier: The identifier of the entity.

Returns:
The model output for the given entity identifier. This can also be None if the model output is None.
"""
return self.data.get(identifier)


class ModelInput(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]):
"""
This class defines the input to a model.

Args:
request: The request that will be used to generate the model input.
entities: A list of entities that will be used to generate the model input.
"""

request: REQUEST_ENTITY
entities: List[GENERALIZED_WYVERN_ENTITY] = []

@property
def first_entity(self) -> GENERALIZED_WYVERN_ENTITY:
"""
Get the first entity in the list of entities. This is useful when you know that there is only one entity.

Returns:
The first entity in the list of entities.
"""
if not self.entities:
raise WyvernModelInputError(model_input=self)
return self.entities[0]

@property
def first_identifier(self) -> Identifier:
"""
Get the identifier of the first entity in the list of entities. This is useful when you know that there is only
one entity.

Returns:
The identifier of the first entity in the list of entities.
"""
return self.first_entity.identifier


MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput)
MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput)


class ModelComponent(
Component[
MODEL_INPUT,
Expand All @@ -155,11 +69,14 @@ def __init__(
self,
*upstreams,
name: Optional[str] = None,
cache_output: bool = False,
):
super().__init__(*upstreams, name=name)
self.model_input_type = self.get_type_args_simple(0)
self.model_output_type = self.get_type_args_simple(1)

self.cache_output = cache_output

@classmethod
def get_type_args_simple(cls, index: int) -> Type:
"""
Expand All @@ -185,26 +102,51 @@ async def execute(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:
"""
The model_name and model_score will be automatically logged
"""
api_source = request_context.ensure_current_request().url_path
wyvern_request = request_context.ensure_current_request()
api_source = wyvern_request.url_path
request_id = input.request.request_id
model_output = await self.inference(input, **kwargs)

if self.cache_output:
wyvern_request.cache_model_output(self.name, model_output.data)

def events_generator() -> List[ModelEvent]:
timestamp = datetime.utcnow()
return [
ModelEvent(
request_id=request_id,
api_source=api_source,
event_timestamp=timestamp,
event_data=ModelEventData(
model_name=model_output.model_name or self.__class__.__name__,
model_output=str(output),
entity_identifier=identifier.identifier,
entity_identifier_type=identifier.identifier_type,
),
)
for identifier, output in model_output.data.items()
]
all_events: List[ModelEvent] = []
for identifier, output in model_output.data.items():
if isinstance(output, dict):
wintonzheng marked this conversation as resolved.
Show resolved Hide resolved
for key, value in output.items():
all_events.append(
ModelEvent(
request_id=request_id,
api_source=api_source,
event_timestamp=timestamp,
event_data=ModelEventData(
model_name=model_output.model_name
or self.__class__.__name__,
model_output=str(value),
target=key,
entity_identifier=identifier.identifier,
entity_identifier_type=identifier.identifier_type,
),
),
)
else:
all_events.append(
ModelEvent(
request_id=request_id,
api_source=api_source,
event_timestamp=timestamp,
event_data=ModelEventData(
model_name=model_output.model_name
or self.__class__.__name__,
model_output=str(output),
entity_identifier=identifier.identifier,
entity_identifier_type=identifier.identifier_type,
),
),
)
return all_events

event_logger.log_events(events_generator) # type: ignore

Expand Down
Loading