Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
striderdu committed Dec 2, 2024
1 parent d2ed47d commit 5d22bd2
Show file tree
Hide file tree
Showing 31 changed files with 10,662 additions and 1 deletion.
73 changes: 72 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,72 @@
# CSSE-DDI
# Customized Subgraph Selection and Encoding for Drug-drug Interaction Prediction

<p align="left">
<a href="https://neurips.cc/virtual/2024/poster/94377"><img src="https://img.shields.io/badge/NeurIPS%202024-Poster-brightgreen.svg" alt="neurips paper">
</p>

---

## Requirements

```sheel
torch==1.13.0
dgl-cu111==0.6.1
optuna==3.2.0
```

## Run

### Unpack Dataset
```shell
unzip datasets.zip
```

### Supernet Training
```shell
python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \
--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \
--loss_type ce --dataset drugbank --ss_search_algorithm snas
```
### Sub-Supernet Training
```shell
python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \
--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \
--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op rotate --weight_sharing --ss_search_algorithm snas

python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \
--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \
--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op ccorr --weight_sharing --ss_search_algorithm snas

python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \
--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \
--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op mult --weight_sharing --ss_search_algorithm snas

python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \
--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \
--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op sub --weight_sharing --ss_search_algorithm snas
```
### Subgraph Selection and Encoding Function Searching
```shell
python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \
--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \
--loss_type ce --dataset drugbank --exp_note spfs --weight_sharing --ss_search_algorithm snas --arch_search_mode ng
```
## Citation

Readers are welcomed to follow our work. Please kindly cite our paper:

```bibtex
@inproceedings{du2024customized,
title={Customized Subgraph Selection and Encoding for Drug-drug Interaction Prediction},
author={Du, Haotong and Yao, Quanming and Zhang, Juzheng and Liu, Yang and Wang, Zhen},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024}
}
```

## Contact
If you have any questions, feel free to contact me at [[email protected]](mailto:[email protected]).

## Acknowledgement

