This repository has been archived by the owner on Mar 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* multi model evaluation * fix test * chained model evaluation * cache model output * v bump * support dict model output * rename * add get_model_output support * fix * 0.0.18-beta1
- Loading branch information
1 parent
8dc9cef
commit 904a622
Showing
12 changed files
with
292 additions
and
126 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "wyvern-ai" | ||
version = "0.0.17" | ||
version = "0.0.18-beta1" | ||
description = "" | ||
authors = ["Wyvern AI <[email protected]>"] | ||
readme = "README.md" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# -*- coding: utf-8 -*- | ||
from functools import cached_property | ||
from typing import Optional, Set | ||
|
||
from wyvern.components.models.model_component import ModelComponent | ||
from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT, ChainedModelInput | ||
from wyvern.exceptions import MissingModelChainOutputError | ||
|
||
|
||
class ModelChainComponent(ModelComponent[MODEL_INPUT, MODEL_OUTPUT]): | ||
""" | ||
Model chaining allows you to chain models together so that the output of one model can be the input to another model | ||
For all the models in the chain, all the request and entities in the model input are the same | ||
""" | ||
|
||
def __init__(self, *upstreams: ModelComponent, name: Optional[str] = None): | ||
super().__init__(*upstreams, name=name) | ||
self.chain = upstreams | ||
|
||
@cached_property | ||
def manifest_feature_names(self) -> Set[str]: | ||
feature_names: Set[str] = set() | ||
for model in self.chain: | ||
feature_names = feature_names.union(model.manifest_feature_names) | ||
return feature_names | ||
|
||
async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: | ||
output = None | ||
prev_model: Optional[ModelComponent] = None | ||
for model in self.chain: | ||
curr_input: ChainedModelInput | ||
if prev_model is not None and output is not None: | ||
curr_input = ChainedModelInput( | ||
request=input.request, | ||
entities=input.entities, | ||
upstream_model_name=prev_model.name, | ||
upstream_model_output=output.data, | ||
) | ||
else: | ||
curr_input = ChainedModelInput( | ||
request=input.request, | ||
entities=input.entities, | ||
upstream_model_name=None, | ||
upstream_model_output={}, | ||
) | ||
output = await model.execute(curr_input, **kwargs) | ||
prev_model = model | ||
|
||
if output is None: | ||
raise MissingModelChainOutputError() | ||
|
||
# TODO: do type checking to make sure the output is of the correct type | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.