Skip to content

Commit

Permalink
label as series from polars while stitching
Browse files Browse the repository at this point in the history
  • Loading branch information
silil committed Nov 13, 2023
1 parent c3be958 commit bdcd19d
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/triage/component/architect/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,10 @@ def build_matrix(
matrix_metadata["label_timespan"],
)

output = self.stitch_csvs(feature_queries, label_query, matrix_store, matrix_uuid)
output, labels = self.stitch_csvs(feature_queries, label_query, matrix_store, matrix_uuid)
logger.debug(f"matrix stitched, pandas DF returned")
matrix_store.metadata = matrix_metadata
labels = output.pop(matrix_store.label_column_name)
#labels = output.pop(matrix_store.label_column_name)
matrix_store.matrix_label_tuple = output, labels
matrix_store.save()

Expand Down Expand Up @@ -554,10 +554,22 @@ def stitch_csvs(self, features_queries, label_query, matrix_store, matrix_uuid):
df_pl = df_pl.with_columns(pl.col("entity_id").cast(pl.Int32, strict=False))
end = time.time()
logger.debug(f"time casting entity_id and as_of_date of matrix with uuid {matrix_uuid} (sec): {(end-start)/60}")

logger.debug(f"getting labels pandas series from polars data frame")
# getting label series
labels_pl = df_pl.select(pl.columns[-1])
# convert into pandas series
labels_df = labels_pl.to_pandas()
labels_series = labels_df.squeeze()

# remove labels from features and return as df
logger.debug(f"removing labels from main polars df")
df_pl_aux = df_pl.drop(df_pl.columns[-1])

# converting from polars to pandas
logger.debug(f"about to convert polars df into pandas df")
start = time.time()
df = df_pl.to_pandas()
df = df_pl_aux.to_pandas()
end = time.time()
logger.debug(f"Time converting from polars to pandas (sec): {(end-start)/60}")
df.set_index(["entity_id", "as_of_date"], inplace=True)
Expand All @@ -569,8 +581,8 @@ def stitch_csvs(self, features_queries, label_query, matrix_store, matrix_uuid):
rm_filenames = generate_list_of_files_to_remove(filenames, matrix_uuid)
self.remove_unnecessary_files(rm_filenames, path_, matrix_uuid)

#return downcast_matrix(df)
return df
return df, labels_series


def remove_unnecessary_files(self, filenames, path_, matrix_uuid):
"""
Expand Down

0 comments on commit bdcd19d

Please sign in to comment.