From 3c4a9e6b6f7432f4ffcf6c3a437f0ccd68f9f90a Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Mon, 16 Oct 2023 23:20:50 +0300 Subject: [PATCH] Utilize Polars.DataFrame for performance in ModelbitComponent (#80) * Add polars dependency * Use data frame to gather features * Add ddtrace to feature retrieval * Bump version * Set retrieved features to request context * Handle case when identifier type isn't in feature map * Remove FeatureMap and FeatureData * Bump version * Use class str method for key generation, fix crucial bug in build_requests * Check if there are RT features before merging DFs * Initialize RT feature df with empty df * Cast all f32 columns to f64 before concat * Remove feature map * Fix some bugs during testing * Update tests, fix concat vs join bugs * Update feature_store_main.py * Fix logging to use original identifiers instead of primary identifiers * Don't replace : with __ for requests * Add a null check * Keep replacing only 1 feature name separator * Change component.get_feature return type --- .pre-commit-config.yaml | 1 + examples/feature_store_main.py | 24 +- examples/real_time_features_main.py | 22 +- poetry.lock | 62 +++--- pyproject.toml | 3 +- .../test_pinning_business_logic.py | 6 +- .../feature_store/test_real_time_features.py | 208 ++++-------------- tests/scenarios/test_product_ranking.py | 9 +- wyvern/__init__.py | 3 - wyvern/components/component.py | 71 +++++- wyvern/components/features/feature_logger.py | 31 ++- .../features/feature_retrieval_pipeline.py | 121 ++++++---- wyvern/components/features/feature_store.py | 79 ++++--- .../features/realtime_features_component.py | 70 ++++-- wyvern/components/helpers/polars.py | 9 + .../components/models/modelbit_component.py | 15 +- wyvern/components/pipeline_component.py | 9 +- wyvern/entities/feature_entities.py | 89 ++++++-- wyvern/entities/feature_entity_helpers.py | 33 --- wyvern/entities/identifier.py | 12 + wyvern/exceptions.py | 16 ++ wyvern/wyvern_request.py | 39 +++- 22 files changed, 537 insertions(+), 395 deletions(-) create mode 100644 wyvern/components/helpers/polars.py delete mode 100644 wyvern/entities/feature_entity_helpers.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c1e074b..d43c2cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,6 +62,7 @@ repos: - types-tqdm - nest-asyncio - aiohttp + - polars exclude: "^tests/" # Check for spelling diff --git a/examples/feature_store_main.py b/examples/feature_store_main.py index fed0ff3..a30c546 100644 --- a/examples/feature_store_main.py +++ b/examples/feature_store_main.py @@ -5,6 +5,7 @@ import typer from pydantic import BaseModel +from wyvern import Identifier from wyvern.components.api_route_component import APIRouteComponent from wyvern.components.features.feature_store import ( FeatureStoreRetrievalRequest, @@ -33,13 +34,24 @@ async def execute( self, input: FeatureStoreRetrievalRequest, **kwargs ) -> FeatureStoreResponse: logger.info(f"Executing input {input}") - feature_map = await feature_store_retrieval_component.execute(input) - + feature_df = await feature_store_retrieval_component.execute(input) + feature_dicts = feature_df.df.to_dicts() + feature_data: Dict[str, FeatureData] = { + str(feature_dict["IDENTIFIER"]): FeatureData( + identifier=Identifier( + identifier_type=feature_dict["IDENTIFIER"].split("::")[0], + identifier=feature_dict["IDENTIFIER"].split("::")[1], + ), + features={ + feature_name: feature_value + for feature_name, feature_value in feature_dict.items() + if feature_name != "IDENTIFIER" + }, + ) + for feature_dict in feature_dicts + } return FeatureStoreResponse( - feature_data={ - identifier.identifier: feature_map.feature_map[identifier] - for identifier in feature_map.feature_map.keys() - }, + feature_data=feature_data, ) diff --git a/examples/real_time_features_main.py b/examples/real_time_features_main.py index ef047be..1b8c33e 100644 --- a/examples/real_time_features_main.py +++ b/examples/real_time_features_main.py @@ -337,15 +337,27 @@ async def execute( ) time_start = time() - feature_map = await self.feature_retrieval_pipeline.execute(request) + feature_df = await self.feature_retrieval_pipeline.execute(request) logger.info(f"operation feature_retrieval took:{time()-time_start:2.4f} sec") profiler.stop() profiler.print() + feature_dicts = feature_df.df.to_dicts() + feature_data: Dict[str, FeatureData] = { + str(feature_dict["IDENTIFIER"]): FeatureData( + identifier=Identifier( + identifier_type=feature_dict["IDENTIFIER"].split("::")[0], + identifier=feature_dict["IDENTIFIER"].split("::")[1], + ), + features={ + feature_name: feature_value + for feature_name, feature_value in feature_dict.items() + if feature_name != "IDENTIFIER" + }, + ) + for feature_dict in feature_dicts + } return FeatureStoreResponse( - feature_data={ - str(identifier): feature_map.feature_map[identifier] - for identifier in feature_map.feature_map.keys() - }, + feature_data=feature_data, ) diff --git a/poetry.lock b/poetry.lock index 6477954..44e181f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1727,7 +1727,6 @@ files = [ {file = "greenlet-2.0.2-cp27-cp27m-win32.whl", hash = "sha256:6c3acb79b0bfd4fe733dff8bc62695283b57949ebcca05ae5c129eb606ff2d74"}, {file = "greenlet-2.0.2-cp27-cp27m-win_amd64.whl", hash = "sha256:283737e0da3f08bd637b5ad058507e578dd462db259f7f6e4c5c365ba4ee9343"}, {file = "greenlet-2.0.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:d27ec7509b9c18b6d73f2f5ede2622441de812e7b1a80bbd446cb0633bd3d5ae"}, - {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d967650d3f56af314b72df7089d96cda1083a7fc2da05b375d2bc48c82ab3f3c"}, {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:30bcf80dda7f15ac77ba5af2b961bdd9dbc77fd4ac6105cee85b0d0a5fcf74df"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26fbfce90728d82bc9e6c38ea4d038cba20b7faf8a0ca53a9c07b67318d46088"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9190f09060ea4debddd24665d6804b995a9c122ef5917ab26e1566dcc712ceeb"}, @@ -1736,7 +1735,6 @@ files = [ {file = "greenlet-2.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:76ae285c8104046b3a7f06b42f29c7b73f77683df18c49ab5af7983994c2dd91"}, {file = "greenlet-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:2d4686f195e32d36b4d7cf2d166857dbd0ee9f3d20ae349b6bf8afc8485b3645"}, {file = "greenlet-2.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c4302695ad8027363e96311df24ee28978162cdcdd2006476c43970b384a244c"}, - {file = "greenlet-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d4606a527e30548153be1a9f155f4e283d109ffba663a15856089fb55f933e47"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c48f54ef8e05f04d6eff74b8233f6063cb1ed960243eacc474ee73a2ea8573ca"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a1846f1b999e78e13837c93c778dcfc3365902cfb8d1bdb7dd73ead37059f0d0"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a06ad5312349fec0ab944664b01d26f8d1f05009566339ac6f63f56589bc1a2"}, @@ -1766,7 +1764,6 @@ files = [ {file = "greenlet-2.0.2-cp37-cp37m-win32.whl", hash = "sha256:3f6ea9bd35eb450837a3d80e77b517ea5bc56b4647f5502cd28de13675ee12f7"}, {file = "greenlet-2.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:7492e2b7bd7c9b9916388d9df23fa49d9b88ac0640db0a5b4ecc2b653bf451e3"}, {file = "greenlet-2.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b864ba53912b6c3ab6bcb2beb19f19edd01a6bfcbdfe1f37ddd1778abfe75a30"}, - {file = "greenlet-2.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1087300cf9700bbf455b1b97e24db18f2f77b55302a68272c56209d5587c12d1"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:ba2956617f1c42598a308a84c6cf021a90ff3862eddafd20c3333d50f0edb45b"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc3a569657468b6f3fb60587e48356fe512c1754ca05a564f11366ac9e306526"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8eab883b3b2a38cc1e050819ef06a7e6344d4a990d24d45bc6f2cf959045a45b"}, @@ -1775,7 +1772,6 @@ files = [ {file = "greenlet-2.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b0ef99cdbe2b682b9ccbb964743a6aca37905fda5e0452e5ee239b1654d37f2a"}, {file = "greenlet-2.0.2-cp38-cp38-win32.whl", hash = "sha256:b80f600eddddce72320dbbc8e3784d16bd3fb7b517e82476d8da921f27d4b249"}, {file = "greenlet-2.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:4d2e11331fc0c02b6e84b0d28ece3a36e0548ee1a1ce9ddde03752d9b79bba40"}, - {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8512a0c38cfd4e66a858ddd1b17705587900dd760c6003998e9472b77b56d417"}, {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:88d9ab96491d38a5ab7c56dd7a3cc37d83336ecc564e4e8816dbed12e5aaefc8"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:561091a7be172ab497a3527602d467e2b3fbe75f9e783d8b8ce403fa414f71a6"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:971ce5e14dc5e73715755d0ca2975ac88cfdaefcaab078a284fea6cfabf866df"}, @@ -2602,16 +2598,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -3289,6 +3275,42 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "polars" +version = "0.19.6" +description = "Blazingly fast DataFrame library" +optional = false +python-versions = ">=3.8" +files = [ + {file = "polars-0.19.6-cp38-abi3-macosx_10_7_x86_64.whl", hash = "sha256:a9667e1afcada45c0a32df7c1cd3b588a32424249487db4ef931841859020db4"}, + {file = "polars-0.19.6-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:21e1b6648fcbf79ccb69a1f09c490574b3ec6556af4d3044da6ccf8353a77915"}, + {file = "polars-0.19.6-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b8f5ca5a682ae3ebfda5993847cee4818f546ea6ab40fe5fb275093db2a14ac"}, + {file = "polars-0.19.6-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52259072fa9c5aac0a54a52fe74039fe5bf157ddc7ec38df19cee6c8594918d9"}, + {file = "polars-0.19.6-cp38-abi3-win_amd64.whl", hash = "sha256:f4eb5860301bb3ad4d6e9bb6c319e099fc5d0a6dbe47c60b44323b96a09ec1ea"}, + {file = "polars-0.19.6.tar.gz", hash = "sha256:b0e4be7db019152ee5fa26b42515474037472521a1ce5a113b0f0bb4209205fd"}, +] + +[package.extras] +adbc = ["adbc_driver_sqlite"] +all = ["polars[adbc,cloudpickle,connectorx,deltalake,fsspec,gevent,matplotlib,numpy,pandas,pyarrow,pydantic,pyiceberg,sqlalchemy,timezone,xlsx2csv,xlsxwriter]"] +cloudpickle = ["cloudpickle"] +connectorx = ["connectorx"] +deltalake = ["deltalake (>=0.10.0)"] +fsspec = ["fsspec"] +gevent = ["gevent"] +matplotlib = ["matplotlib"] +numpy = ["numpy (>=1.16.0)"] +openpyxl = ["openpyxl (>=3.0.0)"] +pandas = ["pandas", "pyarrow (>=7.0.0)"] +pyarrow = ["pyarrow (>=7.0.0)"] +pydantic = ["pydantic"] +pyiceberg = ["pyiceberg (>=0.5.0)"] +pyxlsb = ["pyxlsb (>=1.0)"] +sqlalchemy = ["pandas", "sqlalchemy"] +timezone = ["backports.zoneinfo", "tzdata"] +xlsx2csv = ["xlsx2csv (>=0.8.0)"] +xlsxwriter = ["xlsxwriter"] + [[package]] name = "posthog" version = "3.0.2" @@ -3917,7 +3939,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -3925,15 +3946,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -3950,7 +3964,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -3958,7 +3971,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -5313,4 +5325,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.12" -content-hash = "8a3e8ff16402de8c47b84bf254b80d702818269c5ef2c4c25d3ec657c4460264" +content-hash = "5be697802a6d07b5e8b227ef9274b46358c16e1835731bfebe3b5c66cc8d167f" diff --git a/pyproject.toml b/pyproject.toml index 4a7af71..9fedea0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wyvern-ai" -version = "0.0.26" +version = "0.0.27" description = "" authors = ["Wyvern AI "] readme = "README.md" @@ -33,6 +33,7 @@ aiohttp = {extras = ["speedups"], version = "^3.8.5"} requests = "^2.31.0" platformdirs = "^3.8" posthog = "^3.0.2" +polars = "^0.19.6" [tool.poetry.group.dev.dependencies] diff --git a/tests/components/business_logic/test_pinning_business_logic.py b/tests/components/business_logic/test_pinning_business_logic.py index edc30a7..8a4bd34 100644 --- a/tests/components/business_logic/test_pinning_business_logic.py +++ b/tests/components/business_logic/test_pinning_business_logic.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from collections import defaultdict from typing import Dict, List import pytest @@ -13,7 +14,7 @@ PinningBusinessLogicComponent, ) from wyvern.entities.candidate_entities import ScoredCandidate -from wyvern.entities.feature_entities import FeatureMap +from wyvern.entities.feature_entities import FeatureDataFrame from wyvern.entities.identifier_entities import ProductEntity from wyvern.entities.request import BaseWyvernRequest from wyvern.wyvern_request import WyvernRequest @@ -65,7 +66,8 @@ def __init__(self): headers={}, entity_store={}, events=[], - feature_map=FeatureMap(feature_map={}), + feature_df=FeatureDataFrame(), + feature_orig_identifiers=defaultdict(dict), model_output_map={}, ), ) diff --git a/tests/feature_store/test_real_time_features.py b/tests/feature_store/test_real_time_features.py index 158e9d2..520cce4 100644 --- a/tests/feature_store/test_real_time_features.py +++ b/tests/feature_store/test_real_time_features.py @@ -8,7 +8,7 @@ RankingRealtimeFeatureComponent, ) from wyvern.components.features.feature_store import feature_store_retrieval_component -from wyvern.entities.feature_entities import FeatureMap +from wyvern.entities.feature_entities import FeatureDataFrame from wyvern.feature_store.historical_feature_util import separate_real_time_features from wyvern.service import WyvernService @@ -29,7 +29,7 @@ def mock_feature_store(mocker): mocker.patch.object( feature_store_retrieval_component, "fetch_features_from_feature_store", - return_value=FeatureMap(feature_map={}), + return_value=FeatureDataFrame(), ) @@ -78,101 +78,49 @@ async def test_end_to_end(mock_redis, test_client, mock_feature_store): "RealTimeQueryFeature:f_query_length": 6.0, "RealTimeStringFeature:f_query": "candle", "RealTimeEmbeddingFeature:f_query_embedding_vector_8": [ - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, ], + "RealTimeUserQueryFeature:f_user_query_name_edit_distance": 3.0, + "RealTimeUserQueryFeature:f_user_query_name_jaccard_similarity": -3.0, }, }, "user::1234": { "identifier": {"identifier": "1234", "identifier_type": "user"}, "features": {"RealTimeUserFeature:f_user_name_length": 9.0}, }, - "query:user::candle:1234": { - "identifier": { - "identifier": "candle:1234", - "identifier_type": "query:user", - }, - "features": { - "RealTimeUserQueryFeature:f_user_query_name_edit_distance": 3.0, - "RealTimeUserQueryFeature:f_user_query_name_jaccard_similarity": -3.0, - }, - }, "product::p1": { "identifier": {"identifier": "p1", "identifier_type": "product"}, - "features": {"RealTimeProductFeature:f_opensearch_score": 1.0}, - }, - "product::p2": { - "identifier": {"identifier": "p2", "identifier_type": "product"}, - "features": {}, - }, - "product::p3": { - "identifier": {"identifier": "p3", "identifier_type": "product"}, - "features": {}, - }, - "product:query::p1:candle": { - "identifier": { - "identifier": "p1:candle", - "identifier_type": "product:query", - }, "features": { "RealTimeMatchedQueriesProductFeature:f_matched_queries_QUERY_1": 1.0, "RealTimeMatchedQueriesProductFeature:f_matched_queries_QUERY_2": 1.0, + "RealTimeProductFeature:f_opensearch_score": 1.0, "RealTimeQueryProductFeature:f_query_product_name_edit_distance": 4.0, "RealTimeQueryProductFeature:f_query_product_name_jaccard_similarity": -4.0, - }, - }, - "product:query::p2:candle": { - "identifier": { - "identifier": "p2:candle", - "identifier_type": "product:query", - }, - "features": { - "RealTimeQueryProductFeature:f_query_product_name_edit_distance": 4.0, - "RealTimeQueryProductFeature:f_query_product_name_jaccard_similarity": -4.0, - }, - }, - "product:query::p3:candle": { - "identifier": { - "identifier": "p3:candle", - "identifier_type": "product:query", - }, - "features": { - "RealTimeQueryProductFeature:f_query_product_name_edit_distance": 4.0, - "RealTimeQueryProductFeature:f_query_product_name_jaccard_similarity": -4.0, - }, - }, - "product:user::p1:1234": { - "identifier": { - "identifier": "p1:1234", - "identifier_type": "product:user", - }, - "features": { "RealTimeUserProductFeature:f_user_product_name_edit_distance": 7.0, "RealTimeUserProductFeature:f_user_product_name_jaccard_similarity": -7.0, }, }, - "product:user::p2:1234": { - "identifier": { - "identifier": "p2:1234", - "identifier_type": "product:user", - }, + "product::p2": { + "identifier": {"identifier": "p2", "identifier_type": "product"}, "features": { + "RealTimeQueryProductFeature:f_query_product_name_edit_distance": 4.0, + "RealTimeQueryProductFeature:f_query_product_name_jaccard_similarity": -4.0, "RealTimeUserProductFeature:f_user_product_name_edit_distance": 7.0, "RealTimeUserProductFeature:f_user_product_name_jaccard_similarity": -7.0, }, }, - "product:user::p3:1234": { - "identifier": { - "identifier": "p3:1234", - "identifier_type": "product:user", - }, + "product::p3": { + "identifier": {"identifier": "p3", "identifier_type": "product"}, "features": { + "RealTimeQueryProductFeature:f_query_product_name_edit_distance": 4.0, + "RealTimeQueryProductFeature:f_query_product_name_jaccard_similarity": -4.0, "RealTimeUserProductFeature:f_user_product_name_edit_distance": 7.0, "RealTimeUserProductFeature:f_user_product_name_jaccard_similarity": -7.0, }, @@ -234,127 +182,61 @@ async def test_end_to_end__2(mock_redis__2, test_client): "RealTimeQueryFeature:f_query_length": 6.0, "RealTimeStringFeature:f_query": "candle", "RealTimeEmbeddingFeature:f_query_embedding_vector_8": [ - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, ], + "RealTimeUserQueryFeature:f_user_query_name_edit_distance": 3.0, + "RealTimeUserQueryFeature:f_user_query_name_jaccard_similarity": -3.0, }, }, "user::1234": { "identifier": {"identifier": "1234", "identifier_type": "user"}, "features": {"RealTimeUserFeature:f_user_name_length": 9.0}, }, - "query:user::candle:1234": { - "identifier": { - "identifier": "candle:1234", - "identifier_type": "query:user", - }, - "features": { - "RealTimeUserQueryFeature:f_user_query_name_edit_distance": 3.0, - "RealTimeUserQueryFeature:f_user_query_name_jaccard_similarity": -3.0, - }, - }, "product::p1": { "identifier": {"identifier": "p1", "identifier_type": "product"}, - "features": {"RealTimeProductFeature:f_opensearch_score": 1.0}, - }, - "product::p2": { - "identifier": {"identifier": "p2", "identifier_type": "product"}, - "features": {}, - }, - "product::p3": { - "identifier": {"identifier": "p3", "identifier_type": "product"}, - "features": {}, - }, - "product::p4": { - "identifier": {"identifier": "p4", "identifier_type": "product"}, - "features": {"RealTimeProductFeature:f_opensearch_score": 100.0}, - }, - "product:query::p1:candle": { - "identifier": { - "identifier": "p1:candle", - "identifier_type": "product:query", - }, "features": { + "RealTimeProductFeature:f_opensearch_score": 1.0, "RealTimeMatchedQueriesProductFeature:f_matched_queries_QUERY_1": 1.0, "RealTimeMatchedQueriesProductFeature:f_matched_queries_QUERY_2": 1.0, "RealTimeQueryProductFeature:f_query_product_name_edit_distance": 4.0, "RealTimeQueryProductFeature:f_query_product_name_jaccard_similarity": -4.0, + "RealTimeUserProductFeature:f_user_product_name_edit_distance": 7.0, + "RealTimeUserProductFeature:f_user_product_name_jaccard_similarity": -7.0, }, }, - "product:query::p2:candle": { - "identifier": { - "identifier": "p2:candle", - "identifier_type": "product:query", - }, + "product::p2": { + "identifier": {"identifier": "p2", "identifier_type": "product"}, "features": { "RealTimeQueryProductFeature:f_query_product_name_edit_distance": 4.0, "RealTimeQueryProductFeature:f_query_product_name_jaccard_similarity": -4.0, + "RealTimeUserProductFeature:f_user_product_name_edit_distance": 7.0, + "RealTimeUserProductFeature:f_user_product_name_jaccard_similarity": -7.0, }, }, - "product:query::p3:candle": { - "identifier": { - "identifier": "p3:candle", - "identifier_type": "product:query", - }, + "product::p3": { + "identifier": {"identifier": "p3", "identifier_type": "product"}, "features": { "RealTimeQueryProductFeature:f_query_product_name_edit_distance": 4.0, "RealTimeQueryProductFeature:f_query_product_name_jaccard_similarity": -4.0, + "RealTimeUserProductFeature:f_user_product_name_edit_distance": 7.0, + "RealTimeUserProductFeature:f_user_product_name_jaccard_similarity": -7.0, }, }, - "product:query::p4:candle": { - "identifier": { - "identifier": "p4:candle", - "identifier_type": "product:query", - }, + "product::p4": { + "identifier": {"identifier": "p4", "identifier_type": "product"}, "features": { + "RealTimeProductFeature:f_opensearch_score": 100.0, "RealTimeMatchedQueriesProductFeature:f_matched_queries_MATIAS": 1.0, "RealTimeMatchedQueriesProductFeature:f_matched_queries_QUERY_2": 1.0, "RealTimeQueryProductFeature:f_query_product_name_edit_distance": 4.0, "RealTimeQueryProductFeature:f_query_product_name_jaccard_similarity": -4.0, - }, - }, - "product:user::p1:1234": { - "identifier": { - "identifier": "p1:1234", - "identifier_type": "product:user", - }, - "features": { - "RealTimeUserProductFeature:f_user_product_name_edit_distance": 7.0, - "RealTimeUserProductFeature:f_user_product_name_jaccard_similarity": -7.0, - }, - }, - "product:user::p2:1234": { - "identifier": { - "identifier": "p2:1234", - "identifier_type": "product:user", - }, - "features": { - "RealTimeUserProductFeature:f_user_product_name_edit_distance": 7.0, - "RealTimeUserProductFeature:f_user_product_name_jaccard_similarity": -7.0, - }, - }, - "product:user::p3:1234": { - "identifier": { - "identifier": "p3:1234", - "identifier_type": "product:user", - }, - "features": { - "RealTimeUserProductFeature:f_user_product_name_edit_distance": 7.0, - "RealTimeUserProductFeature:f_user_product_name_jaccard_similarity": -7.0, - }, - }, - "product:user::p4:1234": { - "identifier": { - "identifier": "p4:1234", - "identifier_type": "product:user", - }, - "features": { "RealTimeUserProductFeature:f_user_product_name_edit_distance": 7.0, "RealTimeUserProductFeature:f_user_product_name_jaccard_similarity": -7.0, }, diff --git a/tests/scenarios/test_product_ranking.py b/tests/scenarios/test_product_ranking.py index 697249b..a196f18 100644 --- a/tests/scenarios/test_product_ranking.py +++ b/tests/scenarios/test_product_ranking.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from collections import defaultdict from functools import cached_property from typing import Any, Dict, List, Optional, Set @@ -19,7 +20,7 @@ from wyvern.core.compression import wyvern_encode from wyvern.core.http import aiohttp_client from wyvern.entities.candidate_entities import CandidateSetEntity -from wyvern.entities.feature_entities import FeatureData, FeatureMap +from wyvern.entities.feature_entities import FeatureData, FeatureDataFrame from wyvern.entities.identifier import Identifier from wyvern.entities.identifier_entities import ProductEntity, WyvernEntity from wyvern.entities.model_entities import ModelInput, ModelOutput @@ -387,7 +388,8 @@ async def test_hydrate(mock_redis): entity_store={}, model_output_map={}, events=[], - feature_map=FeatureMap(feature_map={}), + feature_df=FeatureDataFrame(), + feature_orig_identifiers=defaultdict(dict), ) request_context.set(test_wyvern_request) @@ -448,8 +450,9 @@ async def test_hydrate__duplicate_brand(mock_redis__duplicate_brand): headers={}, entity_store={}, events=[], - feature_map=FeatureMap(feature_map={}), + feature_df=FeatureDataFrame(), model_output_map={}, + feature_orig_identifiers=defaultdict(dict), ) request_context.set(test_wyvern_request) diff --git a/wyvern/__init__.py b/wyvern/__init__.py index 168f860..c0d36ad 100644 --- a/wyvern/__init__.py +++ b/wyvern/__init__.py @@ -19,7 +19,6 @@ SingleEntityPipelineResponse, ) from wyvern.entities.candidate_entities import CandidateSetEntity -from wyvern.entities.feature_entities import FeatureData, FeatureMap from wyvern.entities.identifier import CompositeIdentifier, Identifier, IdentifierType from wyvern.entities.identifier_entities import ( ProductEntity, @@ -44,8 +43,6 @@ "CandidateSetEntity", "ChainedModelInput", "CompositeIdentifier", - "FeatureData", - "FeatureMap", "Identifier", "IdentifierType", "ModelComponent", diff --git a/wyvern/components/component.py b/wyvern/components/component.py index 4c1f3d0..a44d651 100644 --- a/wyvern/components/component.py +++ b/wyvern/components/component.py @@ -5,11 +5,14 @@ import logging from enum import Enum from functools import cached_property -from typing import Dict, Generic, List, Optional, Set, Union +from typing import Dict, Generic, List, Optional, Set, Tuple, Union from uuid import uuid4 +import polars as pl + from wyvern import request_context -from wyvern.entities.identifier import Identifier +from wyvern.entities.identifier import Identifier, get_identifier_key +from wyvern.exceptions import WyvernFeatureValueError from wyvern.wyvern_typing import INPUT_TYPE, OUTPUT_TYPE, WyvernFeature logger = logging.getLogger(__name__) @@ -142,8 +145,40 @@ def manifest_feature_names(self) -> Set[str]: """ return set() + @staticmethod + def get_features( + identifiers: List[Identifier], + feature_names: List[str], + ) -> List[Tuple[str, List[WyvernFeature]]]: + current_request = request_context.ensure_current_request() + identifier_keys = [get_identifier_key(identifier) for identifier in identifiers] + df = current_request.feature_df.get_features_by_identifier_keys( + identifier_keys, + feature_names, + ) + + # build tuples where the identifier column is the first element and the feature columns are the rest + rows = df.rows() + identifier_to_features_dict = { + # row[0] is the identifier column, it is a string + # row[1:] are the feature columns, each column is a WyvernFeature + row[0]: row[1:] + for row in rows + } + + empty_feature_list = [None] * len(feature_names) + tuples = [ + ( + identifier_key, + identifier_to_features_dict.get(identifier_key, empty_feature_list), + ) + for identifier_key in identifier_keys + ] + + return tuples # type: ignore + + @staticmethod def get_feature( - self, identifier: Identifier, feature_name: str, ) -> WyvernFeature: @@ -159,12 +194,19 @@ def get_feature( you just have to pass in feature_name="wyvern_feature". """ current_request = request_context.ensure_current_request() - feature_data = current_request.feature_map.feature_map.get(identifier) - if not feature_data: - return None - return feature_data.features.get(feature_name) + df = current_request.feature_df.get_features( + [identifier], + [feature_name], + ) + df = df.filter(pl.col(feature_name).is_not_null()) + if len(df) > 1: + raise WyvernFeatureValueError( + identifier=identifier, + feature_name=feature_name, + ) + return df[feature_name][0] if not df[feature_name].is_empty() else None - def get_all_features( + def get_all_features_for_identifier( self, identifier: Identifier, ) -> Dict[str, WyvernFeature]: @@ -173,10 +215,15 @@ def get_all_features( The features are cached once fetched/evaluated. """ current_request = request_context.ensure_current_request() - feature_data = current_request.feature_map.feature_map.get(identifier) - if not feature_data: - return {} - return feature_data.features + df = current_request.feature_df.get_all_features_for_identifier(identifier) + feature_dict = df.to_dict() + result: Dict[str, WyvernFeature] = {} + for key, value in feature_dict.items(): + if len(value) > 1: + raise WyvernFeatureValueError(identifier=identifier, feature_name=key) + result[key] = value[0] if value else None + + return result def get_model_output( self, diff --git a/wyvern/components/features/feature_logger.py b/wyvern/components/features/feature_logger.py index 9c7ab02..62e282c 100644 --- a/wyvern/components/features/feature_logger.py +++ b/wyvern/components/features/feature_logger.py @@ -8,7 +8,7 @@ from wyvern import request_context from wyvern.components.component import Component from wyvern.components.events.events import EventType, LoggedEvent -from wyvern.entities.feature_entities import FeatureMap +from wyvern.entities.feature_entities import IDENTIFIER from wyvern.event_logging import event_logger from wyvern.wyvern_typing import REQUEST_ENTITY, WyvernFeature @@ -47,11 +47,13 @@ class FeatureEventLoggingRequest( Attributes: request: The request to log feature events for. - feature_map: The feature map to log. + feature_df: The feature data frame to log. """ request: REQUEST_ENTITY - feature_map: FeatureMap + + class Config: + arbitrary_types_allowed = True class FeatureEventLoggingComponent( @@ -75,6 +77,10 @@ def feature_event_generator(): A list of feature events. """ timestamp = datetime.utcnow() + + # Extract column names excluding "IDENTIFIER" + feature_columns = wyvern_request.feature_df.df.columns[1:] + return [ FeatureEvent( request_id=input.request.request_id, @@ -82,14 +88,21 @@ def feature_event_generator(): api_source=url_path, event_timestamp=timestamp, event_data=FeatureLogEventData( - feature_identifier=feature_data.identifier.identifier, - feature_identifier_type=feature_data.identifier.identifier_type, - feature_name=feature_name, - feature_value=feature_value, + feature_identifier_type=wyvern_request.get_original_identifier( + row[IDENTIFIER], + col, + ).identifier_type, + feature_identifier=wyvern_request.get_original_identifier( + row[IDENTIFIER], + col, + ).identifier, + feature_name=col, + feature_value=row[col], ), ) - for feature_data in input.feature_map.feature_map.values() - for feature_name, feature_value in feature_data.features.items() + for row in wyvern_request.feature_df.df.iter_rows(named=True) + for col in feature_columns + if row[col] ] event_logger.log_events(feature_event_generator) # type: ignore diff --git a/wyvern/components/features/feature_retrieval_pipeline.py b/wyvern/components/features/feature_retrieval_pipeline.py index 5575367..94799fc 100644 --- a/wyvern/components/features/feature_retrieval_pipeline.py +++ b/wyvern/components/features/feature_retrieval_pipeline.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- import asyncio import logging -from typing import Generic, List, Optional, Set, Type +from collections import defaultdict +from typing import Generic, List, Optional, Set, Tuple, Type +import polars as pl from ddtrace import tracer from pydantic.generics import GenericModel @@ -21,9 +23,9 @@ RealtimeFeatureComponent, RealtimeFeatureRequest, ) +from wyvern.components.helpers.polars import cast_float32_to_float64 from wyvern.entities.candidate_entities import CandidateSetEntity -from wyvern.entities.feature_entities import FeatureData, FeatureMap -from wyvern.entities.feature_entity_helpers import feature_map_create, feature_map_join +from wyvern.entities.feature_entities import IDENTIFIER, FeatureDataFrame from wyvern.entities.identifier_entities import WyvernEntity from wyvern.wyvern_typing import REQUEST_ENTITY @@ -49,7 +51,7 @@ class FeatureRetrievalPipelineRequest(GenericModel, Generic[REQUEST_ENTITY]): class FeatureRetrievalPipeline( - Component[FeatureRetrievalPipelineRequest[REQUEST_ENTITY], FeatureMap], + Component[FeatureRetrievalPipelineRequest[REQUEST_ENTITY], FeatureDataFrame], Generic[REQUEST_ENTITY], ): """ @@ -92,6 +94,42 @@ def __init__( name=name, ) + @tracer.wrap(name="FeatureRetrievalPipeline._concat_real_time_features") + def _concat_real_time_features( + self, + real_time_feature_dfs: List[Tuple[str, Optional[pl.DataFrame]]], + ) -> Optional[pl.DataFrame]: + """ + This method is used to cast and concatenate real-time features into one DataFrame. + + Args: + real_time_feature_dfs: A list of DataFrames that contain real-time features. + + Returns: + A DataFrame that contains all the real-time features. + """ + grouped_features = defaultdict(list) + for key, value in real_time_feature_dfs: + if value is not None: + grouped_features[key].append(cast_float32_to_float64(value)) + + merged_features = [ + pl.concat(value, how="diagonal") if len(value) > 1 else value[0] + for value in grouped_features.values() + ] + + if not merged_features: + return None + + real_time_feature_merged_df = merged_features[0] + for df in merged_features[1:]: + real_time_feature_merged_df = real_time_feature_merged_df.join( + df, + on=IDENTIFIER, + how="outer", + ) + return real_time_feature_merged_df + @tracer.wrap(name="FeatureRetrievalPipeline._generate_real_time_features") def _generate_real_time_features( self, @@ -112,7 +150,7 @@ def _generate_real_time_features( @tracer.wrap(name="FeatureRetrievalPipeline.execute") async def execute( self, input: FeatureRetrievalPipelineRequest[REQUEST_ENTITY], **kwargs - ) -> FeatureMap: + ) -> FeatureDataFrame: """ This method is used to retrieve features for a given request. @@ -153,16 +191,13 @@ async def execute( feature_names=list(feature_names_to_retrieve_from_feature_store), ) - feature_retrieval_response: FeatureMap = ( - await self.feature_retrieval_component.execute( - input=feature_retrieval_request, - handle_exceptions=self.handle_exceptions, - **kwargs, - ) + feature_df = await self.feature_retrieval_component.execute( + input=feature_retrieval_request, + handle_exceptions=self.handle_exceptions, + **kwargs, ) current_request = request_context.ensure_current_request() - current_request.feature_map = feature_retrieval_response - + current_request.feature_df = feature_df """ TODO (suchintan): 1. Figure out a set of: (Candidate entities), (Non-candidate entities), (Request) @@ -207,10 +242,10 @@ async def execute( with tracer.trace("FeatureRetrievalPipeline.real_time_no_entity_features"): request = RealtimeFeatureRequest[REQUEST_ENTITY]( request=input.request, - feature_retrieval_response=feature_retrieval_response, + feature_retrieval_response=feature_df, ) real_time_request_no_entity_features: List[ - Optional[FeatureData] + Tuple[str, Optional[pl.DataFrame]] ] = await asyncio.gather( *[ real_time_feature.compute_request_features_wrapper( @@ -223,7 +258,7 @@ async def execute( with tracer.trace("FeatureRetrievalPipeline.real_time_entity_features"): real_time_request_features: List[ - Optional[FeatureData] + Tuple[str, Optional[pl.DataFrame]] ] = await asyncio.gather( *[ real_time_feature.compute_features_wrapper( @@ -238,7 +273,7 @@ async def execute( with tracer.trace("FeatureRetrievalPipeline.real_time_combination_features"): real_time_request_combination_features: List[ - Optional[FeatureData] + Tuple[str, Optional[pl.DataFrame]] ] = await asyncio.gather( *[ real_time_feature.compute_composite_features_wrapper( @@ -258,8 +293,10 @@ async def execute( ] ) - real_time_candidate_features: List[Optional[FeatureData]] = [] - real_time_candidate_combination_features: List[Optional[FeatureData]] = [] + real_time_candidate_features: List[Tuple[str, Optional[pl.DataFrame]]] = [] + real_time_candidate_combination_features: List[ + Tuple[str, Optional[pl.DataFrame]] + ] = [] if isinstance(input.request, CandidateSetEntity): with tracer.trace("FeatureRetrievalPipeline.real_time_candidate_features"): @@ -305,26 +342,34 @@ async def execute( # Idea 2: Define feature views that have the same interface, # and we collect them together ahead of this dict comprehension # pytest / linter validation: we should assert for feature name conflicts -- this should never happen - with tracer.trace("FeatureRetrievalPipeline.create_feature_map"): - real_time_feature_responses = feature_map_create( - *real_time_request_no_entity_features, - *real_time_request_features, - *real_time_request_combination_features, - *real_time_candidate_features, - *real_time_candidate_combination_features, + with tracer.trace("FeatureRetrievalPipeline.merge_feature_dfs"): + real_time_feature_merged_df = self._concat_real_time_features( + [ + *real_time_request_no_entity_features, + *real_time_request_features, + *real_time_request_combination_features, + *real_time_candidate_features, + *real_time_candidate_combination_features, + ], ) with tracer.trace("FeatureRetrievalPipeline.create_feature_response"): - await self.feature_logger_component.execute( - FeatureEventLoggingRequest( - request=input.request, - feature_map=real_time_feature_responses, - ), - ) - # TODO (suchintan): Improve performance of this - feature_responses = feature_map_join( - feature_retrieval_response, - real_time_feature_responses, - ) + if ( + real_time_feature_merged_df is None + or real_time_feature_merged_df.is_empty() + ): + feature_responses = feature_df.df + else: + await self.feature_logger_component.execute( + FeatureEventLoggingRequest( + request=input.request, + ), + ) + feature_responses = feature_df.df.join( + real_time_feature_merged_df, + on=IDENTIFIER, + how="outer", + ) - return feature_responses + current_request.feature_df = FeatureDataFrame(df=feature_responses) + return current_request.feature_df diff --git a/wyvern/components/features/feature_store.py b/wyvern/components/features/feature_store.py index 3c8cd37..c0f352d 100644 --- a/wyvern/components/features/feature_store.py +++ b/wyvern/components/features/feature_store.py @@ -1,21 +1,18 @@ # -*- coding: utf-8 -*- import logging -from typing import Dict, List, Optional +from typing import List, Optional +import polars as pl from ddtrace import tracer from pydantic import BaseModel +from wyvern import request_context from wyvern.components.component import Component from wyvern.config import settings from wyvern.core.http import aiohttp_client -from wyvern.entities.feature_entities import ( - FeatureData, - FeatureMap, - build_empty_feature_map, -) -from wyvern.entities.identifier import Identifier +from wyvern.entities.feature_entities import IDENTIFIER, FeatureDataFrame +from wyvern.entities.identifier import Identifier, get_identifier_key from wyvern.exceptions import WyvernFeatureNameError, WyvernFeatureStoreError -from wyvern.wyvern_typing import WyvernFeature logger = logging.getLogger(__name__) @@ -37,7 +34,7 @@ class FeatureStoreRetrievalRequest(BaseModel): class FeatureStoreRetrievalComponent( - Component[FeatureStoreRetrievalRequest, FeatureMap], + Component[FeatureStoreRetrievalRequest, FeatureDataFrame], ): """ Component to retrieve features from the feature store. This component is responsible for fetching features from @@ -71,7 +68,7 @@ async def fetch_features_from_feature_store( self, identifiers: List[Identifier], feature_names: List[str], - ) -> FeatureMap: + ) -> FeatureDataFrame: """ Fetches features from the feature store for the given identifiers and feature names. @@ -83,7 +80,7 @@ async def fetch_features_from_feature_store( FeatureMap containing the features for the given identifiers and feature names. """ if not feature_names or not settings.FEATURE_STORE_ENABLED: - return FeatureMap(feature_map={}) + return FeatureDataFrame() logger.info(f"Fetching features from feature store: {feature_names}") invalid_feature_names: List[str] = [ @@ -94,7 +91,7 @@ async def fetch_features_from_feature_store( request_body = { "features": feature_names, "entities": { - "IDENTIFIER": [identifier.identifier for identifier in identifiers], + IDENTIFIER: [identifier.identifier for identifier in identifiers], }, "full_feature_names": True, } @@ -117,32 +114,49 @@ async def fetch_features_from_feature_store( response_json = await response.json() feature_names = response_json["metadata"]["feature_names"] - feature_name_keys = [ + feature_names = [ feature_name.replace("__", ":", 1) for feature_name in feature_names ] - results = response_json["results"] - response_identifiers = results[0]["values"] identifier_by_identifiers = { identifier.identifier: identifier for identifier in identifiers } - feature_map: Dict[Identifier, FeatureData] = {} - for i in range(len(response_identifiers)): - feature_data: Dict[str, WyvernFeature] = { - feature_name_keys[j]: results[j]["values"][i] - # the first feature_name is IDENTIFIER which we will skip - for j in range(1, len(feature_names)) - } - - identifier = identifier_by_identifiers[response_identifiers[i]] - feature_map[identifier] = FeatureData( - identifier=identifier, - features=feature_data, - ) + current_request = request_context.ensure_current_request() + current_request.feature_orig_identifiers.update( + { + feature_name: { + get_identifier_key( + identifier_by_identifiers[identifier], + ): identifier_by_identifiers[identifier] + for identifier in results[0]["values"] + } + # skip identifier column itself + for feature_name in feature_names[1:] + }, + ) + + # Start with the IDENTIFIER column since we need to map the str -> Identifier + df_columns = [ + # get_identifier_key will return the primary identifier for composite identifiers + pl.Series( + name=IDENTIFIER, + values=[ + get_identifier_key(identifier_by_identifiers[identifier]) + for identifier in results[0]["values"] + ], + ), + ] + df_columns.extend( + [ + pl.Series(name=feature_name, values=results[i + 1]["values"]) + for i, feature_name in enumerate(feature_names[1:]) + ], + ) + df = pl.DataFrame().with_columns(df_columns) - return FeatureMap(feature_map=feature_map) + return FeatureDataFrame(df=df) @tracer.wrap(name="FeatureStoreRetrievalComponent.execute") async def execute( @@ -150,7 +164,7 @@ async def execute( input: FeatureStoreRetrievalRequest, handle_exceptions: bool = False, **kwargs, - ) -> FeatureMap: + ) -> FeatureDataFrame: """ Fetches features from the feature store for the given identifiers and feature names. This method is a wrapper around `fetch_features_from_feature_store` which handles exceptions and returns an empty FeatureMap in case of @@ -167,7 +181,10 @@ async def execute( except WyvernFeatureStoreError as e: if handle_exceptions: # logging is handled where the exception is raised - return build_empty_feature_map(input.identifiers, input.feature_names) + return FeatureDataFrame.build_empty_df( + input.identifiers, + input.feature_names, + ) else: raise e diff --git a/wyvern/components/features/realtime_features_component.py b/wyvern/components/features/realtime_features_component.py index f1a1f76..a23d2f4 100644 --- a/wyvern/components/features/realtime_features_component.py +++ b/wyvern/components/features/realtime_features_component.py @@ -15,10 +15,13 @@ get_args, ) +import polars as pl from pydantic.generics import GenericModel +from wyvern import request_context from wyvern.components.component import Component -from wyvern.entities.feature_entities import FeatureData, FeatureMap +from wyvern.entities.feature_entities import IDENTIFIER, FeatureData, FeatureDataFrame +from wyvern.entities.identifier import get_identifier_key from wyvern.entities.identifier_entities import WyvernEntity from wyvern.feature_store.constants import ( FULL_FEATURE_NAME_SEPARATOR, @@ -47,7 +50,7 @@ class RealtimeFeatureRequest(GenericModel, Generic[REQUEST_ENTITY]): """ request: REQUEST_ENTITY - feature_retrieval_response: FeatureMap + feature_retrieval_response: FeatureDataFrame class RealtimeFeatureEntity(GenericModel, Generic[PRIMARY_ENTITY, SECONDARY_ENTITY]): @@ -67,7 +70,7 @@ class RealtimeFeatureComponent( RealtimeFeatureRequest[REQUEST_ENTITY], RealtimeFeatureEntity[PRIMARY_ENTITY, SECONDARY_ENTITY], ], - Optional[FeatureData], + Tuple[str, Optional[pl.DataFrame]], ], Generic[PRIMARY_ENTITY, SECONDARY_ENTITY, REQUEST_ENTITY], ): @@ -258,7 +261,7 @@ async def execute( RealtimeFeatureEntity[PRIMARY_ENTITY, SECONDARY_ENTITY], ], **kwargs, - ) -> Optional[FeatureData]: + ) -> Tuple[str, Optional[pl.DataFrame]]: # TODO (Suchintan): Delete this method -- this has been fully delegated upwards? request = input[0] entities = input[1] @@ -268,7 +271,9 @@ async def execute( entities.primary_entity, entities.secondary_entity, ): - return None + return self.name, pl.DataFrame().with_columns( + pl.Series(name=IDENTIFIER, dtype=pl.Utf8), + ) if ( entities.secondary_entity is not None @@ -285,7 +290,7 @@ async def execute( f"Failed to compute composite features for " f"{self} {entities.primary_entity.identifier} {entities.secondary_entity.identifier}", ) - return resp + return self.create_df_with_full_feature_name(resp) if entities.primary_entity is not None: resp = await self.compute_features( @@ -298,7 +303,7 @@ async def execute( f"Failed to compute features for " f"{self} {entities.primary_entity.identifier}", ) - return resp + return self.create_df_with_full_feature_name(resp) # TODO (suchintan): Lowercase feature names? resp = await self.compute_request_features(request) @@ -308,7 +313,7 @@ async def execute( logger.info( f"Failed to compute request features for {self} {request.request}", ) - return resp + return self.create_df_with_full_feature_name(resp) async def compute_request_features( self, @@ -334,45 +339,62 @@ async def compute_composite_features( async def compute_request_features_wrapper( self, request: RealtimeFeatureRequest[REQUEST_ENTITY], - ) -> Optional[FeatureData]: + ) -> Tuple[str, Optional[pl.DataFrame]]: feature_data = await self.compute_request_features(request) - return self.set_full_feature_name(feature_data) + return self.create_df_with_full_feature_name(feature_data) async def compute_features_wrapper( self, entity: PRIMARY_ENTITY, request: RealtimeFeatureRequest[REQUEST_ENTITY], - ) -> Optional[FeatureData]: + ) -> Tuple[str, Optional[pl.DataFrame]]: feature_data = await self.compute_features(entity, request) - return self.set_full_feature_name(feature_data) + return self.create_df_with_full_feature_name(feature_data) async def compute_composite_features_wrapper( self, primary_entity: PRIMARY_ENTITY, secondary_entity: SECONDARY_ENTITY, request: RealtimeFeatureRequest[REQUEST_ENTITY], - ) -> Optional[FeatureData]: + ) -> Tuple[str, Optional[pl.DataFrame]]: feature_data = await self.compute_composite_features( primary_entity, secondary_entity, request, ) - return self.set_full_feature_name(feature_data) + return self.create_df_with_full_feature_name(feature_data) - def set_full_feature_name( + def create_df_with_full_feature_name( self, feature_data: Optional[FeatureData], - ) -> Optional[FeatureData]: + ) -> Tuple[str, Optional[pl.DataFrame]]: """ - Sets the full feature name for the feature data + Creates a dataframe with the full feature name for the feature data """ if not feature_data: - return None - - return FeatureData( - identifier=feature_data.identifier, - features={ - f"{self.name}:{feature_name}": feature_value + return self.name, None + + current_request = request_context.ensure_current_request() + for feature_name in feature_data.features.keys(): + feature_name = f"{self.name}:{feature_name}" + dict_to_update = current_request.feature_orig_identifiers[feature_name] + dict_to_update.update( + {get_identifier_key(feature_data.identifier): feature_data.identifier}, + ) + current_request.feature_orig_identifiers[feature_name] = dict_to_update + + df = pl.DataFrame().with_columns( + [ + pl.Series( + name=IDENTIFIER, + values=[get_identifier_key(feature_data.identifier)], + ), + ], + ) + df = df.with_columns( + [ + pl.Series(name=f"{self.name}:{feature_name}", values=[feature_value]) for feature_name, feature_value in feature_data.features.items() - }, + ], ) + return self.name, df diff --git a/wyvern/components/helpers/polars.py b/wyvern/components/helpers/polars.py new file mode 100644 index 0000000..f2f1b6b --- /dev/null +++ b/wyvern/components/helpers/polars.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +import polars as pl + + +def cast_float32_to_float64(df) -> pl.DataFrame: + float32_cols = [ + col for col, dtype in zip(df.columns, df.dtypes) if dtype == pl.Float32 + ] + return df.with_columns([df[col].cast(pl.Float64) for col in float32_cols]) diff --git a/wyvern/components/models/modelbit_component.py b/wyvern/components/models/modelbit_component.py index 74756ef..3a5b5a5 100644 --- a/wyvern/components/models/modelbit_component.py +++ b/wyvern/components/models/modelbit_component.py @@ -161,15 +161,14 @@ async def build_requests( Union[WyvernEntity, BaseWyvernRequest] ] = input.entities or [input.request] target_identifiers = [entity.identifier for entity in target_entities] + identifier_features_tuples = self.get_features( + target_identifiers, + self.modelbit_features, + ) + all_requests = [ - [ - idx + 1, - [ - self.get_feature(identifier, feature_name) - for feature_name in self.modelbit_features - ], - ] - for idx, identifier in enumerate(target_identifiers) + [idx + 1, features] + for idx, (identifier, features) in enumerate(identifier_features_tuples) ] return target_identifiers, all_requests diff --git a/wyvern/components/pipeline_component.py b/wyvern/components/pipeline_component.py index 1680010..2644dfc 100644 --- a/wyvern/components/pipeline_component.py +++ b/wyvern/components/pipeline_component.py @@ -2,7 +2,8 @@ from functools import cached_property from typing import Optional, Set, Type -from wyvern import request_context +from ddtrace import tracer + from wyvern.components.api_route_component import APIRouteComponent from wyvern.components.component import Component from wyvern.components.features.feature_retrieval_pipeline import ( @@ -58,6 +59,7 @@ async def initialize(self) -> None: for feature_name in component.manifest_feature_names: self.feature_names.add(feature_name) + @tracer.wrap(name="PipelineComponent.retrieve_features") async def retrieve_features(self, request: REQUEST_ENTITY) -> None: """ TODO shu: it doesn't support feature overrides. Write code to support that @@ -67,12 +69,9 @@ async def retrieve_features(self, request: REQUEST_ENTITY) -> None: requested_feature_names=self.feature_names, feature_overrides=self.realtime_features_overrides, ) - - feature_map = await self.feature_retrieval_pipeline.execute( + await self.feature_retrieval_pipeline.execute( feature_request, ) - current_request = request_context.ensure_current_request() - current_request.feature_map = feature_map async def warm_up(self, input: REQUEST_ENTITY) -> None: await super().warm_up(input) diff --git a/wyvern/entities/feature_entities.py b/wyvern/entities/feature_entities.py index 4d4aa50..7475af8 100644 --- a/wyvern/entities/feature_entities.py +++ b/wyvern/entities/feature_entities.py @@ -1,13 +1,19 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import logging from typing import Dict, List +import polars as pl from pydantic.main import BaseModel -from wyvern.entities.identifier import Identifier +from wyvern.entities.identifier import Identifier, get_identifier_key from wyvern.wyvern_typing import WyvernFeature +logger = logging.getLogger(__name__) + +IDENTIFIER = "IDENTIFIER" + class FeatureData(BaseModel, frozen=True): """ @@ -28,30 +34,67 @@ def __repr__(self): return self.__str__() -class FeatureMap(BaseModel, frozen=True): +class FeatureDataFrame(BaseModel): """ - A class to represent a map of identifiers to feature data. - - TODO (kerem): Fix the data duplication between this class and the FeatureData class. The identifier field in the - FeatureData class is redundant. + A class to store features in a polars dataframe. """ - feature_map: Dict[Identifier, FeatureData] + df: pl.DataFrame = pl.DataFrame().with_columns( + pl.Series(name=IDENTIFIER, dtype=pl.Utf8), + ) + class Config: + arbitrary_types_allowed = True + frozen = True -def build_empty_feature_map( - identifiers: List[Identifier], - feature_names: List[str], -) -> FeatureMap: - """ - Builds an empty feature map with the given identifiers and feature names. - """ - return FeatureMap( - feature_map={ - identifier: FeatureData( - identifier=identifier, - features={feature: None for feature in feature_names}, - ) - for identifier in identifiers - }, - ) + def get_features( + self, + identifiers: List[Identifier], + feature_names: List[str], + ) -> pl.DataFrame: + # Filter the dataframe by identifier. If the identifier is a composite identifier, use the primary identifier + identifier_keys = [get_identifier_key(identifier) for identifier in identifiers] + return self.get_features_by_identifier_keys( + identifier_keys=identifier_keys, + feature_names=feature_names, + ) + + def get_features_by_identifier_keys( + self, + identifier_keys: List[str], + feature_names: List[str], + ) -> pl.DataFrame: + # Filter the dataframe by identifier + df = self.df.filter(pl.col(IDENTIFIER).is_in(identifier_keys)) + + # Process feature names, adding identifier to the selection + feature_names = [IDENTIFIER] + feature_names + existing_cols = df.columns + for col_name in feature_names: + if col_name not in existing_cols: + # Add a new column filled with None values if it doesn't exist + df = df.with_columns(pl.lit(None).alias(col_name)) + df = df.select(feature_names) + + return df + + def get_all_features_for_identifier(self, identifier: Identifier) -> pl.DataFrame: + identifier_key = get_identifier_key(identifier) + return self.df.filter(pl.col(IDENTIFIER) == identifier_key) + + @staticmethod + def build_empty_df( + identifiers: List[Identifier], + feature_names: List[str], + ) -> FeatureDataFrame: + """ + Builds an empty polars df with the given identifiers and feature names. + """ + identifier_keys = [get_identifier_key(identifier) for identifier in identifiers] + df_columns = [ + pl.Series(name=IDENTIFIER, values=identifier_keys, dtype=pl.Object), + ] + df_columns.extend( + [pl.lit(None).alias(feature_name) for feature_name in feature_names], # type: ignore + ) + return FeatureDataFrame(df=pl.DataFrame().with_columns(df_columns)) diff --git a/wyvern/entities/feature_entity_helpers.py b/wyvern/entities/feature_entity_helpers.py deleted file mode 100644 index 954d091..0000000 --- a/wyvern/entities/feature_entity_helpers.py +++ /dev/null @@ -1,33 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Dict, Optional - -from wyvern.entities.feature_entities import FeatureData, FeatureMap -from wyvern.entities.identifier import Identifier - - -def feature_map_join(*feature_maps: FeatureMap) -> FeatureMap: - """ - Joins multiple feature maps into a single feature map. Used to join feature maps from different sources. - """ - return feature_map_create( - *[value for map in feature_maps for value in map.feature_map.values()] - ) - - -def feature_map_create(*feature_data: Optional[FeatureData]) -> FeatureMap: - """ - Creates a feature map from a list of feature data. Used to create feature maps from different sources. - """ - feature_map: Dict[Identifier, FeatureData] = {} - for data in feature_data: - if data is None: - continue - - if data.identifier in feature_map: - # print(f"Duplicate keys found in feature map {data}") - # TODO (suchintan): handle duplicate keys at this stage - feature_map[data.identifier].features.update(data.features) - else: - feature_map[data.identifier] = data - - return FeatureMap(feature_map=feature_map) diff --git a/wyvern/entities/identifier.py b/wyvern/entities/identifier.py index f19dddd..e6c980c 100644 --- a/wyvern/entities/identifier.py +++ b/wyvern/entities/identifier.py @@ -139,3 +139,15 @@ def __init__( secondary_identifier=secondary_identifier, **kwargs, ) + + +def get_identifier_key( + identifier: Identifier, +) -> str: + """ + Returns the identifier key for a given identifier. If the identifier is a composite identifier, the primary + identifier is used. This is useful while doing feature retrievals for composite entities. + """ + if isinstance(identifier, CompositeIdentifier): + return str(identifier.primary_identifier) + return str(identifier) diff --git a/wyvern/exceptions.py b/wyvern/exceptions.py index 6f3c342..46b643b 100644 --- a/wyvern/exceptions.py +++ b/wyvern/exceptions.py @@ -99,6 +99,14 @@ class WyvernFeatureNameError(WyvernError): ) +class WyvernFeatureValueError(WyvernError): + """ + Raised when there is an error in feature value + """ + + message = "More than one feature value found for identifier={identifier} feature_name={feature_name}." + + class WyvernModelInputError(WyvernError): """ Raised when there is an error in model input @@ -163,3 +171,11 @@ class MissingModelChainOutputError(WyvernError): class MissingModelOutputError(WyvernError): message = "Identifier is missing in the model output" + + +class WyvernLoggingOriginalIdentifierMissingError(WyvernError): + """ + Raised when original identifier is missing during feature logging + """ + + message = "Original identifier is missing for primary identifier={identifier} feature_name={feature_name}." diff --git a/wyvern/wyvern_request.py b/wyvern/wyvern_request.py index 6b87184..d206e64 100644 --- a/wyvern/wyvern_request.py +++ b/wyvern/wyvern_request.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import annotations +from collections import defaultdict from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union from urllib.parse import urlparse @@ -9,8 +10,9 @@ from pydantic import BaseModel from wyvern.components.events.events import LoggedEvent -from wyvern.entities.feature_entities import FeatureMap +from wyvern.entities.feature_entities import FeatureDataFrame from wyvern.entities.identifier import Identifier +from wyvern.exceptions import WyvernLoggingOriginalIdentifierMissingError @dataclass @@ -28,7 +30,7 @@ class WyvernRequest: entity_store: A dictionary that can be used to store entities that are created during the request events: A list of functions that return a list of LoggedEvents. These functions are called at the end of the request to log events to the event store - feature_map: A FeatureMap that can be used to store features that are created during the request + feature_df: The feature data frame that is created during the request request_id: The request ID of the request """ @@ -43,7 +45,12 @@ class WyvernRequest: # The list of list here is a minor performance optimization to prevent copying of lists for events events: List[Callable[[], List[LoggedEvent[Any]]]] - feature_map: FeatureMap + feature_df: FeatureDataFrame + # feature_orig_identifiers is a hack to get around the fact that the feature dataframe does not store + # the original identifiers of the entities. This is needed for logging the features with the correct + # identifiers. The below map is a map of the feature name to the primary identifier key of the entity to the + # original identifier of the entity + feature_orig_identifiers: Dict[str, Dict[str, Identifier]] # the key is the name of the model and the value is a map of the identifier to the model score model_output_map: Dict[ @@ -92,7 +99,8 @@ def parse_fastapi_request( headers=dict(req.headers), entity_store={}, events=[], - feature_map=FeatureMap(feature_map={}), + feature_df=FeatureDataFrame(), + feature_orig_identifiers=defaultdict(dict), model_output_map={}, request_id=request_id, run_id=run_id, @@ -132,3 +140,26 @@ def get_model_output( if model_name not in self.model_output_map: return None return self.model_output_map[model_name].get(identifier) + + def get_original_identifier( + self, + primary_identifier_key: str, + feature_name: str, + ) -> Identifier: + """Gets the original identifier for a feature name and primary identifier key. + + Args: + primary_identifier_key: The primary identifier key. + feature_name: The name of the feature. + + + Returns: + The original identifier. + """ + try: + return self.feature_orig_identifiers[feature_name][primary_identifier_key] + except KeyError: + raise WyvernLoggingOriginalIdentifierMissingError( + identifier=primary_identifier_key, + feature_name=feature_name, + )