Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

SingleEntityPipeline (v0.0.18-beta5) #69

Merged
merged 8 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
177 changes: 177 additions & 0 deletions tests/scenarios/single_entity_pipelines/test_single_entity_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# -*- coding: utf-8 -*-
from typing import List

import pytest
from fastapi.testclient import TestClient

from wyvern.components.business_logic.business_logic import (
SingleEntityBusinessLogicComponent,
SingleEntityBusinessLogicPipeline,
SingleEntityBusinessLogicRequest,
)
from wyvern.components.models.model_chain_component import SingleEntityModelChain
from wyvern.components.models.model_component import SingleEntityModelComponent
from wyvern.components.single_entity_pipeline import (
SingleEntityPipeline,
SingleEntityPipelineResponse,
)
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import WyvernEntity
from wyvern.entities.model_entities import ModelOutput
from wyvern.entities.request import BaseWyvernRequest
from wyvern.service import WyvernService


class Seller(WyvernEntity):
seller_id: str

def generate_identifier(self) -> Identifier:
return Identifier(
identifier=self.seller_id,
identifier_type="seller",
)


class Buyer(WyvernEntity):
buyer_id: str

def generate_identifier(self) -> Identifier:
return Identifier(
identifier=self.buyer_id,
identifier_type="buyer",
)


class Order(WyvernEntity):
order_id: str

def generate_identifier(self) -> Identifier:
return Identifier(
identifier=self.order_id,
identifier_type="order",
)


class FraudRequest(BaseWyvernRequest):
seller: Seller
buyer: Buyer
order: Order


class FraudResponse(SingleEntityPipelineResponse[float]):
reasons: List[str]


class FraudRuleModel(SingleEntityModelComponent[FraudRequest, ModelOutput[float]]):
async def inference(self, input: FraudRequest, **kwargs) -> ModelOutput[float]:
return ModelOutput(
data={
input.order.identifier: 1,
},
)


class FraudAssessmentModel(
SingleEntityModelComponent[FraudRequest, ModelOutput[float]],
):
async def inference(self, input: FraudRequest, **kwargs) -> ModelOutput[float]:
return ModelOutput(
data={
input.order.identifier: 1,
},
)


fraud_model = SingleEntityModelChain[FraudRequest, ModelOutput[float]](
FraudRuleModel(),
FraudAssessmentModel(),
name="fraud_model",
)


class FraudBusinessLogicComponent(
SingleEntityBusinessLogicComponent[float, FraudRequest],
):
async def execute(
self,
input: SingleEntityBusinessLogicRequest[float, FraudRequest],
**kwargs,
) -> float:
if input.request.seller.identifier.identifier == "test_seller_new":
return 0.0
return input.model_output


fraud_biz_pipeline = SingleEntityBusinessLogicPipeline(
FraudBusinessLogicComponent(),
name="fraud_biz_pipeline",
)


class FraudPipeline(SingleEntityPipeline[FraudRequest, float]):
PATH = "/fraud"
REQUEST_SCHEMA_CLASS = FraudRequest
RESPONSE_SCHEMA_CLASS = FraudResponse

def generate_response(
self,
input: FraudRequest,
pipeline_output: float,
) -> FraudResponse:
if pipeline_output == 0.0:
return FraudResponse(
data=pipeline_output,
reasons=["Fraudulent order detected!"],
)
return FraudResponse(
data=pipeline_output,
reasons=[],
)


fraud_pipeline = FraudPipeline(model=fraud_model, business_logic=fraud_biz_pipeline)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the "new experience" of passing the model and business_logic to the pipeline to define the pipeline.

Does it look simpler than the current RankingPipeline.get_model pattern?

def get_model(self) -> ModelCompolent:
    return SomeModelComponent()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ykeremy i think the model=xxx pattern definitely makes it easier to play around with different version of models in the future for a pipeline. i prefer model=xxx pattern. wdyt?



@pytest.fixture
def mock_redis(mocker):
with mocker.patch(
"wyvern.redis.wyvern_redis.mget",
return_value=[],
):
yield


@pytest.fixture
def test_client(mock_redis):
wyvern_app = WyvernService.generate_app(
route_components=[fraud_pipeline],
)
yield TestClient(wyvern_app)


def test_end_to_end(test_client):
response = test_client.post(
"/api/v1/fraud",
json={
"request_id": "test_request_id",
"seller": {"seller_id": "test_seller_id"},
"buyer": {"buyer_id": "test_buyer_id"},
"order": {"order_id": "test_order_id"},
},
)
assert response.status_code == 200
assert response.json() == {"data": 1.0, "reasons": []}


def test_end_to_end__new_seller(test_client):
response = test_client.post(
"/api/v1/fraud",
json={
"request_id": "test_request_id",
"seller": {"seller_id": "test_seller_new"},
"buyer": {"buyer_id": "test_buyer_id"},
"order": {"order_id": "test_order_id"},
},
)
assert response.status_code == 200
assert response.json() == {"data": 0.0, "reasons": ["Fraudulent order detected!"]}
5 changes: 3 additions & 2 deletions tests/scenarios/test_product_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,11 @@ async def execute(

@pytest.fixture
def test_client(mock_redis):
wyvern_service = WyvernService.generate(
wyvern_app = WyvernService.generate_app(
route_components=[RankingComponent],
realtime_feature_components=[],
)
yield TestClient(wyvern_service.service.app)
yield TestClient(wyvern_app)


def test_get_all_identifiers():
Expand Down
18 changes: 15 additions & 3 deletions wyvern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
from wyvern.components.features.realtime_features_component import (
RealtimeFeatureComponent,
)
from wyvern.components.models.model_chain_component import ModelChainComponent
from wyvern.components.models.model_component import ModelComponent
from wyvern.components.models.model_chain_component import SingleEntityModelChain
from wyvern.components.models.model_component import (
ModelComponent,
MultiEntityModelComponent,
SingleEntityModelComponent,
)
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.components.ranking_pipeline import (
RankingPipeline,
RankingRequest,
RankingResponse,
)
from wyvern.components.single_entity_pipeline import (
SingleEntityPipeline,
SingleEntityPipelineResponse,
)
from wyvern.entities.candidate_entities import CandidateSetEntity
from wyvern.entities.feature_entities import FeatureData, FeatureMap
from wyvern.entities.identifier import CompositeIdentifier, Identifier, IdentifierType
Expand Down Expand Up @@ -40,17 +48,21 @@
"FeatureMap",
"Identifier",
"IdentifierType",
"ModelChainComponent",
"ModelComponent",
"ModelInput",
"ModelOutput",
"MultiEntityModelComponent",
"PipelineComponent",
"ProductEntity",
"QueryEntity",
"RankingPipeline",
"RankingResponse",
"RankingRequest",
"RealtimeFeatureComponent",
"SingleEntityModelChain",
"SingleEntityModelComponent",
"SingleEntityPipeline",
"SingleEntityPipelineResponse",
"UserEntity",
"WyvernDataModel",
"WyvernEntity",
Expand Down
Loading