Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

v0.0.18 - chained model evaluation #63

Merged
merged 11 commits into from
Sep 26, 2023
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "wyvern-ai"
version = "0.0.17"
version = "0.0.18"
description = ""
authors = ["Wyvern AI <[email protected]>"]
readme = "README.md"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self):
entity_store={},
events=[],
feature_map=FeatureMap(feature_map={}),
model_output_map={},
),
)
return await pipeline.execute(request)
Expand Down
Empty file.
177 changes: 177 additions & 0 deletions tests/scenarios/single_entity_pipelines/test_single_entity_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# -*- coding: utf-8 -*-
from typing import List

import pytest
from fastapi.testclient import TestClient

from wyvern.components.business_logic.business_logic import (
SingleEntityBusinessLogicComponent,
SingleEntityBusinessLogicPipeline,
SingleEntityBusinessLogicRequest,
)
from wyvern.components.models.model_chain_component import SingleEntityModelChain
from wyvern.components.models.model_component import SingleEntityModelComponent
from wyvern.components.single_entity_pipeline import (
SingleEntityPipeline,
SingleEntityPipelineResponse,
)
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import WyvernEntity
from wyvern.entities.model_entities import ModelOutput
from wyvern.entities.request import BaseWyvernRequest
from wyvern.service import WyvernService


class Seller(WyvernEntity):
seller_id: str

def generate_identifier(self) -> Identifier:
return Identifier(
identifier=self.seller_id,
identifier_type="seller",
)


class Buyer(WyvernEntity):
buyer_id: str

def generate_identifier(self) -> Identifier:
return Identifier(
identifier=self.buyer_id,
identifier_type="buyer",
)


class Order(WyvernEntity):
order_id: str

def generate_identifier(self) -> Identifier:
return Identifier(
identifier=self.order_id,
identifier_type="order",
)


class FraudRequest(BaseWyvernRequest):
seller: Seller
buyer: Buyer
order: Order


class FraudResponse(SingleEntityPipelineResponse[float]):
reasons: List[str]


class FraudRuleModel(SingleEntityModelComponent[FraudRequest, ModelOutput[float]]):
async def inference(self, input: FraudRequest, **kwargs) -> ModelOutput[float]:
return ModelOutput(
data={
input.order.identifier: 1,
},
)


class FraudAssessmentModel(
SingleEntityModelComponent[FraudRequest, ModelOutput[float]],
):
async def inference(self, input: FraudRequest, **kwargs) -> ModelOutput[float]:
return ModelOutput(
data={
input.order.identifier: 1,
},
)


fraud_model = SingleEntityModelChain[FraudRequest, ModelOutput[float]](
FraudRuleModel(),
FraudAssessmentModel(),
name="fraud_model",
)


class FraudBusinessLogicComponent(
SingleEntityBusinessLogicComponent[FraudRequest, float],
):
async def execute(
self,
input: SingleEntityBusinessLogicRequest[FraudRequest, float],
**kwargs,
) -> float:
if input.request.seller.identifier.identifier == "test_seller_new":
return 0.0
return input.model_output


fraud_biz_pipeline = SingleEntityBusinessLogicPipeline(
FraudBusinessLogicComponent(),
name="fraud_biz_pipeline",
)


class FraudPipeline(SingleEntityPipeline[FraudRequest, float]):
PATH = "/fraud"
REQUEST_SCHEMA_CLASS = FraudRequest
RESPONSE_SCHEMA_CLASS = FraudResponse

def generate_response(
self,
input: FraudRequest,
pipeline_output: float,
) -> FraudResponse:
if pipeline_output == 0.0:
return FraudResponse(
data=pipeline_output,
reasons=["Fraudulent order detected!"],
)
return FraudResponse(
data=pipeline_output,
reasons=[],
)


fraud_pipeline = FraudPipeline(model=fraud_model, business_logic=fraud_biz_pipeline)


@pytest.fixture
def mock_redis(mocker):
with mocker.patch(
"wyvern.redis.wyvern_redis.mget",
return_value=[],
):
yield


@pytest.fixture
def test_client(mock_redis):
wyvern_app = WyvernService.generate_app(
route_components=[fraud_pipeline],
)
yield TestClient(wyvern_app)


