Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to qsar.py #116

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 55 additions & 48 deletions comptox_ai/utils/qsar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,23 @@ def grab_graph_data(listAcronym = "CPDAT"):
print("Grabbed graph data...")
return db.run_cypher('MATCH (cl:ChemicalList {listAcronym: "' + listAcronym + '"})-[r1]->(chem:Chemical)-[r2]->(a:Assay) RETURN chem.commonName AS chemName, chem.maccs AS maccs, r2 AS relation, a.commonName AS assayName')

def assign_assay_values(output):
if pd.isna(output):
return -1
elif "CHEMICALHASACTIVEASSAY" in output:
return 1
else:
return 0

def build_formatted_data(output):
chemicals_df = pd.DataFrame(output)
pivot_table = chemicals_df.pivot_table(index=['chemName', 'maccs'], columns='assayName', values='relation', aggfunc='first')
for col in pivot_table.columns:
pivot_table[col] = pivot_table[col].apply(lambda x: -1 if pd.isna(x)
else 1 if x[1] == "CHEMICALHASACTIVEASSAY"
else 0)
pivot_table[col] = pivot_table[col].apply(lambda x: assign_assay_values(x))
pivot_table.reset_index(inplace=True)
return pivot_table

def maccs_key_conversion(maccs):
def maccs_key_conversion(maccs, human_readable = False):
maccs_keys_descriptions = {
0: 'Padding',
1: 'ISOTOPE',
Expand Down Expand Up @@ -196,33 +202,32 @@ def maccs_key_conversion(maccs):
165: 'Ring',
166: 'Fragments FIX: this cant be done in SMARTS'
}
return [maccs_keys_descriptions[int(key)] for key in maccs]
return [maccs_keys_descriptions[int(key)] for key in maccs] if human_readable else [maccs_keys_descriptions[int(key) - 1] for key in maccs]

def expand_maccs_column(datalist):
def expand_maccs_column(datalist, human_readable = False):
# maccs key is 167 rather than 166
# Expand MACCS binary strings into separate columns
print(datalist['maccs'].iloc[0])
print(len(datalist['maccs'].iloc[0]))
maccs_expanded = datalist['maccs'].apply(lambda x: pd.Series(list(x)).astype(int))
maccs_expanded.columns = maccs_key_conversion([i for i in range(maccs_expanded.shape[1])])
print(maccs_expanded.columns)
maccs_expanded.columns = maccs_key_conversion([i for i in range(maccs_expanded.shape[1])], human_readable=human_readable) if human_readable else [f'{i+1}' for i in range(maccs_expanded.shape[1])]
datalist = datalist.drop(columns=['maccs'])
final_df = pd.concat([datalist[['chemName']], maccs_expanded, datalist.drop(columns=['chemName'])], axis=1)
print("Expanding MACCS binary strings into separate columns...")
print("Converting from maccs key to human readable names...")
return final_df

def makeQsarDataset(listAcronym="PFASMASTER", output_dir = "./output"):
def makeQsarDataset(listAcronym="PFASMASTER", output_dir = "./output", human_readable=False):
if not os.path.exists(output_dir):
os.makedirs(output_dir)

graphdata = grab_graph_data(listAcronym=listAcronym)
cleaned_data = build_formatted_data(graphdata)
formatted_data = expand_maccs_column(cleaned_data)
formatted_data.to_csv(f"{output_dir.rstrip("/")}/comptox_{listAcronym}_all_assay.tsv", sep="\t")
return formatted_data
formatted_data = expand_maccs_column(cleaned_data, human_readable = human_readable)
formatted_data.to_csv(f"{output_dir.rstrip('/')}/comptox_{listAcronym}_all_assay.tsv", sep="\t")
return formatted_data.set_index('chemName')

# Model Training Functions

def train_generic_model(clf, X, y):
def train_generic_model(clf, X, y, kwargs):
# Scikit-like model interface
clf = clf(**kwargs)
return clf.fit(X, y)

def format_data_for_model(df, y_col_name, end_index = 167):
Expand Down Expand Up @@ -263,8 +268,27 @@ def makeDiscoveryDatasets(data, assays, output_dir = "./output"):

return discovery_data

def describe_dataset(data, y_col_name):
X, y = data
num_chemicals = X.shape[0]
num_toxic = y[y == 1].shape[0]
num_nontoxic = y[y == 0].shape[0]
untested = y[y == -1].shape[0]
return {'assay': y_col_name, 'num_chemicals': num_chemicals, 'num_toxic': num_toxic, 'num_nontoxic': num_nontoxic, 'untested': untested}

def describe_datasets(data, assays, output_dir = "./output"):
if not os.path.exists(output_dir):
os.makedirs(output_dir)

data_description = []
for assay in assays:
data_description.append(describe_dataset(format_data_for_model(data, assay), assay))
dd_df = pd.DataFrame.from_dict(data_description)
dd_df.to_csv(f"{output_dir}/data_description.tsv", sep = "\t")
return dd_df