The codes of this paper are partially based on the codes of [SEAL_dgl](https://github.com/Smilexuhc/SEAL_dgl), [PS2](https://github.com/qiaoyu-tan/PS2), and [Interstellar](https://github.com/LARS-research/Interstellar). We thank the authors of above work.
35 changes: 35 additions & 0 deletions config/logger_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"version": 1,
"root": {
"handlers": [
"console_handler",
"file_handler"
],
"level": "DEBUG"
},
"handlers": {
"console_handler": {
"class": "logging.StreamHandler",
"level": "DEBUG",
"formatter": "console_formatter"
},
"file_handler": {
"class": "logging.FileHandler",
"level": "DEBUG",
"formatter": "file_formatter",
"filename": "python_logging.log",
"encoding": "utf8",
"mode": "w"
}
},
"formatters": {
"console_formatter": {
"format": "%(asctime)s - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S"
},
"file_formatter": {
"format": "%(asctime)s - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S"
}
}
}
176 changes: 176 additions & 0 deletions data/knowledge_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""
based on the implementation in DGL
(https://github.com/dmlc/dgl/blob/master/python/dgl/contrib/data/knowledge_graph.py)
Knowledge graph dataset for Relational-GCN
Code adapted from authors' implementation of Relational-GCN
https://github.com/tkipf/relational-gcn
https://github.com/MichSchli/RelationPrediction
"""

from __future__ import print_function
from __future__ import absolute_import
import numpy as np
import scipy.sparse as sp
import os

from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
from utils.dgl_utils import process_files_ddi
from utils.graph_utils import incidence_matrix

# np.random.seed(123)

_downlaod_prefix = _get_dgl_url('dataset/')


def load_data(dataset):
if dataset in ['drugbank', 'twosides', 'twosides_200', 'drugbank_s1', 'twosides_s1']:
return load_link(dataset)
else:
raise ValueError('Unknown dataset: {}'.format(dataset))


class RGCNLinkDataset(object):

def __init__(self, name):
self.name = name
self.dir = 'datasets'

# zip_path = os.path.join(self.dir, '{}.zip'.format(self.name))
self.dir = os.path.join(self.dir, self.name)
# extract_archive(zip_path, self.dir)

def load(self):
entity_path = os.path.join(self.dir, 'entities.dict')
relation_path = os.path.join(self.dir, 'relations.dict')
train_path = os.path.join(self.dir, 'train.txt')
valid_path = os.path.join(self.dir, 'valid.txt')
test_path = os.path.join(self.dir, 'test.txt')
entity_dict = _read_dictionary(entity_path)
relation_dict = _read_dictionary(relation_path)
self.train = np.asarray(_read_triplets_as_list(
train_path, entity_dict, relation_dict))
self.valid = np.asarray(_read_triplets_as_list(
valid_path, entity_dict, relation_dict))
self.test = np.asarray(_read_triplets_as_list(
test_path, entity_dict, relation_dict))
self.num_nodes = len(entity_dict)
print("# entities: {}".format(self.num_nodes))
self.num_rels = len(relation_dict)
print("# relations: {}".format(self.num_rels))
print("# training sample: {}".format(len(self.train)))
print("# valid sample: {}".format(len(self.valid)))
print("# testing sample: {}".format(len(self.test)))
file_paths = {
'train': f'{self.dir}/train_raw.txt',
'valid': f'{self.dir}/dev_raw.txt',
'test': f'{self.dir}/test_raw.txt'
}
external_kg_file = f'{self.dir}/external_kg.txt'
adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel = process_files_ddi(file_paths, external_kg_file)
A_incidence = incidence_matrix(adj_list)
A_incidence += A_incidence.T
self.adj = A_incidence



def load_link(dataset):
if 'twosides' in dataset or 'ogbl_biokg' in dataset:
data = MultiLabelDataset(dataset)
else:
data = RGCNLinkDataset(dataset)
data.load()
return data


def _read_dictionary(filename):
d = {}
with open(filename, 'r+') as f:
for line in f:
line = line.strip().split('\t')
d[line[1]] = int(line[0])
return d


def _read_triplets(filename):
with open(filename, 'r+') as f:
for line in f:
processed_line = line.strip().split('\t')
yield processed_line


def _read_triplets_as_list(filename, entity_dict, relation_dict):
l = []
for triplet in _read_triplets(filename):
s = entity_dict[triplet[0]]
r = relation_dict[triplet[1]]
o = entity_dict[triplet[2]]
l.append([s, r, o])
return l


def _read_multi_rel_triplets(filename):
with open(filename, 'r+') as f:
for line in f:
processed_line = line.strip().split('\t')
yield processed_line

def _read_multi_rel_triplets_as_array(filename, entity_dict):
graph_list = []
input_list = []
multi_label_list = []
pos_neg_list = []
for triplet in _read_triplets(filename):
s = entity_dict[triplet[0]]
o = entity_dict[triplet[1]]
r_list = list(map(int, triplet[2].split(',')))
multi_label_list.append(r_list)
r_label = [i for i, _ in enumerate(r_list) if _ == 1]
for r in r_label:
graph_list.append([s, r, o])
input_list.append([s, -1, o])
pos_neg = int(triplet[3])
pos_neg_list.append(pos_neg)
return np.asarray(graph_list), np.asarray(input_list), np.asarray(multi_label_list), np.asarray(pos_neg_list)

class MultiLabelDataset(object):
def __init__(self, name):
self.name = name
self.dir = 'datasets'

# zip_path = os.path.join(self.dir, '{}.zip'.format(self.name))
self.dir = os.path.join(self.dir, self.name)
# extract_archive(zip_path, self.dir)

def load(self):
entity_path = os.path.join(self.dir, 'entities.dict')
train_path = os.path.join(self.dir, 'train.txt')
valid_path = os.path.join(self.dir, 'valid.txt')
test_path = os.path.join(self.dir, 'test.txt')
entity_dict = _read_dictionary(entity_path)
self.train_graph, self.train_input, self.train_multi_label, self.train_pos_neg = _read_multi_rel_triplets_as_array(
train_path, entity_dict)
_, self.valid_input, self.valid_multi_label, self.valid_pos_neg = _read_multi_rel_triplets_as_array(
valid_path, entity_dict)
_, self.test_input, self.test_multi_label, self.test_pos_neg = _read_multi_rel_triplets_as_array(
test_path, entity_dict)
self.num_nodes = len(entity_dict)
print("# entities: {}".format(self.num_nodes))
self.num_rels = self.train_multi_label.shape[1]
print("# relations: {}".format(self.num_rels))
print("# training sample: {}".format(self.train_input.shape[0]))
print("# valid sample: {}".format(self.valid_input.shape[0]))
print("# testing sample: {}".format(self.test_input.shape[0]))
# print("# training sample: {}".format(len(self.train)))
# print("# valid sample: {}".format(len(self.valid)))
# print("# testing sample: {}".format(len(self.test)))
# file_paths = {
# 'train': f'{self.dir}/train_raw.txt',
# 'valid': f'{self.dir}/dev_raw.txt',
# 'test': f'{self.dir}/test_raw.txt'
# }
# external_kg_file = f'{self.dir}/external_kg.txt'
# adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel = process_files_ddi(file_paths,
# external_kg_file)
# A_incidence = incidence_matrix(adj_list)
# A_incidence += A_incidence.T
# self.adj = A_incidence
Binary file added datasets.zip
Binary file not shown.
7 changes: 7 additions & 0 deletions model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .gcns import GCN_TransE, GCN_DistMult, GCN_ConvE, GCN_ConvE_Rel, GCN_Transformer, GCN_None, GCN_MLP, GCN_MLP_NCN
from .subgraph_selector import SubgraphSelector
from .model_search import SearchGCN_MLP
from .model import SearchedGCN_MLP
from .model_fast import NetworkGNN_MLP
from .model_spos import SearchGCN_MLP_SPOS
from .seal_model import SEAL_GCN
Loading

0 comments on commit 5d22bd2

Please sign in to comment.