def test_end_to_end(test_client):
response = test_client.post(
"/api/v1/fraud",
json={
"request_id": "test_request_id",
"seller": {"seller_id": "test_seller_id"},
"buyer": {"buyer_id": "test_buyer_id"},
"order": {"order_id": "test_order_id"},
},
)
assert response.status_code == 200
assert response.json() == {"data": 1.0, "reasons": []}


def test_end_to_end__new_seller(test_client):
response = test_client.post(
"/api/v1/fraud",
json={
"request_id": "test_request_id",
"seller": {"seller_id": "test_seller_new"},
"buyer": {"buyer_id": "test_buyer_id"},
"order": {"order_id": "test_order_id"},
},
)
assert response.status_code == 200
assert response.json() == {"data": 0.0, "reasons": ["Fraudulent order detected!"]}
14 changes: 7 additions & 7 deletions tests/scenarios/test_product_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
RealtimeFeatureComponent,
RealtimeFeatureRequest,
)
from wyvern.components.models.model_component import (
ModelComponent,
ModelInput,
ModelOutput,
)
from wyvern.components.models.model_component import ModelComponent
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.config import settings
from wyvern.core.compression import wyvern_encode
Expand All @@ -26,6 +22,7 @@
from wyvern.entities.feature_entities import FeatureData, FeatureMap
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import ProductEntity, WyvernEntity
from wyvern.entities.model_entities import ModelInput, ModelOutput
from wyvern.entities.request import BaseWyvernRequest
from wyvern.service import WyvernService
from wyvern.wyvern_request import WyvernRequest
Expand Down Expand Up @@ -321,10 +318,11 @@ async def execute(

@pytest.fixture
def test_client(mock_redis):
wyvern_service = WyvernService.generate(
wyvern_app = WyvernService.generate_app(
route_components=[RankingComponent],
realtime_feature_components=[],
)
yield TestClient(wyvern_service.service.app)
yield TestClient(wyvern_app)


def test_get_all_identifiers():
Expand Down Expand Up @@ -387,6 +385,7 @@ async def test_hydrate(mock_redis):
json=json_input,
headers={},
entity_store={},
model_output_map={},
events=[],
feature_map=FeatureMap(feature_map={}),
)
Expand Down Expand Up @@ -450,6 +449,7 @@ async def test_hydrate__duplicate_brand(mock_redis__duplicate_brand):
entity_store={},
events=[],
feature_map=FeatureMap(feature_map={}),
model_output_map={},
)
request_context.set(test_wyvern_request)

Expand Down
16 changes: 14 additions & 2 deletions wyvern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@
from wyvern.components.features.realtime_features_component import (
RealtimeFeatureComponent,
)
from wyvern.components.models.model_chain_component import SingleEntityModelChain
from wyvern.components.models.model_component import (
ModelComponent,
ModelInput,
ModelOutput,
MultiEntityModelComponent,
SingleEntityModelComponent,
)
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.components.ranking_pipeline import (
RankingPipeline,
RankingRequest,
RankingResponse,
)
from wyvern.components.single_entity_pipeline import (
SingleEntityPipeline,
SingleEntityPipelineResponse,
)
from wyvern.entities.candidate_entities import CandidateSetEntity
from wyvern.entities.feature_entities import FeatureData, FeatureMap
from wyvern.entities.identifier import CompositeIdentifier, Identifier, IdentifierType
Expand All @@ -23,6 +28,7 @@
WyvernDataModel,
WyvernEntity,
)
from wyvern.entities.model_entities import ChainedModelInput, ModelInput, ModelOutput
from wyvern.feature_store.feature_server import generate_wyvern_store_app
from wyvern.service import WyvernService
from wyvern.wyvern_logging import setup_logging
Expand All @@ -36,6 +42,7 @@
__all__ = [
"generate_wyvern_store_app",
"CandidateSetEntity",
"ChainedModelInput",
"CompositeIdentifier",
"FeatureData",
"FeatureMap",
Expand All @@ -44,13 +51,18 @@
"ModelComponent",
"ModelInput",
"ModelOutput",
"MultiEntityModelComponent",
"PipelineComponent",
"ProductEntity",
"QueryEntity",
"RankingPipeline",
"RankingResponse",
"RankingRequest",
"RealtimeFeatureComponent",
"SingleEntityModelChain",
"SingleEntityModelComponent",
"SingleEntityPipeline",
"SingleEntityPipelineResponse",
"UserEntity",
"WyvernDataModel",
"WyvernEntity",
Expand Down
Loading