def trainQsarModel(
df, y_col_name, clf, suffix = 1, output_dir = "./output", seed = 42,
df, y_col_name, clf, kwargs, suffix = 1, output_dir = "./output", seed = 42,
save_model = True, save_discovery = True):
# Expects that the first 167 columns are the macc columns
focus_df = df[df[y_col_name] >= 0] # Only keep rows that have assay values
Expand All @@ -282,7 +306,7 @@ def trainQsarModel(
train_X, test_X, train_y, test_y = train_test_split(X, y, random_state=seed, stratify = y)

# Train model and save
model = train_generic_model(clf, train_X, train_y)
model = train_generic_model(clf, train_X, train_y, kwargs)

if save_model:
pickle.dump(model, open(output_dir + "/models/model_" + str(suffix) + ".pkl", "wb"))
Expand Down Expand Up @@ -313,8 +337,8 @@ def trainQsarModel(
return model, test_X, row

def trainQsarModels(
data, assays, clf, output_dir = "./output", seed = None, remove_existing_output_folder = False,
write_log = True, save_model = True):
data, assays, clf, kwargs, output_dir = "./output", seed = None, remove_existing_output_folder = False,
write_log = True, save_model = True, human_readable = False):
if os.path.exists(f"{output_dir}/log.txt") and remove_existing_output_folder:
os.remove(f"{output_dir}/log.txt")

Expand All @@ -341,7 +365,7 @@ def trainQsarModels(
for idx, assay in enumerate(assays):
write_to_log(f"Training model for assay {assay} with seed {seed}.", f"{output_dir}/log.txt", write_log)
model, X, row = trainQsarModel(
data, assay, clf, suffix=idx, seed = seed, output_dir = output_dir, save_model = save_model)
data, assay, clf, kwargs, suffix=idx, seed = seed, output_dir = output_dir, save_model = save_model)
if model is None or row is None:
write_to_log(f"Warning: There are not enough of both classes (y) for assay {assay} to train a model.", f"{output_dir}/log.txt", write_log)
skipped_models += 1
Expand All @@ -355,39 +379,20 @@ def trainQsarModels(
model, X = pair
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
shap.summary_plot(shap_values, X, feature_names = maccs_key_conversion(X.columns), show = False)
shap.summary_plot(shap_values, X, feature_names = X.columns, show = False) if human_readable else shap.summary_plot(shap_values, X, feature_names = maccs_key_conversion(X.columns), show = False)
plt.savefig(f"{output_dir}/shap_plots/shap_plot_{idx}.png")
plt.clf()
write_to_log(f"Saved {len(models)} models and skipped {skipped_models} assays due to insufficient data.", f"{output_dir}/log.txt", write_log)
dd_df = pd.DataFrame.from_dict(data_description)
dd_df.to_csv(f"{output_dir}/data_description.tsv", sep = "\t")
return models, evaluation

def describe_datasets(data, assays, output_dir = "./output"):
if not os.path.exists(output_dir):
os.makedirs(output_dir)

data_description = []
for assay in assays:
data_description.append(describe_dataset(format_data_for_model(data, assay), assay))
dd_df = pd.DataFrame.from_dict(data_description)
dd_df.to_csv(f"{output_dir}/data_description.tsv", sep = "\t")
return dd_df

def describe_dataset(data, y_col_name):
X, y = data
num_chemicals = X.shape[0]
num_toxic = y[y == 1].shape[0]
num_nontoxic = num_chemicals - num_toxic
untested = y[y == -1].shape[0]
return {'assay': y_col_name, 'num_chemicals': num_chemicals, 'num_toxic': num_toxic, 'num_nontoxic': num_nontoxic, 'untested': untested}

# Model Evaluation Functions

def select_models(models, evaluation, by_rocauc = True, by_f1 = False, n = 10):
target_col = 'rocauc' if by_rocauc else 'f1'
best_models = evaluation.sort_values(by=[target_col], ascending=False)
if n < 0:
if n > 0:
best_models = best_models.head(n=n)
selected_models = []
for idx, row in best_models.iterrows():
Expand All @@ -397,11 +402,11 @@ def select_models(models, evaluation, by_rocauc = True, by_f1 = False, n = 10):
def select_assays(evaluation, by_rocauc = True, by_f1 = False, n = 10):
target_col = 'rocauc' if by_rocauc else 'f1'
best_models = evaluation.sort_values(by=[target_col], ascending=False)
if n < 0:
if n > 0:
best_models = best_models.head(n=n)
return best_models['assay'].tolist()

def predictQsar(model, data, y_col_name):
def predictQsar(model, data, y_col_name, output_dir = "./output"):
data = data[data[y_col_name] >= 0] # Only keep rows that have assay values
X, y = format_data_for_model(data, y_col_name)
y_pred = model.predict(X)
Expand Down Expand Up @@ -429,13 +434,15 @@ def predictQsar(model, data, y_col_name):
"num_nontoxic": num_nontoxic
}

def validate_all_models(models, data, assays):
def validate_all_models(models, data, assays, output_dir = "./output"):
if not os.path.exists(output_dir):
os.makedirs(output_dir)

# For each model, we check predictions on all assays
validations = []
for idx, model in enumerate(models):
validations.append(predictQsar(model, data, assays[idx]))
validations.append(predictQsar(model, data, assays[idx], output_dir = output_dir))
return pd.DataFrame.from_dict(validations)

def display_results(results, toxic_cutoff = 1, sort_by = "f1"):
return results[results['num_toxic'] > toxic_cutoff].sort_values(by=[sort_by], ascending=False)