Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
update to SingleEntity style
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng committed Sep 18, 2023
1 parent 524c58f commit 2e1909e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 56 deletions.
88 changes: 42 additions & 46 deletions wyvern/components/business_logic/business_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from wyvern.entities.candidate_entities import (
GENERALIZED_WYVERN_ENTITY,
ScoredCandidate,
ScoredEntityProtocol,
)
from wyvern.entities.identifier import Identifier
from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE
Expand Down Expand Up @@ -151,56 +150,11 @@ class SingleEntityBusinessLogicComponent(
pass


class ExtractEventMixin(Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE]):
def extract_business_logic_events(
self,
output: Sequence[ScoredEntityProtocol],
pipeline_index: int,
upstream_name: str,
request_id: str,
old_scores: List,
) -> List[BusinessLogicEvent]:
"""
Extracts the business logic events from the output of a business logic component
Args:
output: The output of a business logic component
pipeline_index: The index of the business logic component in the business logic pipeline
upstream_name: The name of the business logic component
request_id: The request id of the request that the business logic component was called in
old_scores: The old scores of the candidates that the business logic component was called on
Returns:
The business logic events that were extracted from the output of the business logic component
"""
timestamp = datetime.utcnow()
events = [
BusinessLogicEvent(
request_id=request_id,
api_source=request_context.ensure_current_request().url_path,
event_timestamp=timestamp,
event_data=BusinessLogicEventData(
business_logic_pipeline_order=pipeline_index,
business_logic_name=upstream_name,
old_score=str(old_scores[j]),
new_score=str(output[j].score),
entity_identifier=candidate.entity.identifier.identifier,
entity_identifier_type=candidate.entity.identifier.identifier_type,
),
)
for (j, candidate) in enumerate(output)
if output[j].score != old_scores[j]
]

return events


class BusinessLogicPipeline(
Component[
BusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
BusinessLogicResponse[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
],
ExtractEventMixin[GENERALIZED_WYVERN_ENTITY, float],
Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY],
):
"""
Expand Down Expand Up @@ -282,6 +236,48 @@ def log_events(
adjusted_candidates=output,
)

def extract_business_logic_events(
self,
output: Sequence[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]],
pipeline_index: int,
upstream_name: str,
request_id: str,
old_scores: List,
) -> List[BusinessLogicEvent]:
"""
Extracts the business logic events from the output of a business logic component
Args:
output: The output of a business logic component
pipeline_index: The index of the business logic component in the business logic pipeline
upstream_name: The name of the business logic component
request_id: The request id of the request that the business logic component was called in
old_scores: The old scores of the candidates that the business logic component was called on
Returns:
The business logic events that were extracted from the output of the business logic component
"""
timestamp = datetime.utcnow()
events = [
BusinessLogicEvent(
request_id=request_id,
api_source=request_context.ensure_current_request().url_path,
event_timestamp=timestamp,
event_data=BusinessLogicEventData(
business_logic_pipeline_order=pipeline_index,
business_logic_name=upstream_name,
old_score=str(old_scores[j]),
new_score=str(output[j].score),
entity_identifier=candidate.entity.identifier.identifier,
entity_identifier_type=candidate.entity.identifier.identifier_type,
),
)
for (j, candidate) in enumerate(output)
if output[j].score != old_scores[j]
]

return events


class SingleEntityBusinessLogicPipeline(
Component[
Expand Down
11 changes: 1 addition & 10 deletions wyvern/entities/candidate_entities.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from typing import Generic, List, Protocol, TypeVar
from typing import Generic, List, TypeVar

from pydantic.generics import GenericModel

from wyvern.entities.identifier_entities import WyvernDataModel
from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE
from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY


class ScoredEntityProtocol(
Protocol,
Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE],
):
entity: GENERALIZED_WYVERN_ENTITY
score: MODEL_OUTPUT_DATA_TYPE


# TODO (suchintan): This should be renamed to ScoredEntity probably
class ScoredCandidate(
GenericModel,
Expand Down

0 comments on commit 2e1909e

Please sign in to comment.