This repository has been archived by the owner on Mar 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b92fa23
commit 59e2e71
Showing
2 changed files
with
33 additions
and
9 deletions.
There are no files selected for viewing
37 changes: 29 additions & 8 deletions
37
wyvern/components/models/model_experimentation_component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,43 @@ | ||
# -*- coding: utf-8 -*- | ||
from typing import Optional | ||
from typing import Dict, Optional | ||
|
||
from wyvern.components.models.model_component import BaseModelComponent | ||
from wyvern.entities.model_entities import MODEL_OUTPUT | ||
from wyvern.experimentation.client import experimentation_client | ||
from wyvern.wyvern_typing import INPUT_TYPE | ||
|
||
|
||
class ModelExperimentation(BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT]): | ||
def __init__( | ||
self, | ||
*upstreams, | ||
first_model: BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT], | ||
second_model: BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT], | ||
# TODO: find a better name for assignment_mapping | ||
# this is the assignment var -> model mapping | ||
assignment_mapping: Dict[str, BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT]], | ||
experiment_id: str, | ||
raise_error_on_none: bool = True, | ||
name: Optional[str] = None, | ||
): | ||
super().__init__(*upstreams, name=name) | ||
self.first_model = first_model | ||
self.second_model = second_model | ||
all_models = list(assignment_mapping.values()) | ||
super().__init__(*all_models, name=name) | ||
self.experiment_id = experiment_id | ||
self.assignment_mapping = assignment_mapping | ||
self.raise_error_on_none = raise_error_on_none | ||
|
||
async def execute(self, input: INPUT_TYPE, **kwargs) -> MODEL_OUTPUT: | ||
return await super().execute(input, **kwargs) | ||
treatment = experimentation_client.get_experiment_result( | ||
self.experiment_id, | ||
self.get_entity_id(input), | ||
) | ||
# TODO: validation | ||
if treatment is None: | ||
# if self.raise_error_on_none: | ||
# raise ValueError("treatment is None") | ||
# else: | ||
# # use a default model? | ||
raise ValueError("treatment is None") | ||
|
||
model = self.assignment_mapping[treatment] | ||
return await model.execute(input, **kwargs) | ||
|
||
def get_entity_id(self, input: INPUT_TYPE) -> str: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters