Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
experimentation model component
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng committed Oct 27, 2023
1 parent b92fa23 commit 59e2e71
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
37 changes: 29 additions & 8 deletions wyvern/components/models/model_experimentation_component.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,43 @@
# -*- coding: utf-8 -*-
from typing import Optional
from typing import Dict, Optional

from wyvern.components.models.model_component import BaseModelComponent
from wyvern.entities.model_entities import MODEL_OUTPUT
from wyvern.experimentation.client import experimentation_client
from wyvern.wyvern_typing import INPUT_TYPE


class ModelExperimentation(BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT]):
def __init__(
self,
*upstreams,
first_model: BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT],
second_model: BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT],
# TODO: find a better name for assignment_mapping
# this is the assignment var -> model mapping
assignment_mapping: Dict[str, BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT]],
experiment_id: str,
raise_error_on_none: bool = True,
name: Optional[str] = None,
):
super().__init__(*upstreams, name=name)
self.first_model = first_model
self.second_model = second_model
all_models = list(assignment_mapping.values())
super().__init__(*all_models, name=name)
self.experiment_id = experiment_id
self.assignment_mapping = assignment_mapping
self.raise_error_on_none = raise_error_on_none

async def execute(self, input: INPUT_TYPE, **kwargs) -> MODEL_OUTPUT:
return await super().execute(input, **kwargs)
treatment = experimentation_client.get_experiment_result(
self.experiment_id,
self.get_entity_id(input),
)
# TODO: validation
if treatment is None:
# if self.raise_error_on_none:
# raise ValueError("treatment is None")
# else:
# # use a default model?
raise ValueError("treatment is None")

model = self.assignment_mapping[treatment]
return await model.execute(input, **kwargs)

def get_entity_id(self, input: INPUT_TYPE) -> str:
raise NotImplementedError
5 changes: 4 additions & 1 deletion wyvern/components/single_entity_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from wyvern.components.component import Component
from wyvern.components.events.events import LoggedEvent
from wyvern.components.models.model_component import SingleEntityModelComponent
from wyvern.components.models.model_experimentation_component import (
ModelExperimentation,
)
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.entities.identifier import Identifier
from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE
Expand All @@ -33,7 +36,7 @@ class SingleEntityPipeline(
def __init__(
self,
*upstreams: Component,
model: SingleEntityModelComponent,
model: SingleEntityModelComponent | ModelExperimentation,
business_logic: Optional[
SingleEntityBusinessLogicPipeline[REQUEST_ENTITY, MODEL_OUTPUT_DATA_TYPE]
] = None,
Expand Down

0 comments on commit 59e2e71

Please sign in to comment.