Skip to content

Commit

Permalink
Fix for model-build failure due to presence of survey inputs as a dic…
Browse files Browse the repository at this point in the history
…tionary (#954)

* Filtered out rows with dictionary in user_label_df

The idea is to check the data type using isinstance() and then apply this check on the entire data frame as a whole instead of doing it iteratively on each row which is much slower.

These rows are then filtered out of the original dataframe leaving behind only the non-dict rows.

* Added log statements to indicate dataframe filtering done

Added log statement to the greedy_similarity_binning, to indicate filtering is being done for the dictionary elements in the dataframe if the column is 'trip_user_input'.

* Modified filtering of survey inputs

Now filtering survey inputs before dataframe itself is created by checking whether each dictionary value is a value or a nested dictionary.

* Add TODO so I can merge this

---------

Co-authored-by: Mahadik, Mukul Chandrakant <[email protected]>
Co-authored-by: K. Shankari <[email protected]>
  • Loading branch information
3 people authored Feb 10, 2024
1 parent 911e1ec commit 174bfb1
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,11 @@ def _generate_predictions(self):
probability is estimated with label_count / total_labels.
"""
for _, bin_record in self.bins.items():
user_label_df = pd.DataFrame(bin_record['labels'])
# TODO: Revisit after we have unified label and survey inputs (https://github.com/e-mission/e-mission-docs/issues/1045)
logging.debug("Filtering out any nested dictionaries from the list of dictionary labels")
filtered_label_dicts = [label_dict for label_dict in bin_record['labels'] if not any(isinstance(x, dict) for x in label_dict.values())]
logging.debug("Number of entries after filtering changed %s -> %s" % (len(bin_record['labels']), len(filtered_label_dicts)))
user_label_df = pd.DataFrame(filtered_label_dicts)
user_label_df = lp.map_labels(user_label_df).dropna()
# compute the sum of trips in this cluster
sum_trips = len(user_label_df)
Expand Down

0 comments on commit 174bfb1

Please sign in to comment.