-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
31 changed files
with
10,662 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,72 @@ | ||
# CSSE-DDI | ||
# Customized Subgraph Selection and Encoding for Drug-drug Interaction Prediction | ||
|
||
<p align="left"> | ||
<a href="https://neurips.cc/virtual/2024/poster/94377"><img src="https://img.shields.io/badge/NeurIPS%202024-Poster-brightgreen.svg" alt="neurips paper"> | ||
</p> | ||
|
||
--- | ||
|
||
## Requirements | ||
|
||
```sheel | ||
torch==1.13.0 | ||
dgl-cu111==0.6.1 | ||
optuna==3.2.0 | ||
``` | ||
|
||
## Run | ||
|
||
### Unpack Dataset | ||
```shell | ||
unzip datasets.zip | ||
``` | ||
|
||
### Supernet Training | ||
```shell | ||
python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \ | ||
--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \ | ||
--loss_type ce --dataset drugbank --ss_search_algorithm snas | ||
``` | ||
### Sub-Supernet Training | ||
```shell | ||
python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \ | ||
--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \ | ||
--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op rotate --weight_sharing --ss_search_algorithm snas | ||
|
||
python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \ | ||
--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \ | ||
--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op ccorr --weight_sharing --ss_search_algorithm snas | ||
|
||
python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \ | ||
--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \ | ||
--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op mult --weight_sharing --ss_search_algorithm snas | ||
|
||
python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \ | ||
--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \ | ||
--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op sub --weight_sharing --ss_search_algorithm snas | ||
``` | ||
### Subgraph Selection and Encoding Function Searching | ||
```shell | ||
python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \ | ||
--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \ | ||
--loss_type ce --dataset drugbank --exp_note spfs --weight_sharing --ss_search_algorithm snas --arch_search_mode ng | ||
``` | ||
## Citation | ||
|
||
Readers are welcomed to follow our work. Please kindly cite our paper: | ||
|
||
```bibtex | ||
@inproceedings{du2024customized, | ||
title={Customized Subgraph Selection and Encoding for Drug-drug Interaction Prediction}, | ||
author={Du, Haotong and Yao, Quanming and Zhang, Juzheng and Liu, Yang and Wang, Zhen}, | ||
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, | ||
year={2024} | ||
} | ||
``` | ||
|
||
## Contact | ||
If you have any questions, feel free to contact me at [[email protected]](mailto:[email protected]). | ||
|
||
## Acknowledgement | ||
|
||
The codes of this paper are partially based on the codes of [SEAL_dgl](https://github.com/Smilexuhc/SEAL_dgl), [PS2](https://github.com/qiaoyu-tan/PS2), and [Interstellar](https://github.com/LARS-research/Interstellar). We thank the authors of above work. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
{ | ||
"version": 1, | ||
"root": { | ||
"handlers": [ | ||
"console_handler", | ||
"file_handler" | ||
], | ||
"level": "DEBUG" | ||
}, | ||
"handlers": { | ||
"console_handler": { | ||
"class": "logging.StreamHandler", | ||
"level": "DEBUG", | ||
"formatter": "console_formatter" | ||
}, | ||
"file_handler": { | ||
"class": "logging.FileHandler", | ||
"level": "DEBUG", | ||
"formatter": "file_formatter", | ||
"filename": "python_logging.log", | ||
"encoding": "utf8", | ||
"mode": "w" | ||
} | ||
}, | ||
"formatters": { | ||
"console_formatter": { | ||
"format": "%(asctime)s - %(message)s", | ||
"datefmt": "%Y-%m-%d %H:%M:%S" | ||
}, | ||
"file_formatter": { | ||
"format": "%(asctime)s - %(message)s", | ||
"datefmt": "%Y-%m-%d %H:%M:%S" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
""" | ||
based on the implementation in DGL | ||
(https://github.com/dmlc/dgl/blob/master/python/dgl/contrib/data/knowledge_graph.py) | ||
Knowledge graph dataset for Relational-GCN | ||
Code adapted from authors' implementation of Relational-GCN | ||
https://github.com/tkipf/relational-gcn | ||
https://github.com/MichSchli/RelationPrediction | ||
""" | ||
|
||
from __future__ import print_function | ||
from __future__ import absolute_import | ||
import numpy as np | ||
import scipy.sparse as sp | ||
import os | ||
|
||
from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url | ||
from utils.dgl_utils import process_files_ddi | ||
from utils.graph_utils import incidence_matrix | ||
|
||
# np.random.seed(123) | ||
|
||
_downlaod_prefix = _get_dgl_url('dataset/') | ||
|
||
|
||
def load_data(dataset): | ||
if dataset in ['drugbank', 'twosides', 'twosides_200', 'drugbank_s1', 'twosides_s1']: | ||
return load_link(dataset) | ||
else: | ||
raise ValueError('Unknown dataset: {}'.format(dataset)) | ||
|
||
|
||
class RGCNLinkDataset(object): | ||
|
||
def __init__(self, name): | ||
self.name = name | ||
self.dir = 'datasets' | ||
|
||
# zip_path = os.path.join(self.dir, '{}.zip'.format(self.name)) | ||
self.dir = os.path.join(self.dir, self.name) | ||
# extract_archive(zip_path, self.dir) | ||
|
||
def load(self): | ||
entity_path = os.path.join(self.dir, 'entities.dict') | ||
relation_path = os.path.join(self.dir, 'relations.dict') | ||
train_path = os.path.join(self.dir, 'train.txt') | ||
valid_path = os.path.join(self.dir, 'valid.txt') | ||
test_path = os.path.join(self.dir, 'test.txt') | ||
entity_dict = _read_dictionary(entity_path) | ||
relation_dict = _read_dictionary(relation_path) | ||
self.train = np.asarray(_read_triplets_as_list( | ||
train_path, entity_dict, relation_dict)) | ||
self.valid = np.asarray(_read_triplets_as_list( | ||
valid_path, entity_dict, relation_dict)) | ||
self.test = np.asarray(_read_triplets_as_list( | ||
test_path, entity_dict, relation_dict)) | ||
self.num_nodes = len(entity_dict) | ||
print("# entities: {}".format(self.num_nodes)) | ||
self.num_rels = len(relation_dict) | ||
print("# relations: {}".format(self.num_rels)) | ||
print("# training sample: {}".format(len(self.train))) | ||
print("# valid sample: {}".format(len(self.valid))) | ||
print("# testing sample: {}".format(len(self.test))) | ||
file_paths = { | ||
'train': f'{self.dir}/train_raw.txt', | ||
'valid': f'{self.dir}/dev_raw.txt', | ||
'test': f'{self.dir}/test_raw.txt' | ||
} | ||
external_kg_file = f'{self.dir}/external_kg.txt' | ||
adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel = process_files_ddi(file_paths, external_kg_file) | ||
A_incidence = incidence_matrix(adj_list) | ||
A_incidence += A_incidence.T | ||
self.adj = A_incidence | ||
|
||
|
||
|
||
def load_link(dataset): | ||
if 'twosides' in dataset or 'ogbl_biokg' in dataset: | ||
data = MultiLabelDataset(dataset) | ||
else: | ||
data = RGCNLinkDataset(dataset) | ||
data.load() | ||
return data | ||
|
||
|
||
def _read_dictionary(filename): | ||
d = {} | ||
with open(filename, 'r+') as f: | ||
for line in f: | ||
line = line.strip().split('\t') | ||
d[line[1]] = int(line[0]) | ||
return d | ||
|
||
|
||
def _read_triplets(filename): | ||
with open(filename, 'r+') as f: | ||
for line in f: | ||
processed_line = line.strip().split('\t') | ||
yield processed_line | ||
|
||
|
||
def _read_triplets_as_list(filename, entity_dict, relation_dict): | ||
l = [] | ||
for triplet in _read_triplets(filename): | ||
s = entity_dict[triplet[0]] | ||
r = relation_dict[triplet[1]] | ||
o = entity_dict[triplet[2]] | ||
l.append([s, r, o]) | ||
return l | ||
|
||
|
||
def _read_multi_rel_triplets(filename): | ||
with open(filename, 'r+') as f: | ||
for line in f: | ||
processed_line = line.strip().split('\t') | ||
yield processed_line | ||
|
||
def _read_multi_rel_triplets_as_array(filename, entity_dict): | ||
graph_list = [] | ||
input_list = [] | ||
multi_label_list = [] | ||
pos_neg_list = [] | ||
for triplet in _read_triplets(filename): | ||
s = entity_dict[triplet[0]] | ||
o = entity_dict[triplet[1]] | ||
r_list = list(map(int, triplet[2].split(','))) | ||
multi_label_list.append(r_list) | ||
r_label = [i for i, _ in enumerate(r_list) if _ == 1] | ||
for r in r_label: | ||
graph_list.append([s, r, o]) | ||
input_list.append([s, -1, o]) | ||
pos_neg = int(triplet[3]) | ||
pos_neg_list.append(pos_neg) | ||
return np.asarray(graph_list), np.asarray(input_list), np.asarray(multi_label_list), np.asarray(pos_neg_list) | ||
|
||
class MultiLabelDataset(object): | ||
def __init__(self, name): | ||
self.name = name | ||
self.dir = 'datasets' | ||
|
||
# zip_path = os.path.join(self.dir, '{}.zip'.format(self.name)) | ||
self.dir = os.path.join(self.dir, self.name) | ||
# extract_archive(zip_path, self.dir) | ||
|
||
def load(self): | ||
entity_path = os.path.join(self.dir, 'entities.dict') | ||
train_path = os.path.join(self.dir, 'train.txt') | ||
valid_path = os.path.join(self.dir, 'valid.txt') | ||
test_path = os.path.join(self.dir, 'test.txt') | ||
entity_dict = _read_dictionary(entity_path) | ||
self.train_graph, self.train_input, self.train_multi_label, self.train_pos_neg = _read_multi_rel_triplets_as_array( | ||
train_path, entity_dict) | ||
_, self.valid_input, self.valid_multi_label, self.valid_pos_neg = _read_multi_rel_triplets_as_array( | ||
valid_path, entity_dict) | ||
_, self.test_input, self.test_multi_label, self.test_pos_neg = _read_multi_rel_triplets_as_array( | ||
test_path, entity_dict) | ||
self.num_nodes = len(entity_dict) | ||
print("# entities: {}".format(self.num_nodes)) | ||
self.num_rels = self.train_multi_label.shape[1] | ||
print("# relations: {}".format(self.num_rels)) | ||
print("# training sample: {}".format(self.train_input.shape[0])) | ||
print("# valid sample: {}".format(self.valid_input.shape[0])) | ||
print("# testing sample: {}".format(self.test_input.shape[0])) | ||
# print("# training sample: {}".format(len(self.train))) | ||
# print("# valid sample: {}".format(len(self.valid))) | ||
# print("# testing sample: {}".format(len(self.test))) | ||
# file_paths = { | ||
# 'train': f'{self.dir}/train_raw.txt', | ||
# 'valid': f'{self.dir}/dev_raw.txt', | ||
# 'test': f'{self.dir}/test_raw.txt' | ||
# } | ||
# external_kg_file = f'{self.dir}/external_kg.txt' | ||
# adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel = process_files_ddi(file_paths, | ||
# external_kg_file) | ||
# A_incidence = incidence_matrix(adj_list) | ||
# A_incidence += A_incidence.T | ||
# self.adj = A_incidence |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .gcns import GCN_TransE, GCN_DistMult, GCN_ConvE, GCN_ConvE_Rel, GCN_Transformer, GCN_None, GCN_MLP, GCN_MLP_NCN | ||
from .subgraph_selector import SubgraphSelector | ||
from .model_search import SearchGCN_MLP | ||
from .model import SearchedGCN_MLP | ||
from .model_fast import NetworkGNN_MLP | ||
from .model_spos import SearchGCN_MLP_SPOS | ||
from .seal_model import SEAL_GCN |
Oops, something went wrong.