Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
support dict model output
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng committed Sep 14, 2023
1 parent ee85723 commit b173d7b
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 16 deletions.
50 changes: 36 additions & 14 deletions wyvern/components/models/model_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ModelEventData(BaseModel):
model_output: str
entity_identifier: Optional[str] = None
entity_identifier_type: Optional[str] = None
target: Optional[str] = None


class ModelEvent(LoggedEvent[ModelEventData]):
Expand Down Expand Up @@ -108,20 +109,41 @@ async def execute(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:

def events_generator() -> List[ModelEvent]:
timestamp = datetime.utcnow()
return [
ModelEvent(
request_id=request_id,
api_source=api_source,
event_timestamp=timestamp,
event_data=ModelEventData(
model_name=model_output.model_name or self.__class__.__name__,
model_output=str(output),
entity_identifier=identifier.identifier,
entity_identifier_type=identifier.identifier_type,
),
)
for identifier, output in model_output.data.items()
]
all_events: List[ModelEvent] = []
for identifier, output in model_output.data.items():
if isinstance(output, dict):
for key, value in output.items():
all_events.append(
ModelEvent(
request_id=request_id,
api_source=api_source,
event_timestamp=timestamp,
event_data=ModelEventData(
model_name=model_output.model_name
or self.__class__.__name__,
model_output=str(value),
target=key,
entity_identifier=identifier.identifier,
entity_identifier_type=identifier.identifier_type,
),
),
)
else:
all_events.append(
ModelEvent(
request_id=request_id,
api_source=api_source,
event_timestamp=timestamp,
event_data=ModelEventData(
model_name=model_output.model_name
or self.__class__.__name__,
model_output=str(output),
entity_identifier=identifier.identifier,
entity_identifier_type=identifier.identifier_type,
),
),
)
return all_events

event_logger.log_events(events_generator) # type: ignore

Expand Down
19 changes: 17 additions & 2 deletions wyvern/entities/model_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

MODEL_OUTPUT_DATA_TYPE = TypeVar(
"MODEL_OUTPUT_DATA_TYPE",
bound=Union[float, str, List[float]],
bound=Union[
float,
str,
List[float],
Dict[str, Optional[Union[float, str, list[float]]]],
],
)
"""
MODEL_OUTPUT_DATA_TYPE is the type of the output of the model. It can be a float, a string, or a list of floats
Expand Down Expand Up @@ -82,7 +87,17 @@ def first_identifier(self) -> Identifier:


class ChainedModelInput(ModelInput[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]):
upstream_model_output: Dict[Identifier, Optional[Union[float, str, List[float]]]]
upstream_model_output: Dict[
Identifier,
Optional[
Union[
float,
str,
List[float],
Dict[str, Optional[Union[float, str, list[float]]]],
]
],
]
upstream_model_name: Optional[str] = None


Expand Down

0 comments on commit b173d7b

Please sign in to comment.