Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
Change component.get_feature return type
Browse files Browse the repository at this point in the history
  • Loading branch information
ykeremy committed Oct 16, 2023
1 parent fa723e7 commit ed98532
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 11 deletions.
31 changes: 26 additions & 5 deletions wyvern/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import logging
from enum import Enum
from functools import cached_property
from typing import Dict, Generic, List, Optional, Set, Union
from typing import Dict, Generic, List, Optional, Set, Tuple, Union
from uuid import uuid4

import polars as pl

from wyvern import request_context
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier import Identifier, get_identifier_key
from wyvern.exceptions import WyvernFeatureValueError
from wyvern.wyvern_typing import INPUT_TYPE, OUTPUT_TYPE, WyvernFeature

Expand Down Expand Up @@ -149,13 +149,34 @@ def manifest_feature_names(self) -> Set[str]:
def get_features(
identifiers: List[Identifier],
feature_names: List[str],
) -> pl.DataFrame:
) -> List[Tuple[str, List[WyvernFeature]]]:
current_request = request_context.ensure_current_request()
return current_request.feature_df.get_features(
identifiers,
identifier_keys = [get_identifier_key(identifier) for identifier in identifiers]
df = current_request.feature_df.get_features_by_identifier_keys(
identifier_keys,
feature_names,
)

# build tuples where the identifier column is the first element and the feature columns are the rest
rows = df.rows()
identifier_to_features_dict = {
# row[0] is the identifier column, it is a string
# row[1:] are the feature columns, each column is a WyvernFeature
row[0]: row[1:]
for row in rows
}

empty_feature_list = [None] * len(feature_names)
tuples = [
(
identifier_key,
identifier_to_features_dict.get(identifier_key, empty_feature_list),
)
for identifier_key in identifier_keys
]

return tuples # type: ignore

@staticmethod
def get_feature(
identifier: Identifier,
Expand Down
9 changes: 3 additions & 6 deletions wyvern/components/models/modelbit_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,14 @@ async def build_requests(
Union[WyvernEntity, BaseWyvernRequest]
] = input.entities or [input.request]
target_identifiers = [entity.identifier for entity in target_entities]
features = self.get_features(
identifier_features_tuples = self.get_features(
target_identifiers,
self.modelbit_features,
)

# Convert the fetched features DataFrame to a list of lists for easy access
features_list = features.rows()

all_requests = [
[idx + 1, row[1:]] # row[0] is the identifier, so we skip it.
for idx, row in enumerate(features_list)
[idx + 1, features]
for idx, (identifier, features) in enumerate(identifier_features_tuples)
]
return target_identifiers, all_requests

Expand Down
11 changes: 11 additions & 0 deletions wyvern/entities/feature_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ def get_features(
) -> pl.DataFrame:
# Filter the dataframe by identifier. If the identifier is a composite identifier, use the primary identifier
identifier_keys = [get_identifier_key(identifier) for identifier in identifiers]
return self.get_features_by_identifier_keys(
identifier_keys=identifier_keys,
feature_names=feature_names,
)

def get_features_by_identifier_keys(
self,
identifier_keys: List[str],
feature_names: List[str],
) -> pl.DataFrame:
# Filter the dataframe by identifier
df = self.df.filter(pl.col(IDENTIFIER).is_in(identifier_keys))

# Process feature names, adding identifier to the selection
Expand Down

0 comments on commit ed98532

Please sign in to comment.