Skip to content

Commit

Permalink
Merge pull request #932 from dssg/kasun_pf_update
Browse files Browse the repository at this point in the history
Updating predict forward with existing model
  • Loading branch information
kasunamare authored Jul 21, 2023
2 parents d8c35d3 + fa71501 commit 2f2a548
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
29 changes: 24 additions & 5 deletions src/triage/predictlist/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
save_retrain_and_get_hash,
get_retrain_config_from_model_id,
temporal_params_from_matrix_metadata,
cohort_config_from_label_config
)


Expand Down Expand Up @@ -72,11 +73,17 @@ def predict_forward_with_existed_model(db_engine, project_path, model_id, as_of_
upgrade_db(db_engine=db_engine)
project_storage = ProjectStorage(project_path)
matrix_storage_engine = project_storage.matrix_storage_engine()

# 1. Get feature and cohort config from database
(train_matrix_uuid, matrix_metadata) = train_matrix_info_from_model_id(db_engine, model_id)
experiment_config = experiment_config_from_model_id(db_engine, model_id)

# 2. Generate cohort
if experiment_config.get('cohort_config') is None:
# If a separate cohort_config is not defined in the config
logger.info('Experiment config does not contain a cohort config. Using the label query')
experiment_config['cohort_config'] = cohort_config_from_label_config(experiment_config['label_config'])

cohort_table_name = f"triage_production.cohort_{experiment_config['cohort_config']['name']}"
cohort_table_generator = EntityDateTableGenerator(
db_engine=db_engine,
Expand Down Expand Up @@ -150,7 +157,7 @@ def predict_forward_with_existed_model(db_engine, project_path, model_id, as_of_
label_name = experiment_config['label_config']['name']
label_type = 'binary'
cohort_name = experiment_config['cohort_config']['name']
user_metadata = experiment_config['user_metadata']
user_metadata = experiment_config.get('user_metadata', {})

# Use timechop to get the time definition for production
temporal_config = experiment_config["temporal_config"]
Expand All @@ -162,8 +169,16 @@ def predict_forward_with_existed_model(db_engine, project_path, model_id, as_of_
test_label_timespan=temporal_config['test_label_timespans'][0]
)

last_split_definition = prod_definitions[-1]

# formating the datetimes as strings to be saved as JSON
last_split_definition['first_as_of_time'] = str(last_split_definition['first_as_of_time'])
last_split_definition['last_as_of_time'] = str(last_split_definition['last_as_of_time'])
last_split_definition['matrix_info_end_time'] = str(last_split_definition['matrix_info_end_time'])
last_split_definition['as_of_times'] = [str(last_split_definition['as_of_times'][0])]

matrix_metadata = Planner.make_metadata(
prod_definitions[-1],
last_split_definition,
reconstructed_feature_dict,
label_name,
label_type,
Expand Down Expand Up @@ -234,6 +249,10 @@ def __init__(self, db_engine, project_path, model_group_id):
self.test_duration = self.experiment_config['temporal_config']['test_durations'][0]
self.feature_start_time=self.experiment_config['temporal_config']['feature_start_time']

# Handling the case where a separate cohort_config is not defined
if self.experiment_config.get('cohort_config') is None:
self.experiment_config['cohort_config'] = cohort_config_from_label_config(self.experiment_config['label_config'])

self.label_name = self.experiment_config['label_config']['name']
self.cohort_name = self.experiment_config['cohort_config']['name']
self.user_metadata = self.experiment_config['user_metadata']
Expand Down Expand Up @@ -361,9 +380,9 @@ def retrain(self, prediction_date):
chops_train_matrix = chops[0]['train_matrix']
as_of_date = datetime.strftime(chops_train_matrix['last_as_of_time'], "%Y-%m-%d")
retrain_definition = {
'first_as_of_time': chops_train_matrix['first_as_of_time'],
'last_as_of_time': chops_train_matrix['last_as_of_time'],
'matrix_info_end_time': chops_train_matrix['matrix_info_end_time'],
'first_as_of_time': str(chops_train_matrix['first_as_of_time']),
'last_as_of_time': str(chops_train_matrix['last_as_of_time']),
'matrix_info_end_time': str(chops_train_matrix['matrix_info_end_time']),
'as_of_times': [as_of_date],
'training_label_timespan': chops_train_matrix['training_label_timespan'],
'max_training_history': chops_train_matrix['max_training_history'],
Expand Down
21 changes: 20 additions & 1 deletion src/triage/predictlist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def get_feature_names(aggregation, matrix_metadata):
logger.spam("Feature prefix = %s", feature_prefix)
feature_group = aggregation.get_table_name(imputed=True).split('.')[1].replace('"', '')
logger.spam("Feature group = %s", feature_group)
feature_names_in_group = [f for f in matrix_metadata['feature_names'] if re.match(f'\\A{feature_prefix}_', f)]
feature_names_in_group = [f for f in matrix_metadata['feature_names'] if re.match(f'\\A{feature_prefix}_entity_id', f)]
logger.spam("Feature names in group = %s", feature_names_in_group)

return feature_group, feature_names_in_group
Expand Down Expand Up @@ -206,3 +206,22 @@ def save_retrain_and_get_hash(config, db_engine):
return retrain_hash


def cohort_config_from_label_config(label_config):
"""Hande the cases where the cohort query is not specified"""

label_query = label_config['query']

cohort_config = dict()
cohort_config['name'] = 'default'

# We can't have the label_timespan in the cohort query
label_query = label_query.replace('{label_timespan}', '1week')

# We use the label query as a subquery and extract the entity ids
cohort_config['query'] = f"""
select
entity_id
from ({label_query}) as lq
"""

return cohort_config

0 comments on commit 2f2a548

Please sign in to comment.