diff --git a/wyvern/components/models/model_experimentation_component.py b/wyvern/components/models/model_experimentation_component.py index 323c0d3..af66600 100644 --- a/wyvern/components/models/model_experimentation_component.py +++ b/wyvern/components/models/model_experimentation_component.py @@ -1,22 +1,43 @@ # -*- coding: utf-8 -*- -from typing import Optional +from typing import Dict, Optional from wyvern.components.models.model_component import BaseModelComponent from wyvern.entities.model_entities import MODEL_OUTPUT +from wyvern.experimentation.client import experimentation_client from wyvern.wyvern_typing import INPUT_TYPE class ModelExperimentation(BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT]): def __init__( self, - *upstreams, - first_model: BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT], - second_model: BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT], + # TODO: find a better name for assignment_mapping + # this is the assignment var -> model mapping + assignment_mapping: Dict[str, BaseModelComponent[INPUT_TYPE, MODEL_OUTPUT]], + experiment_id: str, + raise_error_on_none: bool = True, name: Optional[str] = None, ): - super().__init__(*upstreams, name=name) - self.first_model = first_model - self.second_model = second_model + all_models = list(assignment_mapping.values()) + super().__init__(*all_models, name=name) + self.experiment_id = experiment_id + self.assignment_mapping = assignment_mapping + self.raise_error_on_none = raise_error_on_none async def execute(self, input: INPUT_TYPE, **kwargs) -> MODEL_OUTPUT: - return await super().execute(input, **kwargs) + treatment = experimentation_client.get_experiment_result( + self.experiment_id, + self.get_entity_id(input), + ) + # TODO: validation + if treatment is None: + # if self.raise_error_on_none: + # raise ValueError("treatment is None") + # else: + # # use a default model? + raise ValueError("treatment is None") + + model = self.assignment_mapping[treatment] + return await model.execute(input, **kwargs) + + def get_entity_id(self, input: INPUT_TYPE) -> str: + raise NotImplementedError diff --git a/wyvern/components/single_entity_pipeline.py b/wyvern/components/single_entity_pipeline.py index 3a04ad1..7476e0e 100644 --- a/wyvern/components/single_entity_pipeline.py +++ b/wyvern/components/single_entity_pipeline.py @@ -10,6 +10,9 @@ from wyvern.components.component import Component from wyvern.components.events.events import LoggedEvent from wyvern.components.models.model_component import SingleEntityModelComponent +from wyvern.components.models.model_experimentation_component import ( + ModelExperimentation, +) from wyvern.components.pipeline_component import PipelineComponent from wyvern.entities.identifier import Identifier from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE @@ -33,7 +36,7 @@ class SingleEntityPipeline( def __init__( self, *upstreams: Component, - model: SingleEntityModelComponent, + model: SingleEntityModelComponent | ModelExperimentation, business_logic: Optional[ SingleEntityBusinessLogicPipeline[REQUEST_ENTITY, MODEL_OUTPUT_DATA_TYPE] ] = None,