Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/RTXteam/RTX
Browse files Browse the repository at this point in the history
  • Loading branch information
ecwood committed May 11, 2021
2 parents c0750b1 + f269c86 commit 9f0c6e4
Show file tree
Hide file tree
Showing 13 changed files with 268 additions and 225 deletions.
73 changes: 45 additions & 28 deletions code/ARAX/ARAXQuery/ARAX_resultify.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'''

import collections
import copy
import math
import os
import sys
Expand Down Expand Up @@ -521,18 +522,23 @@ def _get_results_for_kg_by_qg(kg: KnowledgeGraph, # all nodes *must

# Handle case where QG contains multiple qnodes and no qedges (we'll dump everything in one result)
if not qg.edges and len(qg.nodes) > 1:
nodes_by_qg_key = _get_kg_node_keys_by_qg_key(kg)
result_graph = _create_new_empty_result_graph(qg)
result_graph["nodes"] = nodes_by_qg_key
result_graph["nodes"] = kg_node_keys_by_qg_key
final_result_graphs = [result_graph]
else:
# Build up some indexes for edges in the KG (by their subject/object nodes and qedge keys)
edges_by_qg_id_and_subject_node = collections.defaultdict(lambda: collections.defaultdict(lambda: set()))
edges_by_qg_id_and_object_node = collections.defaultdict(lambda: collections.defaultdict(lambda: set()))
edge_keys_by_subject = collections.defaultdict(lambda: collections.defaultdict(lambda: set()))
edge_keys_by_object = collections.defaultdict(lambda: collections.defaultdict(lambda: set()))
edge_keys_by_node_pair = collections.defaultdict(lambda: collections.defaultdict(lambda: set()))
for edge_key, edge in kg.edges.items():
for qedge_key in edge.qedge_keys:
edges_by_qg_id_and_subject_node[qedge_key][edge.subject].add(edge_key)
edges_by_qg_id_and_object_node[qedge_key][edge.object].add(edge_key)
for qedge_id in edge.qedge_keys:
edge_keys_by_subject[qedge_id][edge.subject].add(edge_key)
edge_keys_by_object[qedge_id][edge.object].add(edge_key)
node_pair_string = f"{edge.subject}--{edge.object}"
edge_keys_by_node_pair[qedge_id][node_pair_string].add(edge_key)
if ignore_edge_direction:
node_pair_other_direction = f"{edge.object}--{edge.subject}"
edge_keys_by_node_pair[qedge_id][node_pair_other_direction].add(edge_key)

# Create results off the "required" portion of the QG (excluding any qnodes/qedges belong to an "option group")
required_qg = QueryGraph(nodes={qnode_key: qnode for qnode_key, qnode in qg.nodes.items() if not qnode.option_group_id},
Expand All @@ -541,8 +547,9 @@ def _get_results_for_kg_by_qg(kg: KnowledgeGraph, # all nodes *must
if qg_is_disconnected:
raise ValueError(f"Required portion of QG is disconnected. This isn't allowed! 'Required' qnode IDs are: "
f"{[qnode_key for qnode_key in required_qg.nodes]}")
result_graphs_required = _create_result_graphs(kg, required_qg, edges_by_qg_id_and_subject_node,
edges_by_qg_id_and_object_node, ignore_edge_direction)
result_graphs_required = _create_result_graphs(kg, required_qg, kg_node_keys_by_qg_key,
edge_keys_by_subject, edge_keys_by_object,
edge_keys_by_node_pair, ignore_edge_direction)

# Then create results for each of the "option groups" in the QG (including the required portion of the QG with each)
option_groups_in_qg = {qedge.option_group_id for qedge in qg.edges.values() if qedge.option_group_id}
Expand All @@ -558,8 +565,9 @@ def _get_results_for_kg_by_qg(kg: KnowledgeGraph, # all nodes *must
raise ValueError(f"Required + option group {option_group_id} portion of the QG is disconnected. "
f"This isn't allowed! 'Required'/group {option_group_id} qnode IDs are: "
f"{[qnode_key for qnode_key in option_group_qg.nodes]}")
result_graphs_for_option_group = _create_result_graphs(kg, option_group_qg, edges_by_qg_id_and_subject_node,
edges_by_qg_id_and_object_node, ignore_edge_direction)
result_graphs_for_option_group = _create_result_graphs(kg, option_group_qg, kg_node_keys_by_qg_key,
edge_keys_by_subject, edge_keys_by_object,
edge_keys_by_node_pair, ignore_edge_direction)
option_group_results_dict[option_group_id] = result_graphs_for_option_group

# Organize our results for the 'required' portion of the QG by the IDs of their is_set=False nodes
Expand Down Expand Up @@ -702,8 +710,7 @@ def _create_new_empty_result_graph(query_graph: QueryGraph) -> Dict[str, Dict[st


def _copy_result_graph(result_graph: Dict[str, Dict[str, Set[str]]]) -> Dict[str, Dict[str, Set[str]]]:
result_graph_copy = {'nodes': {qnode_key: node_keys for qnode_key, node_keys in result_graph['nodes'].items()},
'edges': {qedge_key: edge_keys for qedge_key, edge_keys in result_graph['edges'].items()}}
result_graph_copy = copy.deepcopy(result_graph)
return result_graph_copy


Expand Down Expand Up @@ -790,7 +797,7 @@ def _find_qnode_connected_to_sub_qg(qnode_keys_to_connect_to: Set[str], qnode_ke
return "", set()


def _get_qg_adj_map_undirected(qg) -> Dict[str, Set[str]]:
def _get_qg_adj_map_undirected(qg: QueryGraph) -> Dict[str, Set[str]]:
"""
This function creates a node adjacency map for a given query graph. Example: {"n0": {"n1"}, "n1": {"n0"}}
"""
Expand Down Expand Up @@ -850,11 +857,12 @@ def _clean_up_dead_ends(result_graph: Dict[str, Dict[str, Set[str]]],

def _create_result_graphs(kg: KnowledgeGraph,
qg: QueryGraph,
edges_by_qg_id_and_subject: DefaultDict[str, DefaultDict[str, set]],
edges_by_qg_id_and_object: DefaultDict[str, DefaultDict[str, set]],
kg_node_keys_by_qg_key: Dict[str, Set[str]],
edge_keys_by_subject: DefaultDict[str, DefaultDict[str, set]],
edge_keys_by_object: DefaultDict[str, DefaultDict[str, set]],
edge_keys_by_node_pair: DefaultDict[str, DefaultDict[str, set]],
ignore_edge_direction: bool = True) -> List[Result]:
result_graphs = []
kg_node_keys_by_qg_key = _get_kg_node_keys_by_qg_key(kg)
kg_node_adj_map_by_qg_key = _get_kg_node_adj_map_by_qg_key(kg_node_keys_by_qg_key, kg, qg)
qg_adj_map = _get_qg_adj_map_undirected(qg)

Expand Down Expand Up @@ -926,17 +934,26 @@ def _create_result_graphs(kg: KnowledgeGraph,
qedge = qg.edges[qedge_key]
qedge_source_node_ids = result_graph['nodes'][qedge.subject]
qedge_target_node_ids = result_graph['nodes'][qedge.object]
edges_with_matching_subject = {edge_key for source_node in qedge_source_node_ids
for edge_key in edges_by_qg_id_and_subject[qedge_key][source_node]}
edges_with_matching_object = {edge_key for target_node in qedge_target_node_ids
for edge_key in edges_by_qg_id_and_object[qedge_key][target_node]}
result_graph['edges'][qedge_key] = edges_with_matching_subject.intersection(edges_with_matching_object)
if ignore_edge_direction:
edges_with_reverse_subject = {edge_key for target_node in qedge_target_node_ids
for edge_key in edges_by_qg_id_and_subject[qedge_key][target_node]}
edges_with_reverse_object = {edge_key for source_node in qedge_source_node_ids
for edge_key in edges_by_qg_id_and_object[qedge_key][source_node]}
result_graph['edges'][qedge_key].update(edges_with_reverse_subject.intersection(edges_with_reverse_object))
# Pick the more efficient method for edge-finding depending on the number of nodes for this result/qedge
if len(qedge_source_node_ids) < 10 or len(qedge_target_node_ids) < 10:
possible_node_pairs = {f"{node_1}--{node_2}" for node_1 in qedge_source_node_ids
for node_2 in qedge_target_node_ids}
for node_pair in possible_node_pairs:
ids_of_matching_edges = edge_keys_by_node_pair[qedge_key].get(node_pair, set())
result_graph['edges'][qedge_key].update(ids_of_matching_edges)
else:
# This technique is more efficient when there are large numbers of both subject and object nodes
edges_with_matching_subject = {edge_key for source_node in qedge_source_node_ids
for edge_key in edge_keys_by_subject[qedge_key][source_node]}
edges_with_matching_object = {edge_key for target_node in qedge_target_node_ids
for edge_key in edge_keys_by_object[qedge_key][target_node]}
result_graph['edges'][qedge_key] = edges_with_matching_subject.intersection(edges_with_matching_object)
if ignore_edge_direction:
edges_with_reverse_subject = {edge_key for target_node in qedge_target_node_ids
for edge_key in edge_keys_by_subject[qedge_key][target_node]}
edges_with_reverse_object = {edge_key for source_node in qedge_source_node_ids
for edge_key in edge_keys_by_object[qedge_key][source_node]}
result_graph['edges'][qedge_key].update(edges_with_reverse_subject.intersection(edges_with_reverse_object))

final_result_graphs = [result_graph for result_graph in result_graphs if _result_graph_is_fulfilled(result_graph, qg)]
return final_result_graphs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self):

def get_drug_curies_from_graph(self):
## Pulls a dataframe of all of the graph drug-associated nodes
query = "match (n {category:'biolink:ChemicalSubstance'}) with distinct n.id as id, n.name as name, n.equivalent_curies as equivalent_curies return id, name, equivalent_curies union match (n {category:'biolink:Drug'}) with distinct n.id as id, n.name as name, n.equivalent_curies as equivalent_curies return id, name, equivalent_curies"
query = "match (n) where n.category in ['biolink:ChemicalSubstance', 'biolink:Drug', 'biolink:Metabolite'] with distinct n.id as id, n.name as name, n.equivalent_curies as equivalent_curies return id, name, equivalent_curies"
session = self.driver.session()
res = session.run(query)
drugs = pd.DataFrame(res.data())
Expand Down Expand Up @@ -359,8 +359,8 @@ def generate_SemmedData(self, mysqldump_path, output_path=os.getcwd()):
if __name__ == "__main__":
dataGenerator = DataGeneration()
drugs = dataGenerator.get_drug_curies_from_graph()
drugs.to_csv('/home/cqm5886/work/RTX/code/reasoningtool/MLDrugRepurposing/Test_graphsage/kg2_5_1/raw_training_data/drugs.txt',sep='\t',index=False)
drugs.to_csv('/home/cqm5886/work/RTX/code/reasoningtool/MLDrugRepurposing/Test_graphsage/kg2_6_3/raw_training_data/drugs.txt',sep='\t',index=False)
dataGenerator.generate_MyChemData(drugs=drugs, output_path='/home/cqm5886/work/RTX/code/reasoningtool/MLDrugRepurposing/Test_graphsage/kg2_5_1/raw_training_data/',dist=2)
## For semmedVER43_2020_R_PREDICATION.sql.gz, you might dowload from /data/orangeboard/databases/KG2.3.4/semmedVER43_2020_R_PREDICATION.sql.gz on arax.ncats.io server or directly download the latest one from semmedb website
# dataGenerator.generate_SemmedData(mysqldump_path='/home/cqm5886/work/RTX/code/reasoningtool/MLDrugRepurposing/Test_graphsage/semmedVER43_2020_R_PREDICATION.sql.gz', output_path='/home/cqm5886/work/RTX/code/reasoningtool/MLDrugRepurposing/Test_graphsage/kg2_5_1/raw_training_data/')
# For semmedVER43_2020_R_PREDICATION.sql.gz, you might dowload from /data/orangeboard/databases/KG2.3.4/semmedVER43_2020_R_PREDICATION.sql.gz on arax.ncats.io server or directly download the latest one from semmedb website
# dataGenerator.generate_SemmedData(mysqldump_path='/home/cqm5886/work/RTX/code/reasoningtool/MLDrugRepurposing/Test_graphsage/semmedVER43_2020_R_PREDICATION.sql.gz', output_path='/home/cqm5886/work/RTX/code/reasoningtool/MLDrugRepurposing/Test_graphsage/kg2_5_1/raw_training_data/')

Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@

######### Please ignore this part until Eric finds a better way to categorize these nodes with ambiguous node type ###########
# !Note: Before running the below code, please first check this DSL query, if there is returned value > 0, report error on github.
# !DSL query: match (z) where (('biolink:Disease' in z.all_categories or 'biolink:PhenotypicFeature' in z.all_categories or 'biolink:DiseaseOrPhenotypicFeature' in z.all_categories) and ('biolink:Drug' in z.all_categories or 'biolink:ChemicalSubstance' in z.all_categories)) return count(distinct z.id)
# !DSL query: match (z) where (('biolink:Disease' in z.all_categories or 'biolink:PhenotypicFeature' in z.all_categories or 'biolink:DiseaseOrPhenotypicFeature' in z.all_categories) and ('biolink:Drug' in z.all_categories or 'biolink:ChemicalSubstance' in z.all_categories or 'biolink:Metabolite' in z.all_categories)) return count(distinct z.id)
##############################################################################################################################


## Pull a dataframe of all of the graph edges excluding:
# the edges with one end node with all_categories including 'drug' and another end node with all_categories including 'disease'
# 'drug' here represents all nodes with cateory that is either 'biolink:Drug' or 'biolink:ChemicalSubstance'
# 'drug' here represents all nodes with cateory that is either 'biolink:Drug' or 'biolink:ChemicalSubstance' or 'biolink:Metabolite'
# 'disease' here represents all nodes with cateory that is either 'biolink:Disease'. 'biolink:PhenotypicFeature' or 'biolink:DiseaseOrPhenotypicFeature'
query = "match (disease) where (disease.category='biolink:Disease' or disease.category='biolink:PhenotypicFeature' or disease.category='biolink:DiseaseOrPhenotypicFeature') with collect(distinct disease.id) as disease_ids match (drug) where (drug.category='biolink:Drug' or drug.category='biolink:ChemicalSubstance') with collect(distinct drug.id) as drug_ids, disease_ids as disease_ids match (m1)-[]-(m2) where m1<>m2 and not (m1.id in drug_ids and m2.id in disease_ids) and not (m1.id in disease_ids and m2.id in drug_ids) with distinct m1 as node1, m2 as node2 return node1.id as source, node2.id as target"
query = "match (disease) where (disease.category='biolink:Disease' or disease.category='biolink:PhenotypicFeature' or disease.category='biolink:DiseaseOrPhenotypicFeature') with collect(distinct disease.id) as disease_ids match (drug) where (drug.category='biolink:Drug' or drug.category='biolink:ChemicalSubstance' or drug.category='biolink:Metabolite') with collect(distinct drug.id) as drug_ids, disease_ids as disease_ids match (m1)-[]-(m2) where m1<>m2 and not (m1.id in drug_ids and m2.id in disease_ids) and not (m1.id in disease_ids and m2.id in drug_ids) with distinct m1 as node1, m2 as node2 return node1.id as source, node2.id as target"
res = session.run(query)
KG2_alledges = pd.DataFrame(res.data())
KG2_alledges.to_csv(output_path + 'graph_edges.txt', sep='\t', index=None)
Expand All @@ -42,7 +42,7 @@
KG2_allnodes_label.to_csv(output_path + 'graph_nodes_label_remove_name.txt', sep='\t', index=None)

## Pulls a dataframe of all of the graph drug-associated nodes
query = f"match (n) where (n.category='biolink:Drug') or (n.category='biolink:ChemicalSubstance') with distinct n.id as id, n.name as name return id, name"
query = f"match (n) where (n.category='biolink:Drug') or (n.category='biolink:ChemicalSubstance') or (n.category='biolink:Metabolite') with distinct n.id as id, n.name as name return id, name"
res = session.run(query)
drugs = pd.DataFrame(res.data())
drugs.to_csv(output_path + 'drugs.txt', sep='\t', index=None)
Expand Down
32 changes: 32 additions & 0 deletions code/ARAX/Examples/kg2_api_example.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"message":{
"query_graph":{
"nodes":{
"n00":{
"id":"CHEMBL.COMPOUND:CHEMBL112",
"category":[
"biolink:Drug"
],
"is_set":false
},
"n01":{
"category":[
"biolink:Gene",
"biolink:Protein"
],
"is_set":false
}
},
"edges":{
"e00":{
"predicate":[
"biolink:interacts_with"
],
"subject":"n00",
"object":"n01",
"exclude":false
}
}
}
}
}
11 changes: 11 additions & 0 deletions code/ARAX/Examples/kg2_api_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env python3

import json
import pprint
import requests

with open("kg2_api_example.json", "r") as input_file:
trapi_message = json.load(input_file)
result = requests.post("https://arax.ncats.io/api/rtxkg2/v1.0/query?bypass_cache=false",
json=trapi_message)
pprint.pprint(result.json())
57 changes: 30 additions & 27 deletions code/ARAX/NodeSynonymizer/dump_kg2_node_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,34 @@ def dump_kg2_node_info(file_name: str, write_mode: str, is_test: bool):
"""
query = f"match (n) return properties(n) as p, labels(n) as l {'limit 20' if is_test else ''}"
res = _run_cypher_query(query)
with open(file_name, write_mode, encoding="utf-8") as fid:
for item in res:
prop_dict = item['p']
labels = item['l']
try:
label = list(set(labels) - {'Base'}).pop()
except:
label = ""
try:
fid.write('%s\t' % prop_dict['id'])
except:
fid.write('\t')
try:
fid.write('%s\t' % remove_tab_newlines.sub(" ", prop_dict['name'])) # better approach
except:
fid.write('\t')
try:
fid.write('%s\t' % remove_tab_newlines.sub(" ", prop_dict['full_name']))
except:
fid.write('\t')
try:
fid.write('%s\n' % label)
except:
fid.write('\n')
print(f"Successfully created file '{file_name}'.")
if res:
with open(file_name, write_mode, encoding="utf-8") as fid:
for item in res:
prop_dict = item['p']
labels = item['l']
try:
label = list(set(labels) - {'Base'}).pop()
except:
label = ""
try:
fid.write('%s\t' % prop_dict['id'])
except:
fid.write('\t')
try:
fid.write('%s\t' % remove_tab_newlines.sub(" ", prop_dict['name'])) # better approach
except:
fid.write('\t')
try:
fid.write('%s\t' % remove_tab_newlines.sub(" ", prop_dict['full_name']))
except:
fid.write('\t')
try:
fid.write('%s\n' % label)
except:
fid.write('\n')
print(f"Successfully created file '{file_name}'.")
else:
raise Exception(f"Failed to get results from Neo4j for {file_name}")
return


Expand Down Expand Up @@ -93,7 +96,7 @@ def dump_kg2_equivalencies(output_file_name: str, is_test: bool):
csv_writer.writerows(list(distinct_pairs))
print(f"Successfully created file '{output_file_name}'.")
else:
print(f"Sorry, couldn't get equivalency data. No file created.")
raise Exception(f"Failed to get results from Neo4j for {output_file_name}")


def dump_kg2_synonym_field(output_file_name: str, is_test: bool):
Expand All @@ -109,7 +112,7 @@ def dump_kg2_synonym_field(output_file_name: str, is_test: bool):
json.dump(synonym_map, output_file)
print(f"Successfully created file '{output_file_name}'.")
else:
print(f"Sorry, couldn't get synonym data. No file created.")
raise Exception(f"Failed to get results from Neo4j for {output_file_name}")


def main():
Expand Down
6 changes: 6 additions & 0 deletions code/UI/interactive/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@
<div class="page">
<br><br><br><br>

<div style="border-left-color:#c40;" class="statushead error">TRAPI 1.0</div>
<div class="status">
<h3 style="display:inline-block;">This is the interface based on <b>TRAPI 1.0</b>. Click <a href="/NewFmt/">here to go to the TRAPI 1.1-based interface</a></h3>
</div><br><br>


<div class="pagesection" id="historyDiv">
<div class="statushead">Session History (<span id="numlistitemsSESSION">-</span>)</div>
<div class="status" id="listdivSESSION"><br>Your query history will be displayed here. It can be edited or re-set.<br><br></div>
Expand Down
Loading

0 comments on commit 9f0c6e4

Please sign in to comment.