Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
cache model output
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng committed Sep 13, 2023
1 parent 016d620 commit 148de8b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
9 changes: 8 additions & 1 deletion wyvern/components/models/model_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ def __init__(
self,
*upstreams,
name: Optional[str] = None,
cache_output: bool = True,
):
super().__init__(*upstreams, name=name)
self.model_input_type = self.get_type_args_simple(0)
self.model_output_type = self.get_type_args_simple(1)

self.cache_output = cache_output

@classmethod
def get_type_args_simple(cls, index: int) -> Type:
"""
Expand All @@ -95,10 +98,14 @@ async def execute(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:
"""
The model_name and model_score will be automatically logged
"""
api_source = request_context.ensure_current_request().url_path
wyvern_request = request_context.ensure_current_request()
api_source = wyvern_request.url_path
request_id = input.request.request_id
model_output = await self.inference(input, **kwargs)

if self.cache_output:
wyvern_request.cache_model_score(self.name, model_output.data)

def events_generator() -> List[ModelEvent]:
timestamp = datetime.utcnow()
return [
Expand Down
9 changes: 9 additions & 0 deletions wyvern/wyvern_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,12 @@ def parse_fastapi_request(
model_score_map={},
request_id=request_id,
)

def cache_model_score(
self,
model_name: str,
data: Dict[Identifier, Union[float, str, List[float], None]],
) -> None:
if model_name not in self.model_score_map:
self.model_score_map[model_name] = {}
self.model_score_map[model_name].update(data)

0 comments on commit 148de8b

Please sign in to comment.