From 2e1909ef2d5148e03347810e53ad03b9c462eef9 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Mon, 18 Sep 2023 15:52:27 -0700 Subject: [PATCH] update to SingleEntity style --- .../business_logic/business_logic.py | 88 +++++++++---------- wyvern/entities/candidate_entities.py | 11 +-- 2 files changed, 43 insertions(+), 56 deletions(-) diff --git a/wyvern/components/business_logic/business_logic.py b/wyvern/components/business_logic/business_logic.py index 280512a..d186902 100644 --- a/wyvern/components/business_logic/business_logic.py +++ b/wyvern/components/business_logic/business_logic.py @@ -15,7 +15,6 @@ from wyvern.entities.candidate_entities import ( GENERALIZED_WYVERN_ENTITY, ScoredCandidate, - ScoredEntityProtocol, ) from wyvern.entities.identifier import Identifier from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE @@ -151,56 +150,11 @@ class SingleEntityBusinessLogicComponent( pass -class ExtractEventMixin(Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE]): - def extract_business_logic_events( - self, - output: Sequence[ScoredEntityProtocol], - pipeline_index: int, - upstream_name: str, - request_id: str, - old_scores: List, - ) -> List[BusinessLogicEvent]: - """ - Extracts the business logic events from the output of a business logic component - - Args: - output: The output of a business logic component - pipeline_index: The index of the business logic component in the business logic pipeline - upstream_name: The name of the business logic component - request_id: The request id of the request that the business logic component was called in - old_scores: The old scores of the candidates that the business logic component was called on - - Returns: - The business logic events that were extracted from the output of the business logic component - """ - timestamp = datetime.utcnow() - events = [ - BusinessLogicEvent( - request_id=request_id, - api_source=request_context.ensure_current_request().url_path, - event_timestamp=timestamp, - event_data=BusinessLogicEventData( - business_logic_pipeline_order=pipeline_index, - business_logic_name=upstream_name, - old_score=str(old_scores[j]), - new_score=str(output[j].score), - entity_identifier=candidate.entity.identifier.identifier, - entity_identifier_type=candidate.entity.identifier.identifier_type, - ), - ) - for (j, candidate) in enumerate(output) - if output[j].score != old_scores[j] - ] - - return events - - class BusinessLogicPipeline( Component[ BusinessLogicRequest[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], BusinessLogicResponse[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], ], - ExtractEventMixin[GENERALIZED_WYVERN_ENTITY, float], Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY], ): """ @@ -282,6 +236,48 @@ def log_events( adjusted_candidates=output, ) + def extract_business_logic_events( + self, + output: Sequence[ScoredCandidate[GENERALIZED_WYVERN_ENTITY]], + pipeline_index: int, + upstream_name: str, + request_id: str, + old_scores: List, + ) -> List[BusinessLogicEvent]: + """ + Extracts the business logic events from the output of a business logic component + + Args: + output: The output of a business logic component + pipeline_index: The index of the business logic component in the business logic pipeline + upstream_name: The name of the business logic component + request_id: The request id of the request that the business logic component was called in + old_scores: The old scores of the candidates that the business logic component was called on + + Returns: + The business logic events that were extracted from the output of the business logic component + """ + timestamp = datetime.utcnow() + events = [ + BusinessLogicEvent( + request_id=request_id, + api_source=request_context.ensure_current_request().url_path, + event_timestamp=timestamp, + event_data=BusinessLogicEventData( + business_logic_pipeline_order=pipeline_index, + business_logic_name=upstream_name, + old_score=str(old_scores[j]), + new_score=str(output[j].score), + entity_identifier=candidate.entity.identifier.identifier, + entity_identifier_type=candidate.entity.identifier.identifier_type, + ), + ) + for (j, candidate) in enumerate(output) + if output[j].score != old_scores[j] + ] + + return events + class SingleEntityBusinessLogicPipeline( Component[ diff --git a/wyvern/entities/candidate_entities.py b/wyvern/entities/candidate_entities.py index 1aa93da..8c529a1 100644 --- a/wyvern/entities/candidate_entities.py +++ b/wyvern/entities/candidate_entities.py @@ -1,23 +1,14 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from typing import Generic, List, Protocol, TypeVar +from typing import Generic, List, TypeVar from pydantic.generics import GenericModel from wyvern.entities.identifier_entities import WyvernDataModel -from wyvern.entities.model_entities import MODEL_OUTPUT_DATA_TYPE from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY -class ScoredEntityProtocol( - Protocol, - Generic[GENERALIZED_WYVERN_ENTITY, MODEL_OUTPUT_DATA_TYPE], -): - entity: GENERALIZED_WYVERN_ENTITY - score: MODEL_OUTPUT_DATA_TYPE - - # TODO (suchintan): This should be renamed to ScoredEntity probably class ScoredCandidate( GenericModel,