Skip to content

Commit

Permalink
inference helper columns for online
Browse files Browse the repository at this point in the history
  • Loading branch information
davitbzh committed Oct 13, 2023
1 parent 3eb9c1d commit 6ade109
Show file tree
Hide file tree
Showing 5 changed files with 579 additions and 191 deletions.
21 changes: 15 additions & 6 deletions python/hsfs/core/feature_view_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,10 @@ def get_batch_query(
end_time,
training_dataset_version=None,
with_label=False,
with_inference_helper_columns=False,
with_training_helper_columns=False,
primary_keys=False,
event_time=False,
inference_helper_columns=False,
training_helper_columns=False,
is_python_engine=False,
):
path = self._base_path + [
Expand All @@ -147,23 +149,30 @@ def get_batch_query(
"start_time": start_time,
"end_time": end_time,
"with_label": with_label,
"with_inference_helper_columns": with_inference_helper_columns,
"with_training_helper_columns": with_training_helper_columns,
"with_primary_keys": primary_keys,
"with_event_time": event_time,
"inference_helper_columns": inference_helper_columns,
"training_helper_columns": training_helper_columns,
"is_hive_engine": is_python_engine,
"td_version": training_dataset_version,
},
)
)

def get_serving_prepared_statement(self, name, version, batch):
def get_serving_prepared_statement(
self, name, version, batch, inference_helper_columns
):
path = self._base_path + [
name,
self._VERSION,
version,
self._PREPARED_STATEMENT,
]
headers = {"content-type": "application/json"}
query_params = {"batch": batch}
query_params = {
"batch": batch,
"inference_helper_columns": inference_helper_columns,
}
return serving_prepared_statement.ServingPreparedStatement.from_response_json(
self._client._send_request("GET", path, query_params, headers=headers)
)
Expand Down
66 changes: 47 additions & 19 deletions python/hsfs/core/feature_view_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,10 @@ def get_batch_query(
start_time,
end_time,
with_label=False,
with_inference_helper_columns=False,
with_training_helper_columns=False,
primary_keys=False,
event_time=False,
inference_helper_columns=False,
training_helper_columns=False,
training_dataset_version=None,
spine=None,
):
Expand All @@ -167,8 +169,10 @@ def get_batch_query(
training_dataset_version=training_dataset_version,
is_python_engine=engine.get_type() == "python",
with_label=with_label,
with_inference_helper_columns=with_inference_helper_columns,
with_training_helper_columns=with_training_helper_columns,
primary_keys=primary_keys,
event_time=event_time,
inference_helper_columns=inference_helper_columns,
training_helper_columns=training_helper_columns,
)
# verify whatever is passed 1. spine group with dataframe contained, or 2. dataframe
# the schema has to be consistent
Expand Down Expand Up @@ -252,7 +256,9 @@ def create_training_dataset(
training_dataset_obj,
user_write_options,
spine=None,
with_training_helper_columns=False,
primary_keys=False,
event_time=False,
training_helper_columns=False,
):
self._set_event_time(feature_view_obj, training_dataset_obj)
updated_instance = self._create_training_data_metadata(
Expand All @@ -263,7 +269,9 @@ def create_training_dataset(
user_write_options,
training_dataset_obj=training_dataset_obj,
spine=spine,
with_training_helper_columns=with_training_helper_columns,
with_primary_keys=primary_keys,
event_time=event_time,
training_helper_columns=training_helper_columns,
)
return updated_instance, td_job

Expand All @@ -275,7 +283,9 @@ def get_training_data(
training_dataset_obj=None,
training_dataset_version=None,
spine=None,
with_training_helper_columns=False,
primary_keys=False,
event_time=False,
training_helper_columns=False,
):
# check if provided td version has already existed.
if training_dataset_version:
Expand Down Expand Up @@ -308,7 +318,9 @@ def get_training_data(
td_updated,
td_updated.splits,
read_options,
with_training_helper_columns,
primary_keys,
event_time,
training_helper_columns,
feature_view_obj.training_helper_columns,
)
else:
Expand All @@ -319,8 +331,10 @@ def get_training_data(
start_time=td_updated.event_start_time,
end_time=td_updated.event_end_time,
with_label=True,
with_inference_helper_columns=False,
with_training_helper_columns=with_training_helper_columns,
inference_helper_columns=False,
primary_keys=primary_keys,
event_time=event_time,
training_helper_columns=training_helper_columns,
spine=spine,
)
split_df = engine.get_instance().get_training_data(
Expand Down Expand Up @@ -399,6 +413,8 @@ def _read_from_storage_connector(
training_data_obj,
splits,
read_options,
primary_keys,
event_time,
with_training_helper_columns,
training_helper_columns,
):
Expand All @@ -411,6 +427,8 @@ def _read_from_storage_connector(
training_data_obj,
path,
read_options,
primary_keys,
event_time,
with_training_helper_columns,
training_helper_columns,
)
Expand All @@ -421,6 +439,8 @@ def _read_from_storage_connector(
training_data_obj,
path,
read_options,
primary_keys,
event_time,
with_training_helper_columns,
training_helper_columns,
)
Expand All @@ -438,6 +458,8 @@ def _read_dir_from_storage_connector(
training_data_obj,
path,
read_options,
primary_keys,
event_time,
with_training_helper_columns,
training_helper_columns,
):
Expand Down Expand Up @@ -477,7 +499,9 @@ def compute_training_dataset(
training_dataset_obj=None,
training_dataset_version=None,
spine=None,
with_training_helper_columns=False,
primary_keys=False,
event_time=False,
training_helper_columns=False,
):
if training_dataset_obj:
pass
Expand All @@ -493,16 +517,16 @@ def compute_training_dataset(
training_dataset_obj.event_start_time,
training_dataset_obj.event_end_time,
with_label=True,
with_inference_helper_columns=False,
with_training_helper_columns=with_training_helper_columns,
primary_keys=primary_keys,
event_time=event_time,
inference_helper_columns=False,
training_helper_columns=training_helper_columns,
training_dataset_version=training_dataset_obj.version,
spine=spine,
)

# for spark job
user_write_options[
"with_training_helper_columns"
] = with_training_helper_columns
user_write_options["training_helper_columns"] = training_helper_columns

td_job = engine.get_instance().write_training_dataset(
training_dataset_obj,
Expand Down Expand Up @@ -610,7 +634,9 @@ def get_batch_data(
transformation_functions,
read_options=None,
spine=None,
with_inference_helper_columns=False,
primary_keys=False,
event_time=False,
inference_helper_columns=False,
):
self._check_feature_group_accessibility(feature_view_obj)

Expand All @@ -619,8 +645,10 @@ def get_batch_data(
start_time,
end_time,
with_label=False,
with_inference_helper_columns=with_inference_helper_columns,
with_training_helper_columns=False,
primary_keys=primary_keys,
event_time=event_time,
inference_helper_columns=inference_helper_columns,
training_helper_columns=False,
training_dataset_version=training_dataset_version,
spine=spine,
).read(read_options=read_options)
Expand Down
Loading

0 comments on commit 6ade109

Please sign in to comment.