Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
multi model evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng committed Sep 7, 2023
1 parent 416e082 commit 1296102
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 103 deletions.
7 changes: 2 additions & 5 deletions tests/scenarios/test_product_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
RealtimeFeatureComponent,
RealtimeFeatureRequest,
)
from wyvern.components.models.model_component import (
ModelComponent,
ModelInput,
ModelOutput,
)
from wyvern.components.models.model_component import ModelComponent
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.config import settings
from wyvern.core.compression import wyvern_encode
Expand All @@ -26,6 +22,7 @@
from wyvern.entities.feature_entities import FeatureData, FeatureMap
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import ProductEntity, WyvernEntity
from wyvern.entities.model_entities import ModelInput, ModelOutput
from wyvern.entities.request import BaseWyvernRequest
from wyvern.service import WyvernService
from wyvern.wyvern_request import WyvernRequest
Expand Down
94 changes: 2 additions & 92 deletions wyvern/components/models/model_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,19 @@
import logging
from datetime import datetime
from functools import cached_property
from typing import (
Dict,
Generic,
List,
Optional,
Sequence,
Set,
Type,
TypeVar,
Union,
get_args,
)
from typing import Dict, List, Optional, Sequence, Set, Type, Union, get_args

from pydantic import BaseModel
from pydantic.generics import GenericModel

from wyvern import request_context
from wyvern.components.component import Component
from wyvern.components.events.events import EventType, LoggedEvent
from wyvern.config import settings
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import WyvernEntity
from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT
from wyvern.entities.request import BaseWyvernRequest
from wyvern.event_logging import event_logger
from wyvern.exceptions import WyvernModelInputError
from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY

MODEL_OUTPUT_DATA_TYPE = TypeVar(
"MODEL_OUTPUT_DATA_TYPE",
bound=Union[float, str, List[float]],
)
"""
MODEL_OUTPUT_DATA_TYPE is the type of the output of the model. It can be a float, a string, or a list of floats
(e.g. a list of probabilities, embeddings, etc.)
"""

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -71,74 +49,6 @@ class ModelEvent(LoggedEvent[ModelEventData]):
event_type: EventType = EventType.MODEL


class ModelOutput(GenericModel, Generic[MODEL_OUTPUT_DATA_TYPE]):
"""
This class defines the output of a model.
Args:
data: A dictionary mapping entity identifiers to model outputs. The model outputs can also be None.
model_name: The name of the model. This is optional.
"""

data: Dict[Identifier, Optional[MODEL_OUTPUT_DATA_TYPE]]
model_name: Optional[str] = None

def get_entity_output(
self,
identifier: Identifier,
) -> Optional[MODEL_OUTPUT_DATA_TYPE]:
"""
Get the model output for a given entity identifier.
Args:
identifier: The identifier of the entity.
Returns:
The model output for the given entity identifier. This can also be None if the model output is None.
"""
return self.data.get(identifier)


class ModelInput(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]):
"""
This class defines the input to a model.
Args:
request: The request that will be used to generate the model input.
entities: A list of entities that will be used to generate the model input.
"""

request: REQUEST_ENTITY
entities: List[GENERALIZED_WYVERN_ENTITY] = []

@property
def first_entity(self) -> GENERALIZED_WYVERN_ENTITY:
"""
Get the first entity in the list of entities. This is useful when you know that there is only one entity.
Returns:
The first entity in the list of entities.
"""
if not self.entities:
raise WyvernModelInputError(model_input=self)
return self.entities[0]

@property
def first_identifier(self) -> Identifier:
"""
Get the identifier of the first entity in the list of entities. This is useful when you know that there is only
one entity.
Returns:
The identifier of the first entity in the list of entities.
"""
return self.first_entity.identifier


MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput)
MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput)


