From ed98532fff76c083a3683dc6c0fb4588f36c5838 Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Mon, 16 Oct 2023 13:08:53 -0700 Subject: [PATCH] Change component.get_feature return type --- wyvern/components/component.py | 31 ++++++++++++++++--- .../components/models/modelbit_component.py | 9 ++---- wyvern/entities/feature_entities.py | 11 +++++++ 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/wyvern/components/component.py b/wyvern/components/component.py index 1ab8675..a44d651 100644 --- a/wyvern/components/component.py +++ b/wyvern/components/component.py @@ -5,13 +5,13 @@ import logging from enum import Enum from functools import cached_property -from typing import Dict, Generic, List, Optional, Set, Union +from typing import Dict, Generic, List, Optional, Set, Tuple, Union from uuid import uuid4 import polars as pl from wyvern import request_context -from wyvern.entities.identifier import Identifier +from wyvern.entities.identifier import Identifier, get_identifier_key from wyvern.exceptions import WyvernFeatureValueError from wyvern.wyvern_typing import INPUT_TYPE, OUTPUT_TYPE, WyvernFeature @@ -149,13 +149,34 @@ def manifest_feature_names(self) -> Set[str]: def get_features( identifiers: List[Identifier], feature_names: List[str], - ) -> pl.DataFrame: + ) -> List[Tuple[str, List[WyvernFeature]]]: current_request = request_context.ensure_current_request() - return current_request.feature_df.get_features( - identifiers, + identifier_keys = [get_identifier_key(identifier) for identifier in identifiers] + df = current_request.feature_df.get_features_by_identifier_keys( + identifier_keys, feature_names, ) + # build tuples where the identifier column is the first element and the feature columns are the rest + rows = df.rows() + identifier_to_features_dict = { + # row[0] is the identifier column, it is a string + # row[1:] are the feature columns, each column is a WyvernFeature + row[0]: row[1:] + for row in rows + } + + empty_feature_list = [None] * len(feature_names) + tuples = [ + ( + identifier_key, + identifier_to_features_dict.get(identifier_key, empty_feature_list), + ) + for identifier_key in identifier_keys + ] + + return tuples # type: ignore + @staticmethod def get_feature( identifier: Identifier, diff --git a/wyvern/components/models/modelbit_component.py b/wyvern/components/models/modelbit_component.py index fc012af..3a5b5a5 100644 --- a/wyvern/components/models/modelbit_component.py +++ b/wyvern/components/models/modelbit_component.py @@ -161,17 +161,14 @@ async def build_requests( Union[WyvernEntity, BaseWyvernRequest] ] = input.entities or [input.request] target_identifiers = [entity.identifier for entity in target_entities] - features = self.get_features( + identifier_features_tuples = self.get_features( target_identifiers, self.modelbit_features, ) - # Convert the fetched features DataFrame to a list of lists for easy access - features_list = features.rows() - all_requests = [ - [idx + 1, row[1:]] # row[0] is the identifier, so we skip it. - for idx, row in enumerate(features_list) + [idx + 1, features] + for idx, (identifier, features) in enumerate(identifier_features_tuples) ] return target_identifiers, all_requests diff --git a/wyvern/entities/feature_entities.py b/wyvern/entities/feature_entities.py index d67d84c..7475af8 100644 --- a/wyvern/entities/feature_entities.py +++ b/wyvern/entities/feature_entities.py @@ -54,6 +54,17 @@ def get_features( ) -> pl.DataFrame: # Filter the dataframe by identifier. If the identifier is a composite identifier, use the primary identifier identifier_keys = [get_identifier_key(identifier) for identifier in identifiers] + return self.get_features_by_identifier_keys( + identifier_keys=identifier_keys, + feature_names=feature_names, + ) + + def get_features_by_identifier_keys( + self, + identifier_keys: List[str], + feature_names: List[str], + ) -> pl.DataFrame: + # Filter the dataframe by identifier df = self.df.filter(pl.col(IDENTIFIER).is_in(identifier_keys)) # Process feature names, adding identifier to the selection