diff --git a/comptox_ai/utils/qsar.py b/comptox_ai/utils/qsar.py index 433f570..6609438 100644 --- a/comptox_ai/utils/qsar.py +++ b/comptox_ai/utils/qsar.py @@ -16,17 +16,23 @@ def grab_graph_data(listAcronym = "CPDAT"): print("Grabbed graph data...") return db.run_cypher('MATCH (cl:ChemicalList {listAcronym: "' + listAcronym + '"})-[r1]->(chem:Chemical)-[r2]->(a:Assay) RETURN chem.commonName AS chemName, chem.maccs AS maccs, r2 AS relation, a.commonName AS assayName') +def assign_assay_values(output): + if pd.isna(output): + return -1 + elif "CHEMICALHASACTIVEASSAY" in output: + return 1 + else: + return 0 + def build_formatted_data(output): chemicals_df = pd.DataFrame(output) pivot_table = chemicals_df.pivot_table(index=['chemName', 'maccs'], columns='assayName', values='relation', aggfunc='first') for col in pivot_table.columns: - pivot_table[col] = pivot_table[col].apply(lambda x: -1 if pd.isna(x) - else 1 if x[1] == "CHEMICALHASACTIVEASSAY" - else 0) + pivot_table[col] = pivot_table[col].apply(lambda x: assign_assay_values(x)) pivot_table.reset_index(inplace=True) return pivot_table -def maccs_key_conversion(maccs): +def maccs_key_conversion(maccs, human_readable = False): maccs_keys_descriptions = { 0: 'Padding', 1: 'ISOTOPE', @@ -196,33 +202,32 @@ def maccs_key_conversion(maccs): 165: 'Ring', 166: 'Fragments FIX: this cant be done in SMARTS' } - return [maccs_keys_descriptions[int(key)] for key in maccs] + return [maccs_keys_descriptions[int(key)] for key in maccs] if human_readable else [maccs_keys_descriptions[int(key) - 1] for key in maccs] -def expand_maccs_column(datalist): +def expand_maccs_column(datalist, human_readable = False): # maccs key is 167 rather than 166 # Expand MACCS binary strings into separate columns - print(datalist['maccs'].iloc[0]) - print(len(datalist['maccs'].iloc[0])) maccs_expanded = datalist['maccs'].apply(lambda x: pd.Series(list(x)).astype(int)) - maccs_expanded.columns = maccs_key_conversion([i for i in range(maccs_expanded.shape[1])]) - print(maccs_expanded.columns) + maccs_expanded.columns = maccs_key_conversion([i for i in range(maccs_expanded.shape[1])], human_readable=human_readable) if human_readable else [f'{i+1}' for i in range(maccs_expanded.shape[1])] datalist = datalist.drop(columns=['maccs']) final_df = pd.concat([datalist[['chemName']], maccs_expanded, datalist.drop(columns=['chemName'])], axis=1) - print("Expanding MACCS binary strings into separate columns...") - print("Converting from maccs key to human readable names...") return final_df -def makeQsarDataset(listAcronym="PFASMASTER", output_dir = "./output"): +def makeQsarDataset(listAcronym="PFASMASTER", output_dir = "./output", human_readable=False): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + graphdata = grab_graph_data(listAcronym=listAcronym) cleaned_data = build_formatted_data(graphdata) - formatted_data = expand_maccs_column(cleaned_data) - formatted_data.to_csv(f"{output_dir.rstrip("/")}/comptox_{listAcronym}_all_assay.tsv", sep="\t") - return formatted_data + formatted_data = expand_maccs_column(cleaned_data, human_readable = human_readable) + formatted_data.to_csv(f"{output_dir.rstrip('/')}/comptox_{listAcronym}_all_assay.tsv", sep="\t") + return formatted_data.set_index('chemName') # Model Training Functions -def train_generic_model(clf, X, y): +def train_generic_model(clf, X, y, kwargs): # Scikit-like model interface + clf = clf(**kwargs) return clf.fit(X, y) def format_data_for_model(df, y_col_name, end_index = 167): @@ -263,8 +268,27 @@ def makeDiscoveryDatasets(data, assays, output_dir = "./output"): return discovery_data +def describe_dataset(data, y_col_name): + X, y = data + num_chemicals = X.shape[0] + num_toxic = y[y == 1].shape[0] + num_nontoxic = y[y == 0].shape[0] + untested = y[y == -1].shape[0] + return {'assay': y_col_name, 'num_chemicals': num_chemicals, 'num_toxic': num_toxic, 'num_nontoxic': num_nontoxic, 'untested': untested} + +def describe_datasets(data, assays, output_dir = "./output"): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + data_description = [] + for assay in assays: + data_description.append(describe_dataset(format_data_for_model(data, assay), assay)) + dd_df = pd.DataFrame.from_dict(data_description) + dd_df.to_csv(f"{output_dir}/data_description.tsv", sep = "\t") + return dd_df + def trainQsarModel( - df, y_col_name, clf, suffix = 1, output_dir = "./output", seed = 42, + df, y_col_name, clf, kwargs, suffix = 1, output_dir = "./output", seed = 42, save_model = True, save_discovery = True): # Expects that the first 167 columns are the macc columns focus_df = df[df[y_col_name] >= 0] # Only keep rows that have assay values @@ -282,7 +306,7 @@ def trainQsarModel( train_X, test_X, train_y, test_y = train_test_split(X, y, random_state=seed, stratify = y) # Train model and save - model = train_generic_model(clf, train_X, train_y) + model = train_generic_model(clf, train_X, train_y, kwargs) if save_model: pickle.dump(model, open(output_dir + "/models/model_" + str(suffix) + ".pkl", "wb")) @@ -313,8 +337,8 @@ def trainQsarModel( return model, test_X, row def trainQsarModels( - data, assays, clf, output_dir = "./output", seed = None, remove_existing_output_folder = False, - write_log = True, save_model = True): + data, assays, clf, kwargs, output_dir = "./output", seed = None, remove_existing_output_folder = False, + write_log = True, save_model = True, human_readable = False): if os.path.exists(f"{output_dir}/log.txt") and remove_existing_output_folder: os.remove(f"{output_dir}/log.txt") @@ -341,7 +365,7 @@ def trainQsarModels( for idx, assay in enumerate(assays): write_to_log(f"Training model for assay {assay} with seed {seed}.", f"{output_dir}/log.txt", write_log) model, X, row = trainQsarModel( - data, assay, clf, suffix=idx, seed = seed, output_dir = output_dir, save_model = save_model) + data, assay, clf, kwargs, suffix=idx, seed = seed, output_dir = output_dir, save_model = save_model) if model is None or row is None: write_to_log(f"Warning: There are not enough of both classes (y) for assay {assay} to train a model.", f"{output_dir}/log.txt", write_log) skipped_models += 1 @@ -355,7 +379,7 @@ def trainQsarModels( model, X = pair explainer = shap.TreeExplainer(model) shap_values = explainer.shap_values(X) - shap.summary_plot(shap_values, X, feature_names = maccs_key_conversion(X.columns), show = False) + shap.summary_plot(shap_values, X, feature_names = X.columns, show = False) if human_readable else shap.summary_plot(shap_values, X, feature_names = maccs_key_conversion(X.columns), show = False) plt.savefig(f"{output_dir}/shap_plots/shap_plot_{idx}.png") plt.clf() write_to_log(f"Saved {len(models)} models and skipped {skipped_models} assays due to insufficient data.", f"{output_dir}/log.txt", write_log) @@ -363,31 +387,12 @@ def trainQsarModels( dd_df.to_csv(f"{output_dir}/data_description.tsv", sep = "\t") return models, evaluation -def describe_datasets(data, assays, output_dir = "./output"): - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - data_description = [] - for assay in assays: - data_description.append(describe_dataset(format_data_for_model(data, assay), assay)) - dd_df = pd.DataFrame.from_dict(data_description) - dd_df.to_csv(f"{output_dir}/data_description.tsv", sep = "\t") - return dd_df - -def describe_dataset(data, y_col_name): - X, y = data - num_chemicals = X.shape[0] - num_toxic = y[y == 1].shape[0] - num_nontoxic = num_chemicals - num_toxic - untested = y[y == -1].shape[0] - return {'assay': y_col_name, 'num_chemicals': num_chemicals, 'num_toxic': num_toxic, 'num_nontoxic': num_nontoxic, 'untested': untested} - # Model Evaluation Functions def select_models(models, evaluation, by_rocauc = True, by_f1 = False, n = 10): target_col = 'rocauc' if by_rocauc else 'f1' best_models = evaluation.sort_values(by=[target_col], ascending=False) - if n < 0: + if n > 0: best_models = best_models.head(n=n) selected_models = [] for idx, row in best_models.iterrows(): @@ -397,11 +402,11 @@ def select_models(models, evaluation, by_rocauc = True, by_f1 = False, n = 10): def select_assays(evaluation, by_rocauc = True, by_f1 = False, n = 10): target_col = 'rocauc' if by_rocauc else 'f1' best_models = evaluation.sort_values(by=[target_col], ascending=False) - if n < 0: + if n > 0: best_models = best_models.head(n=n) return best_models['assay'].tolist() -def predictQsar(model, data, y_col_name): +def predictQsar(model, data, y_col_name, output_dir = "./output"): data = data[data[y_col_name] >= 0] # Only keep rows that have assay values X, y = format_data_for_model(data, y_col_name) y_pred = model.predict(X) @@ -429,13 +434,15 @@ def predictQsar(model, data, y_col_name): "num_nontoxic": num_nontoxic } -def validate_all_models(models, data, assays): +def validate_all_models(models, data, assays, output_dir = "./output"): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # For each model, we check predictions on all assays validations = [] for idx, model in enumerate(models): - validations.append(predictQsar(model, data, assays[idx])) + validations.append(predictQsar(model, data, assays[idx], output_dir = output_dir)) return pd.DataFrame.from_dict(validations) def display_results(results, toxic_cutoff = 1, sort_by = "f1"): return results[results['num_toxic'] > toxic_cutoff].sort_values(by=[sort_by], ascending=False) - \ No newline at end of file