This repository has been archived by the owner on Mar 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 7
SingleEntityPipeline (v0.0.18-beta5) #69
Merged
Merged
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
f63a4f3
introduce singular model component and singular pipeline
wintonzheng cce97e3
SingularBusinessLogicPipeline
wintonzheng bfb61c8
integrate with business logic in SingularPipelineComponent
wintonzheng 524c58f
update to SingleEntity style
wintonzheng 2e1909e
update to SingleEntity style
wintonzheng 88951c1
single entity pipeline
wintonzheng f7008f3
update generic type order for business logic components
wintonzheng c8bc660
SingleEntityModelbitComponent
wintonzheng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
177 changes: 177 additions & 0 deletions
177
tests/scenarios/single_entity_pipelines/test_single_entity_pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
# -*- coding: utf-8 -*- | ||
from typing import List | ||
|
||
import pytest | ||
from fastapi.testclient import TestClient | ||
|
||
from wyvern.components.business_logic.business_logic import ( | ||
SingleEntityBusinessLogicComponent, | ||
SingleEntityBusinessLogicPipeline, | ||
SingleEntityBusinessLogicRequest, | ||
) | ||
from wyvern.components.models.model_chain_component import SingleEntityModelChain | ||
from wyvern.components.models.model_component import SingleEntityModelComponent | ||
from wyvern.components.single_entity_pipeline import ( | ||
SingleEntityPipeline, | ||
SingleEntityPipelineResponse, | ||
) | ||
from wyvern.entities.identifier import Identifier | ||
from wyvern.entities.identifier_entities import WyvernEntity | ||
from wyvern.entities.model_entities import ModelOutput | ||
from wyvern.entities.request import BaseWyvernRequest | ||
from wyvern.service import WyvernService | ||
|
||
|
||
class Seller(WyvernEntity): | ||
seller_id: str | ||
|
||
def generate_identifier(self) -> Identifier: | ||
return Identifier( | ||
identifier=self.seller_id, | ||
identifier_type="seller", | ||
) | ||
|
||
|
||
class Buyer(WyvernEntity): | ||
buyer_id: str | ||
|
||
def generate_identifier(self) -> Identifier: | ||
return Identifier( | ||
identifier=self.buyer_id, | ||
identifier_type="buyer", | ||
) | ||
|
||
|
||
class Order(WyvernEntity): | ||
order_id: str | ||
|
||
def generate_identifier(self) -> Identifier: | ||
return Identifier( | ||
identifier=self.order_id, | ||
identifier_type="order", | ||
) | ||
|
||
|
||
class FraudRequest(BaseWyvernRequest): | ||
seller: Seller | ||
buyer: Buyer | ||
order: Order | ||
|
||
|
||
class FraudResponse(SingleEntityPipelineResponse[float]): | ||
reasons: List[str] | ||
|
||
|
||
class FraudRuleModel(SingleEntityModelComponent[FraudRequest, ModelOutput[float]]): | ||
async def inference(self, input: FraudRequest, **kwargs) -> ModelOutput[float]: | ||
return ModelOutput( | ||
data={ | ||
input.order.identifier: 1, | ||
}, | ||
) | ||
|
||
|
||
class FraudAssessmentModel( | ||
SingleEntityModelComponent[FraudRequest, ModelOutput[float]], | ||
): | ||
async def inference(self, input: FraudRequest, **kwargs) -> ModelOutput[float]: | ||
return ModelOutput( | ||
data={ | ||
input.order.identifier: 1, | ||
}, | ||
) | ||
|
||
|
||
fraud_model = SingleEntityModelChain[FraudRequest, ModelOutput[float]]( | ||
FraudRuleModel(), | ||
FraudAssessmentModel(), | ||
name="fraud_model", | ||
) | ||
|
||
|
||
class FraudBusinessLogicComponent( | ||
SingleEntityBusinessLogicComponent[float, FraudRequest], | ||
): | ||
async def execute( | ||
self, | ||
input: SingleEntityBusinessLogicRequest[float, FraudRequest], | ||
**kwargs, | ||
) -> float: | ||
if input.request.seller.identifier.identifier == "test_seller_new": | ||
return 0.0 | ||
return input.model_output | ||
|
||
|
||
fraud_biz_pipeline = SingleEntityBusinessLogicPipeline( | ||
FraudBusinessLogicComponent(), | ||
name="fraud_biz_pipeline", | ||
) | ||
|
||
|
||
class FraudPipeline(SingleEntityPipeline[FraudRequest, float]): | ||
PATH = "/fraud" | ||
REQUEST_SCHEMA_CLASS = FraudRequest | ||
RESPONSE_SCHEMA_CLASS = FraudResponse | ||
|
||
def generate_response( | ||
self, | ||
input: FraudRequest, | ||
pipeline_output: float, | ||
) -> FraudResponse: | ||
if pipeline_output == 0.0: | ||
return FraudResponse( | ||
data=pipeline_output, | ||
reasons=["Fraudulent order detected!"], | ||
) | ||
return FraudResponse( | ||
data=pipeline_output, | ||
reasons=[], | ||
) | ||
|
||
|
||
fraud_pipeline = FraudPipeline(model=fraud_model, business_logic=fraud_biz_pipeline) | ||
|
||
|
||
@pytest.fixture | ||
def mock_redis(mocker): | ||
with mocker.patch( | ||
"wyvern.redis.wyvern_redis.mget", | ||
return_value=[], | ||
): | ||
yield | ||
|
||
|
||
@pytest.fixture | ||
def test_client(mock_redis): | ||
wyvern_app = WyvernService.generate_app( | ||
route_components=[fraud_pipeline], | ||
) | ||
yield TestClient(wyvern_app) | ||
|
||
|
||
def test_end_to_end(test_client): | ||
response = test_client.post( | ||
"/api/v1/fraud", | ||
json={ | ||
"request_id": "test_request_id", | ||
"seller": {"seller_id": "test_seller_id"}, | ||
"buyer": {"buyer_id": "test_buyer_id"}, | ||
"order": {"order_id": "test_order_id"}, | ||
}, | ||
) | ||
assert response.status_code == 200 | ||
assert response.json() == {"data": 1.0, "reasons": []} | ||
|
||
|
||
def test_end_to_end__new_seller(test_client): | ||
response = test_client.post( | ||
"/api/v1/fraud", | ||
json={ | ||
"request_id": "test_request_id", | ||
"seller": {"seller_id": "test_seller_new"}, | ||
"buyer": {"buyer_id": "test_buyer_id"}, | ||
"order": {"order_id": "test_order_id"}, | ||
}, | ||
) | ||
assert response.status_code == 200 | ||
assert response.json() == {"data": 0.0, "reasons": ["Fraudulent order detected!"]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the "new experience" of passing the model and business_logic to the pipeline to define the pipeline.
Does it look simpler than the current RankingPipeline.get_model pattern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ykeremy i think the
model=xxx
pattern definitely makes it easier to play around with different version of models in the future for a pipeline. i prefermodel=xxx
pattern. wdyt?