class ModelComponent(
Component[
MODEL_INPUT,
Expand Down
7 changes: 2 additions & 5 deletions wyvern/components/models/modelbit_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union

from wyvern.components.models.model_component import (
MODEL_INPUT,
MODEL_OUTPUT,
ModelComponent,
)
from wyvern.components.models.model_component import ModelComponent
from wyvern.config import settings
from wyvern.core.http import aiohttp_client
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import WyvernEntity
from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT
from wyvern.entities.request import BaseWyvernRequest
from wyvern.exceptions import (
WyvernModelbitTokenMissingError,
Expand Down
3 changes: 2 additions & 1 deletion wyvern/components/ranking_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ImpressionEventLoggingComponent,
ImpressionEventLoggingRequest,
)
from wyvern.components.models.model_component import ModelComponent, ModelInput
from wyvern.components.models.model_component import ModelComponent
from wyvern.components.pagination.pagination_component import (
PaginationComponent,
PaginationRequest,
Expand All @@ -22,6 +22,7 @@
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.entities.candidate_entities import ScoredCandidate
from wyvern.entities.identifier_entities import QueryEntity
from wyvern.entities.model_entities import ModelInput
from wyvern.entities.request import BaseWyvernRequest
from wyvern.event_logging import event_logger
from wyvern.wyvern_typing import WYVERN_ENTITY
Expand Down
85 changes: 85 additions & 0 deletions wyvern/entities/model_entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
from typing import Dict, Generic, List, Optional, TypeVar, Union

from pydantic.generics import GenericModel

from wyvern.entities.identifier import Identifier
from wyvern.exceptions import WyvernModelInputError
from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY

MODEL_OUTPUT_DATA_TYPE = TypeVar(
"MODEL_OUTPUT_DATA_TYPE",
bound=Union[float, str, List[float]],
)
"""
MODEL_OUTPUT_DATA_TYPE is the type of the output of the model. It can be a float, a string, or a list of floats
(e.g. a list of probabilities, embeddings, etc.)
"""


class ModelOutput(GenericModel, Generic[MODEL_OUTPUT_DATA_TYPE]):
"""
This class defines the output of a model.
Args:
data: A dictionary mapping entity identifiers to model outputs. The model outputs can also be None.
model_name: The name of the model. This is optional.
"""

data: Dict[Identifier, Optional[MODEL_OUTPUT_DATA_TYPE]]
model_name: Optional[str] = None

def get_entity_output(
self,
identifier: Identifier,
) -> Optional[MODEL_OUTPUT_DATA_TYPE]:
"""
Get the model output for a given entity identifier.
Args:
identifier: The identifier of the entity.
Returns:
The model output for the given entity identifier. This can also be None if the model output is None.
"""
return self.data.get(identifier)


class ModelInput(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]):
"""
This class defines the input to a model.
Args:
request: The request that will be used to generate the model input.
entities: A list of entities that will be used to generate the model input.
"""

request: REQUEST_ENTITY
entities: List[GENERALIZED_WYVERN_ENTITY] = []

@property
def first_entity(self) -> GENERALIZED_WYVERN_ENTITY:
"""
Get the first entity in the list of entities. This is useful when you know that there is only one entity.
Returns:
The first entity in the list of entities.
"""
if not self.entities:
raise WyvernModelInputError(model_input=self)
return self.entities[0]

@property
def first_identifier(self) -> Identifier:
"""
Get the identifier of the first entity in the list of entities. This is useful when you know that there is only
one entity.
Returns:
The identifier of the first entity in the list of entities.
"""
return self.first_entity.identifier


MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput)
MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput)
3 changes: 3 additions & 0 deletions wyvern/wyvern_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from wyvern.components.events.events import LoggedEvent
from wyvern.entities.feature_entities import FeatureMap
from wyvern.entities.identifier import Identifier


@dataclass
Expand Down Expand Up @@ -43,6 +44,7 @@ class WyvernRequest:
events: List[Callable[[], List[LoggedEvent[Any]]]]

feature_map: FeatureMap
model_score_map: Dict[str, Dict[Identifier, float]]

request_id: Optional[str] = None

Expand Down Expand Up @@ -75,5 +77,6 @@ def parse_fastapi_request(
entity_store={},
events=[],
feature_map=FeatureMap(feature_map={}),
model_score_map={},
request_id=request_id,
)

0 comments on commit 1296102

Please sign in to comment.