diff --git a/README.md b/README.md index b7164c5..ec3ae50 100644 --- a/README.md +++ b/README.md @@ -1 +1,72 @@ -# CSSE-DDI \ No newline at end of file +# Customized Subgraph Selection and Encoding for Drug-drug Interaction Prediction + +

+neurips paper +

+ +--- + +## Requirements + +```sheel +torch==1.13.0 +dgl-cu111==0.6.1 +optuna==3.2.0 +``` + +## Run + +### Unpack Dataset +```shell +unzip datasets.zip +``` + +### Supernet Training +```shell +python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \ +--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \ +--loss_type ce --dataset drugbank --ss_search_algorithm snas +``` +### Sub-Supernet Training +```shell +python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \ +--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \ +--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op rotate --weight_sharing --ss_search_algorithm snas + +python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \ +--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \ +--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op ccorr --weight_sharing --ss_search_algorithm snas + +python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \ +--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \ +--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op mult --weight_sharing --ss_search_algorithm snas + +python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \ +--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \ +--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op sub --weight_sharing --ss_search_algorithm snas +``` +### Subgraph Selection and Encoding Function Searching +```shell +python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \ +--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \ +--loss_type ce --dataset drugbank --exp_note spfs --weight_sharing --ss_search_algorithm snas --arch_search_mode ng +``` +## Citation + +Readers are welcomed to follow our work. Please kindly cite our paper: + +```bibtex +@inproceedings{du2024customized, + title={Customized Subgraph Selection and Encoding for Drug-drug Interaction Prediction}, + author={Du, Haotong and Yao, Quanming and Zhang, Juzheng and Liu, Yang and Wang, Zhen}, + booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, + year={2024} +} +``` + +## Contact +If you have any questions, feel free to contact me at [duhaotong@mail.nwpu.edu.cn](mailto:duhaotong@mail.nwpu.edu.cn). + +## Acknowledgement + +The codes of this paper are partially based on the codes of [SEAL_dgl](https://github.com/Smilexuhc/SEAL_dgl), [PS2](https://github.com/qiaoyu-tan/PS2), and [Interstellar](https://github.com/LARS-research/Interstellar). We thank the authors of above work. diff --git a/config/logger_config.json b/config/logger_config.json new file mode 100644 index 0000000..f4d5547 --- /dev/null +++ b/config/logger_config.json @@ -0,0 +1,35 @@ +{ + "version": 1, + "root": { + "handlers": [ + "console_handler", + "file_handler" + ], + "level": "DEBUG" + }, + "handlers": { + "console_handler": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "console_formatter" + }, + "file_handler": { + "class": "logging.FileHandler", + "level": "DEBUG", + "formatter": "file_formatter", + "filename": "python_logging.log", + "encoding": "utf8", + "mode": "w" + } + }, + "formatters": { + "console_formatter": { + "format": "%(asctime)s - %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S" + }, + "file_formatter": { + "format": "%(asctime)s - %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S" + } + } +} \ No newline at end of file diff --git a/data/knowledge_graph.py b/data/knowledge_graph.py new file mode 100644 index 0000000..29aec3b --- /dev/null +++ b/data/knowledge_graph.py @@ -0,0 +1,176 @@ +""" +based on the implementation in DGL +(https://github.com/dmlc/dgl/blob/master/python/dgl/contrib/data/knowledge_graph.py) +Knowledge graph dataset for Relational-GCN +Code adapted from authors' implementation of Relational-GCN +https://github.com/tkipf/relational-gcn +https://github.com/MichSchli/RelationPrediction +""" + +from __future__ import print_function +from __future__ import absolute_import +import numpy as np +import scipy.sparse as sp +import os + +from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url +from utils.dgl_utils import process_files_ddi +from utils.graph_utils import incidence_matrix + +# np.random.seed(123) + +_downlaod_prefix = _get_dgl_url('dataset/') + + +def load_data(dataset): + if dataset in ['drugbank', 'twosides', 'twosides_200', 'drugbank_s1', 'twosides_s1']: + return load_link(dataset) + else: + raise ValueError('Unknown dataset: {}'.format(dataset)) + + +class RGCNLinkDataset(object): + + def __init__(self, name): + self.name = name + self.dir = 'datasets' + + # zip_path = os.path.join(self.dir, '{}.zip'.format(self.name)) + self.dir = os.path.join(self.dir, self.name) + # extract_archive(zip_path, self.dir) + + def load(self): + entity_path = os.path.join(self.dir, 'entities.dict') + relation_path = os.path.join(self.dir, 'relations.dict') + train_path = os.path.join(self.dir, 'train.txt') + valid_path = os.path.join(self.dir, 'valid.txt') + test_path = os.path.join(self.dir, 'test.txt') + entity_dict = _read_dictionary(entity_path) + relation_dict = _read_dictionary(relation_path) + self.train = np.asarray(_read_triplets_as_list( + train_path, entity_dict, relation_dict)) + self.valid = np.asarray(_read_triplets_as_list( + valid_path, entity_dict, relation_dict)) + self.test = np.asarray(_read_triplets_as_list( + test_path, entity_dict, relation_dict)) + self.num_nodes = len(entity_dict) + print("# entities: {}".format(self.num_nodes)) + self.num_rels = len(relation_dict) + print("# relations: {}".format(self.num_rels)) + print("# training sample: {}".format(len(self.train))) + print("# valid sample: {}".format(len(self.valid))) + print("# testing sample: {}".format(len(self.test))) + file_paths = { + 'train': f'{self.dir}/train_raw.txt', + 'valid': f'{self.dir}/dev_raw.txt', + 'test': f'{self.dir}/test_raw.txt' + } + external_kg_file = f'{self.dir}/external_kg.txt' + adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel = process_files_ddi(file_paths, external_kg_file) + A_incidence = incidence_matrix(adj_list) + A_incidence += A_incidence.T + self.adj = A_incidence + + + +def load_link(dataset): + if 'twosides' in dataset or 'ogbl_biokg' in dataset: + data = MultiLabelDataset(dataset) + else: + data = RGCNLinkDataset(dataset) + data.load() + return data + + +def _read_dictionary(filename): + d = {} + with open(filename, 'r+') as f: + for line in f: + line = line.strip().split('\t') + d[line[1]] = int(line[0]) + return d + + +def _read_triplets(filename): + with open(filename, 'r+') as f: + for line in f: + processed_line = line.strip().split('\t') + yield processed_line + + +def _read_triplets_as_list(filename, entity_dict, relation_dict): + l = [] + for triplet in _read_triplets(filename): + s = entity_dict[triplet[0]] + r = relation_dict[triplet[1]] + o = entity_dict[triplet[2]] + l.append([s, r, o]) + return l + + +def _read_multi_rel_triplets(filename): + with open(filename, 'r+') as f: + for line in f: + processed_line = line.strip().split('\t') + yield processed_line + +def _read_multi_rel_triplets_as_array(filename, entity_dict): + graph_list = [] + input_list = [] + multi_label_list = [] + pos_neg_list = [] + for triplet in _read_triplets(filename): + s = entity_dict[triplet[0]] + o = entity_dict[triplet[1]] + r_list = list(map(int, triplet[2].split(','))) + multi_label_list.append(r_list) + r_label = [i for i, _ in enumerate(r_list) if _ == 1] + for r in r_label: + graph_list.append([s, r, o]) + input_list.append([s, -1, o]) + pos_neg = int(triplet[3]) + pos_neg_list.append(pos_neg) + return np.asarray(graph_list), np.asarray(input_list), np.asarray(multi_label_list), np.asarray(pos_neg_list) + +class MultiLabelDataset(object): + def __init__(self, name): + self.name = name + self.dir = 'datasets' + + # zip_path = os.path.join(self.dir, '{}.zip'.format(self.name)) + self.dir = os.path.join(self.dir, self.name) + # extract_archive(zip_path, self.dir) + + def load(self): + entity_path = os.path.join(self.dir, 'entities.dict') + train_path = os.path.join(self.dir, 'train.txt') + valid_path = os.path.join(self.dir, 'valid.txt') + test_path = os.path.join(self.dir, 'test.txt') + entity_dict = _read_dictionary(entity_path) + self.train_graph, self.train_input, self.train_multi_label, self.train_pos_neg = _read_multi_rel_triplets_as_array( + train_path, entity_dict) + _, self.valid_input, self.valid_multi_label, self.valid_pos_neg = _read_multi_rel_triplets_as_array( + valid_path, entity_dict) + _, self.test_input, self.test_multi_label, self.test_pos_neg = _read_multi_rel_triplets_as_array( + test_path, entity_dict) + self.num_nodes = len(entity_dict) + print("# entities: {}".format(self.num_nodes)) + self.num_rels = self.train_multi_label.shape[1] + print("# relations: {}".format(self.num_rels)) + print("# training sample: {}".format(self.train_input.shape[0])) + print("# valid sample: {}".format(self.valid_input.shape[0])) + print("# testing sample: {}".format(self.test_input.shape[0])) + # print("# training sample: {}".format(len(self.train))) + # print("# valid sample: {}".format(len(self.valid))) + # print("# testing sample: {}".format(len(self.test))) + # file_paths = { + # 'train': f'{self.dir}/train_raw.txt', + # 'valid': f'{self.dir}/dev_raw.txt', + # 'test': f'{self.dir}/test_raw.txt' + # } + # external_kg_file = f'{self.dir}/external_kg.txt' + # adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel = process_files_ddi(file_paths, + # external_kg_file) + # A_incidence = incidence_matrix(adj_list) + # A_incidence += A_incidence.T + # self.adj = A_incidence \ No newline at end of file diff --git a/datasets.zip b/datasets.zip new file mode 100644 index 0000000..44d839b Binary files /dev/null and b/datasets.zip differ diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..07b73fb --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,7 @@ +from .gcns import GCN_TransE, GCN_DistMult, GCN_ConvE, GCN_ConvE_Rel, GCN_Transformer, GCN_None, GCN_MLP, GCN_MLP_NCN +from .subgraph_selector import SubgraphSelector +from .model_search import SearchGCN_MLP +from .model import SearchedGCN_MLP +from .model_fast import NetworkGNN_MLP +from .model_spos import SearchGCN_MLP_SPOS +from .seal_model import SEAL_GCN diff --git a/model/compgcn_layer.py b/model/compgcn_layer.py new file mode 100644 index 0000000..777da9c --- /dev/null +++ b/model/compgcn_layer.py @@ -0,0 +1,158 @@ +import torch +from torch import nn +import dgl +import dgl.function as fn + + +class CompGCNCov(nn.Module): + def __init__(self, in_channels, out_channels, act=lambda x: x, bias=True, drop_rate=0., opn='corr', num_base=-1, + num_rel=None, wni=False, wsi=False, use_bn=True, ltr=True, add_reverse=True): + super(CompGCNCov, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.act = act # activation function + self.device = None + if add_reverse: + self.rel = nn.Parameter(torch.empty([num_rel * 2, in_channels], dtype=torch.float)) + else: + self.rel = nn.Parameter(torch.empty([num_rel, in_channels], dtype=torch.float)) + self.opn = opn + + self.use_bn = use_bn + self.ltr = ltr + + # relation-type specific parameter + self.in_w = self.get_param([in_channels, out_channels]) + self.out_w = self.get_param([in_channels, out_channels]) + self.loop_w = self.get_param([in_channels, out_channels]) + # transform embedding of relations to next layer + self.w_rel = self.get_param([in_channels, out_channels]) + self.loop_rel = self.get_param([1, in_channels]) # self-loop embedding + + self.drop = nn.Dropout(drop_rate) + self.bn = torch.nn.BatchNorm1d(out_channels) + self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None + if num_base > 0: + if add_reverse: + self.rel_wt = self.get_param([num_rel * 2, num_base]) + else: + self.rel_wt = self.get_param([num_rel, num_base]) + else: + self.rel_wt = None + + self.wni = wni + self.wsi = wsi + + def get_param(self, shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu')) + return param + + def message_func(self, edges): + edge_type = edges.data['type'] # [E, 1] + edge_num = edge_type.shape[0] + edge_data = self.comp( + edges.src['h'], self.rel[edge_type]) # [E, in_channel] + # NOTE: first half edges are all in-directions, last half edges are out-directions. + msg = torch.cat([torch.matmul(edge_data[:edge_num // 2, :], self.in_w), + torch.matmul(edge_data[edge_num // 2:, :], self.out_w)]) + msg = msg * edges.data['norm'].reshape(-1, 1) # [E, D] * [E, 1] + return {'msg': msg} + + def reduce_func(self, nodes): + return {'h': self.drop(nodes.data['h'])} + + def comp(self, h, edge_data): + # def com_mult(a, b): + # r1, i1 = a[..., 0], a[..., 1] + # r2, i2 = b[..., 0], b[..., 1] + # return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1) + # + # def conj(a): + # a[..., 1] = -a[..., 1] + # return a + # + # def ccorr(a, b): + # # return torch.irfft(com_mult(conj(torch.rfft(a, 1)), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],)) + # return torch.fft.irfftn(torch.conj(torch.fft.rfftn(a, (-1))) * torch.fft.rfftn(b, (-1)), (-1)) + + def com_mult(a, b): + r1, i1 = a.real, a.imag + r2, i2 = b.real, b.imag + real = r1 * r2 - i1 * i2 + imag = r1 * i2 + i1 * r2 + return torch.complex(real, imag) + + def conj(a): + a.imag = -a.imag + return a + + def ccorr(a, b): + return torch.fft.irfft(com_mult(conj(torch.fft.rfft(a)), torch.fft.rfft(b)), a.shape[-1]) + + def rotate(h, r): + # re: first half, im: second half + # assume embedding dim is the last dimension + d = h.shape[-1] + h_re, h_im = torch.split(h, d // 2, -1) + r_re, r_im = torch.split(r, d // 2, -1) + return torch.cat([h_re * r_re - h_im * r_im, + h_re * r_im + h_im * r_re], dim=-1) + + if self.opn == 'mult': + return h * edge_data + elif self.opn == 'sub': + return h - edge_data + elif self.opn == 'add': + return h + edge_data + elif self.opn == 'corr': + return ccorr(h, edge_data.expand_as(h)) + elif self.opn == 'rotate': + return rotate(h, edge_data) + else: + raise KeyError(f'composition operator {self.opn} not recognized.') + + def forward(self, g: dgl.DGLGraph, x, rel_repr, edge_type, edge_norm): + """ + :param g: dgl Graph, a graph without self-loop + :param x: input node features, [V, in_channel] + :param rel_repr: input relation features: 1. not using bases: [num_rel*2, in_channel] + 2. using bases: [num_base, in_channel] + :param edge_type: edge type, [E] + :param edge_norm: edge normalization, [E] + :return: x: output node features: [V, out_channel] + rel: output relation features: [num_rel*2, out_channel] + """ + self.device = x.device + g = g.local_var() + g.ndata['h'] = x + g.edata['type'] = edge_type + g.edata['norm'] = edge_norm + # print(self.rel.data) + if self.rel_wt is None: + self.rel.data = rel_repr + else: + # [num_rel*2, num_base] @ [num_base, in_c] + self.rel.data = torch.mm(self.rel_wt, rel_repr) + g.update_all(self.message_func, fn.sum( + msg='msg', out='h'), self.reduce_func) + + if (not self.wni) and (not self.wsi): + x = (g.ndata.pop('h') + + torch.mm(self.comp(x, self.loop_rel), self.loop_w)) / 3 + else: + if self.wsi: + x = g.ndata.pop('h') / 2 + if self.wni: + x = torch.mm(self.comp(x, self.loop_rel), self.loop_w) + + if self.bias is not None: + x = x + self.bias + + if self.use_bn: + x = self.bn(x) + + if self.ltr: + return self.act(x), torch.matmul(self.rel.data, self.w_rel) + else: + return self.act(x), self.rel.data diff --git a/model/fune_layer.py b/model/fune_layer.py new file mode 100644 index 0000000..2f702c3 --- /dev/null +++ b/model/fune_layer.py @@ -0,0 +1,237 @@ +import torch +from torch import nn +import dgl +import dgl.function as fn +from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES +from model.operations import * + + +class CompOp(nn.Module): + def __init__(self, primitive): + super(CompOp, self).__init__() + self._op = COMP_OPS[primitive]() + + def reset_parameters(self): + self._op.reset_parameters() + + def forward(self, src_emb, rel_emb): + return self._op(src_emb, rel_emb) + + +class AggOp(nn.Module): + def __init__(self, primitive): + super(AggOp, self).__init__() + self._op = AGG_OPS[primitive]() + + def reset_parameters(self): + self._op.reset_parameters() + + def forward(self, msg): + return self._op(msg) + + +class CombOp(nn.Module): + def __init__(self, primitive, out_channels): + super(CombOp, self).__init__() + self._op = COMB_OPS[primitive](out_channels) + + def reset_parameters(self): + self._op.reset_parameters() + + def forward(self, self_emb, msg): + return self._op(self_emb, msg) + + +class ActOp(nn.Module): + def __init__(self, primitive): + super(ActOp, self).__init__() + self._op = ACT_OPS[primitive]() + + def reset_parameters(self): + self._op.reset_parameters() + + def forward(self, emb): + return self._op(emb) + +def act_map(act): + if act == "identity": + return lambda x: x + elif act == "elu": + return torch.nn.functional.elu + elif act == "sigmoid": + return torch.sigmoid + elif act == "tanh": + return torch.tanh + elif act == "relu": + return torch.nn.functional.relu + elif act == "relu6": + return torch.nn.functional.relu6 + elif act == "softplus": + return torch.nn.functional.softplus + elif act == "leaky_relu": + return torch.nn.functional.leaky_relu + else: + raise Exception("wrong activate function") + + +class SearchedGCNConv(nn.Module): + def __init__(self, in_channels, out_channels, bias=True, drop_rate=0., num_base=-1, + num_rel=None, wni=False, wsi=False, use_bn=True, ltr=True, comp=None, agg=None, comb=None, act=None): + super(SearchedGCNConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.comp_op = comp + self.act = act_map(act) # activation function + self.agg_op = agg + self.device = None + + self.rel = nn.Parameter(torch.empty([num_rel, in_channels], dtype=torch.float)) + + self.use_bn = use_bn + self.ltr = ltr + + # relation-type specific parameter + self.in_w = self.get_param([in_channels, out_channels]) + self.out_w = self.get_param([in_channels, out_channels]) + self.loop_w = self.get_param([in_channels, out_channels]) + # transform embedding of relations to next layer + self.w_rel = self.get_param([in_channels, out_channels]) + self.loop_rel = self.get_param([1, in_channels]) # self-loop embedding + + self.drop = nn.Dropout(drop_rate) + self.bn = torch.nn.BatchNorm1d(out_channels) + self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None + if num_base > 0: + self.rel_wt = self.get_param([num_rel, num_base]) + else: + self.rel_wt = None + + self.wni = wni + self.wsi = wsi + # self.comp = CompOp(comp) + # self.agg = AggOp(agg) + self.comb = CombOp(comb, out_channels) + # self.act = ActOp(act) + + def get_param(self, shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu')) + return param + + def message_func(self, edges): + edge_type = edges.data['type'] # [E, 1] + edge_num = edge_type.shape[0] + edge_data = self.comp( + edges.src['h'], self.rel[edge_type]) # [E, in_channel] + # NOTE: first half edges are all in-directions, last half edges are out-directions. + msg = torch.cat([torch.matmul(edge_data[:edge_num // 2, :], self.in_w), + torch.matmul(edge_data[edge_num // 2:, :], self.out_w)]) + msg = msg * edges.data['norm'].reshape(-1, 1) # [E, D] * [E, 1] + return {'msg': msg} + + def reduce_func(self, nodes): + # return {'h': torch.sum(nodes.mailbox['msg'], dim=1)} + return {'h': self.agg(nodes.mailbox['msg'])} + + def apply_node_func(self, nodes): + return {'h': self.drop(nodes.data['h'])} + + def comp(self, h, edge_data): + # def com_mult(a, b): + # r1, i1 = a[..., 0], a[..., 1] + # r2, i2 = b[..., 0], b[..., 1] + # return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1) + # + # def conj(a): + # a[..., 1] = -a[..., 1] + # return a + # + # def ccorr(a, b): + # # return torch.irfft(com_mult(conj(torch.rfft(a, 1)), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],)) + # return torch.fft.irfftn(torch.conj(torch.fft.rfftn(a, (-1))) * torch.fft.rfftn(b, (-1)), (-1)) + + def com_mult(a, b): + r1, i1 = a.real, a.imag + r2, i2 = b.real, b.imag + real = r1 * r2 - i1 * i2 + imag = r1 * i2 + i1 * r2 + return torch.complex(real, imag) + + def conj(a): + a.imag = -a.imag + return a + + def ccorr(a, b): + return torch.fft.irfft(com_mult(conj(torch.fft.rfft(a)), torch.fft.rfft(b)), a.shape[-1]) + + def rotate(h, r): + # re: first half, im: second half + # assume embedding dim is the last dimension + d = h.shape[-1] + h_re, h_im = torch.split(h, d // 2, -1) + r_re, r_im = torch.split(r, d // 2, -1) + return torch.cat([h_re * r_re - h_im * r_im, + h_re * r_im + h_im * r_re], dim=-1) + + if self.comp_op == 'mult': + return h * edge_data + elif self.comp_op == 'add': + return h + edge_data + elif self.comp_op == 'sub': + return h - edge_data + elif self.comp_op == 'ccorr': + return ccorr(h, edge_data.expand_as(h)) + elif self.comp_op == 'rotate': + return rotate(h, edge_data) + else: + raise KeyError(f'composition operator {self.opn} not recognized.') + + def forward(self, g: dgl.DGLGraph, x, rel_repr, edge_type, edge_norm): + """ + :param g: dgl Graph, a graph without self-loop + :param x: input node features, [V, in_channel] + :param rel_repr: input relation features: 1. not using bases: [num_rel*2, in_channel] + 2. using bases: [num_base, in_channel] + :param edge_type: edge type, [E] + :param edge_norm: edge normalization, [E] + :return: x: output node features: [V, out_channel] + rel: output relation features: [num_rel*2, out_channel] + """ + self.device = x.device + g = g.local_var() + g.ndata['h'] = x + g.edata['type'] = edge_type + g.edata['norm'] = edge_norm + if self.rel_wt is None: + self.rel.data = rel_repr + else: + # [num_rel*2, num_base] @ [num_base, in_c] + self.rel.data = torch.mm(self.rel_wt, rel_repr) + if self.agg_op == 'max': + g.update_all(self.message_func, fn.max(msg='msg', out='h'), self.apply_node_func) + elif self.agg_op == 'mean': + g.update_all(self.message_func, fn.mean(msg='msg', out='h'), self.apply_node_func) + elif self.agg_op == 'sum': + g.update_all(self.message_func, fn.sum(msg='msg', out='h'), self.apply_node_func) + # g.update_all(self.message_func, self.reduce_func, self.apply_node_func) + + if (not self.wni) and (not self.wsi): + x = self.comb(g.ndata.pop('h'), torch.mm(self.comp(x, self.loop_rel), self.loop_w))*(1/3) + # x = (g.ndata.pop('h') + + # torch.mm(self.comp(x, self.loop_rel, self.comp_weights), self.loop_w)) / 3 + # else: + # if self.wsi: + # x = g.ndata.pop('h') / 2 + # if self.wni: + # x = torch.mm(self.comp(x, self.loop_rel), self.loop_w) + + if self.bias is not None: + x = x + self.bias + + if self.use_bn: + x = self.bn(x) + + if self.ltr: + return self.act(x), torch.matmul(self.rel.data, self.w_rel) + else: + return self.act(x), self.rel.data \ No newline at end of file diff --git a/model/gcns.py b/model/gcns.py new file mode 100644 index 0000000..4919469 --- /dev/null +++ b/model/gcns.py @@ -0,0 +1,1607 @@ +import torch +from torch import nn +import dgl +from model.rgcn_layer import RelGraphConv +from model.compgcn_layer import CompGCNCov +import torch.nn.functional as F +from dgl import NID, EID, readout_nodes +from dgl.nn.pytorch import GraphConv +import time + + +class GCNs(nn.Module): + def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + conv_bias=True, gcn_drop=0., opn='mult', wni=False, wsi=False, encoder='compgcn', use_bn=True, + ltr=True, input_type='subgraph', loss_type='ce', add_reverse=True): + super(GCNs, self).__init__() + self.act = torch.tanh + if loss_type == 'ce': + self.loss = nn.CrossEntropyLoss() + elif loss_type == 'bce': + self.loss = nn.BCELoss(reduce=False) + elif loss_type == 'bce_logits': + self.loss = nn.BCEWithLogitsLoss() + else: + raise NotImplementedError + self.num_ent, self.num_rel, self.num_base = num_ent, num_rel, num_base + self.init_dim, self.gcn_dim, self.embed_dim = init_dim, gcn_dim, embed_dim + self.conv_bias = conv_bias + self.gcn_drop = gcn_drop + self.opn = opn + self.edge_type = edge_type # [E] + self.edge_norm = edge_norm # [E] + self.n_layer = n_layer + self.input_type = input_type + + self.wni = wni + + self.encoder = encoder + if input_type == 'subgraph': + self.init_embed = self.get_param([self.num_ent, self.init_dim]) + # self.init_embed = nn.Embedding(self.num_ent+2, self.init_dim) + else: + self.init_embed = self.get_param([self.num_ent + 1, self.init_dim]) + if add_reverse: + self.init_rel = self.get_param([self.num_rel * 2, self.init_dim]) + self.bias_rel = nn.Parameter(torch.zeros(self.num_rel * 2)) + else: + self.init_rel = self.get_param([self.num_rel, self.init_dim]) + self.bias_rel = nn.Parameter(torch.zeros(self.num_rel)) + + if encoder == 'compgcn': + if n_layer < 3: + self.conv1 = CompGCNCov(self.init_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, opn, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + add_reverse=add_reverse) + self.conv2 = CompGCNCov(self.gcn_dim, self.embed_dim, self.act, conv_bias, gcn_drop, + opn, num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + add_reverse=add_reverse) if n_layer == 2 else None + else: + self.conv1 = CompGCNCov(self.init_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, opn, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + add_reverse=add_reverse) + self.conv2 = CompGCNCov(self.gcn_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, opn, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + add_reverse=add_reverse) + self.conv3 = CompGCNCov(self.gcn_dim, self.embed_dim, self.act, conv_bias, gcn_drop, + opn, num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + add_reverse=add_reverse) + elif encoder == 'rgcn': + if n_layer < 3: + self.conv1 = RelGraphConv(self.init_dim, self.gcn_dim, self.num_rel * 2, "bdd", + num_bases=self.num_base, activation=self.act, self_loop=(not wsi), + dropout=gcn_drop, wni=wni) + self.conv2 = RelGraphConv(self.gcn_dim, self.embed_dim, self.num_rel * 2, "bdd", num_bases=self.num_base, + activation=self.act, self_loop=(not wsi), dropout=gcn_drop, + wni=wni) if n_layer == 2 else None + else: + self.conv1 = RelGraphConv(self.init_dim, self.gcn_dim, self.num_rel * 2, "bdd", + num_bases=self.num_base, activation=self.act, self_loop=(not wsi), + dropout=gcn_drop, wni=wni) + self.conv2 = RelGraphConv(self.gcn_dim, self.gcn_dim, self.num_rel * 2, "bdd", + num_bases=self.num_base, activation=self.act, self_loop=(not wsi), + dropout=gcn_drop, wni=wni) + self.conv3 = RelGraphConv(self.gcn_dim, self.embed_dim, self.num_rel * 2, "bdd", num_bases=self.num_base, + activation=self.act, self_loop=(not wsi), dropout=gcn_drop, + wni=wni) if n_layer == 2 else None + elif encoder == 'gcn': + self.conv1 = GraphConv(self.init_dim, self.gcn_dim, allow_zero_in_degree=True) + self.conv2 = GraphConv(self.gcn_dim, self.gcn_dim, allow_zero_in_degree=True) + + self.bias = nn.Parameter(torch.zeros(self.num_ent)) + # self.bias_rel = nn.Parameter(torch.zeros(self.num_rel*2)) + + def get_param(self, shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu')) + return param + + def calc_loss(self, pred, label, pos_neg=None): + if pos_neg is not None: + m = nn.Sigmoid() + score_pos = m(pred) + targets_pos = pos_neg.unsqueeze(1) + loss = self.loss(score_pos, label * targets_pos) + return torch.sum(loss * label) + return self.loss(pred, label) + + def forward_base(self, g, subj, rel, drop1, drop2): + """ + :param g: graph + :param sub: subjects in a batch [batch] + :param rel: relations in a batch [batch] + :param drop1: dropout rate in first layer + :param drop2: dropout rate in second layer + :return: sub_emb: [batch, D] + rel_emb: [num_rel*2, D] + x: [num_ent, D] + """ + x, r = self.init_embed, self.init_rel # embedding of relations + + if self.n_layer > 0: + if self.encoder == 'compgcn': + if self.n_layer < 3: + x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2( + g, x, r, self.edge_type, self.edge_norm) if self.n_layer == 2 else (x, r) + x = drop2(x) if self.n_layer == 2 else x + else: + x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv3(g, x, r, self.edge_type, self.edge_norm) + x = drop2(x) + elif self.encoder == 'rgcn': + if self.n_layer < 3: + x = self.conv1(g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) + x = drop1(x) # embeddings of entities [num_ent, dim] + x = self.conv2( + g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x + x = drop2(x) if self.n_layer == 2 else x + else: + x = self.conv1(g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) + x = drop1(x) # embeddings of entities [num_ent, dim] + x = self.conv2(g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) + x = drop1(x) + x = self.conv3(g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) + x = drop2(x) + + # filter out embeddings of subjects in this batch + sub_emb = torch.index_select(x, 0, subj) + # filter out embeddings of relations in this batch + rel_emb = torch.index_select(r, 0, rel) + + return sub_emb, rel_emb, x + + def forward_base_search(self, g, subj, rel, drop1, drop2): + """ + :param g: graph + :param sub: subjects in a batch [batch] + :param rel: relations in a batch [batch] + :param drop1: dropout rate in first layer + :param drop2: dropout rate in second layer + :return: sub_emb: [batch, D] + rel_emb: [num_rel*2, D] + x: [num_ent, D] + """ + x, r = self.init_embed, self.init_rel # embedding of relations + x_hidden = [] + if self.n_layer > 0: + if self.encoder == 'compgcn': + if self.n_layer < 3: + x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm) + x_hidden.append(x) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2( + g, x, r, self.edge_type, self.edge_norm) if self.n_layer == 2 else (x, r) + x_hidden.append(x) + x = drop2(x) if self.n_layer == 2 else x + else: + x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm) + x_hidden.append(x) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2(g, x, r, self.edge_type, self.edge_norm) + x_hidden.append(x) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv3(g, x, r, self.edge_type, self.edge_norm) + x_hidden.append(x) + x = drop2(x) + elif self.encoder == 'rgcn': + x = self.conv1(g, x, self.edge_type, + self.edge_norm.unsqueeze(-1)) + x = drop1(x) # embeddings of entities [num_ent, dim] + x = self.conv2( + g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x + x = drop2(x) if self.n_layer == 2 else x + + # filter out embeddings of subjects in this batch + sub_emb = torch.index_select(x, 0, subj) + # filter out embeddings of relations in this batch + rel_emb = torch.index_select(r, 0, rel) + + x_hidden = torch.stack(x_hidden, dim=1) + + return x_hidden + + def forward_base_rel(self, g, subj, obj, drop1, drop2): + """ + :param g: graph + :param sub: subjects in a batch [batch] + :param rel: relations in a batch [batch] + :param drop1: dropout rate in first layer + :param drop2: dropout rate in second layer + :return: sub_emb: [batch, D] + rel_emb: [num_rel*2, D] + x: [num_ent, D] + """ + x, r = self.init_embed, self.init_rel # embedding of relations + + if self.n_layer > 0: + if self.encoder == 'compgcn': + if self.n_layer < 3: + x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2( + g, x, r, self.edge_type, self.edge_norm) if self.n_layer == 2 else (x, r) + x = drop2(x) if self.n_layer == 2 else x + else: + x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv3(g, x, r, self.edge_type, self.edge_norm) + x = drop2(x) + elif self.encoder == 'rgcn': + x = self.conv1(g, x, self.edge_type, + self.edge_norm.unsqueeze(-1)) + x = drop1(x) # embeddings of entities [num_ent, dim] + x = self.conv2( + g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x + x = drop2(x) if self.n_layer == 2 else x + + # filter out embeddings of subjects in this batch + sub_emb = torch.index_select(x, 0, subj) + # filter out embeddings of objects in this batch + obj_emb = torch.index_select(x, 0, obj) + + return sub_emb, obj_emb, x, r + + def forward_base_rel_vis_hop(self, g, subj, obj, drop1, drop2): + """ + :param g: graph + :param sub: subjects in a batch [batch] + :param rel: relations in a batch [batch] + :param drop1: dropout rate in first layer + :param drop2: dropout rate in second layer + :return: sub_emb: [batch, D] + rel_emb: [num_rel*2, D] + x: [num_ent, D] + """ + x, r = self.init_embed, self.init_rel # embedding of relations + x_hidden = [] + + if self.n_layer > 0: + if self.encoder == 'compgcn': + if self.n_layer < 3: + x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm) + x_hidden.append(x) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2( + g, x, r, self.edge_type, self.edge_norm) if self.n_layer == 2 else (x, r) + x_hidden.append(x) + x = drop2(x) if self.n_layer == 2 else x + else: + x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm) + x_hidden.append(x) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2(g, x, r, self.edge_type, self.edge_norm) + x_hidden.append(x) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv3(g, x, r, self.edge_type, self.edge_norm) + x_hidden.append(x) + x = drop2(x) + elif self.encoder == 'rgcn': + x = self.conv1(g, x, self.edge_type, + self.edge_norm.unsqueeze(-1)) + x = drop1(x) # embeddings of entities [num_ent, dim] + x = self.conv2( + g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x + x = drop2(x) if self.n_layer == 2 else x + + # filter out embeddings of subjects in this batch + sub_emb = torch.index_select(x, 0, subj) + # filter out embeddings of objects in this batch + obj_emb = torch.index_select(x, 0, obj) + x_hidden = torch.stack(x_hidden, dim=1) + + return sub_emb, obj_emb, x, r, x_hidden + + def forward_base_rel_search(self, g, drop1, drop2): + """ + :param g: graph + :param sub: subjects in a batch [batch] + :param rel: relations in a batch [batch] + :param drop1: dropout rate in first layer + :param drop2: dropout rate in second layer + :return: sub_emb: [batch, D] + rel_emb: [num_rel*2, D] + x: [num_ent, D] + """ + x, r = self.init_embed, self.init_rel # embedding of relations + x_hidden = [] + if self.n_layer > 0: + if self.encoder == 'compgcn': + if self.n_layer < 3: + x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x_hidden.append(x) + x, r = self.conv2( + g, x, r, self.edge_type, self.edge_norm) if self.n_layer == 2 else (x, r) + x = drop2(x) if self.n_layer == 2 else x + x_hidden.append(x) + else: + x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x_hidden.append(x) + x, r = self.conv2(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x_hidden.append(x) + x, r = self.conv3(g, x, r, self.edge_type, self.edge_norm) + x = drop2(x) + x_hidden.append(x) + elif self.encoder == 'rgcn': + x = self.conv1(g, x, self.edge_type, + self.edge_norm.unsqueeze(-1)) + x = drop1(x) # embeddings of entities [num_ent, dim] + x = self.conv2( + g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x + x = drop2(x) if self.n_layer == 2 else x + + x_hidden = torch.stack(x_hidden, dim=1) + + return x_hidden, x + + def forward_base_rel_subgraph(self, bg, subj, obj, drop1, drop2): + """ + :param g: graph + :param sub: subjects in a batch [batch] + :param rel: relations in a batch [batch] + :param drop1: dropout rate in first layer + :param drop2: dropout rate in second layer + :return: sub_emb: [batch, D] + rel_emb: [num_rel*2, D] + x: [num_ent, D] + """ + x, r = self.init_embed[bg.ndata[NID]], self.init_rel # embedding of relations + edge_type = self.edge_type[bg.edata[EID]] + edge_norm = self.edge_norm[bg.edata[EID]] + if self.n_layer > 0: + if self.encoder == 'compgcn': + if self.n_layer < 3: + x, r = self.conv1(bg, x, r, edge_type, edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2( + bg, x, r, edge_type, edge_norm) if self.n_layer == 2 else (x, r) + x = drop2(x) if self.n_layer == 2 else x + else: + x, r = self.conv1(bg, x, r, edge_type, edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2(bg, x, r, edge_type, edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv3(bg, x, r, edge_type, edge_norm) + x = drop2(x) + elif self.encoder == 'rgcn': + x = self.conv1(bg, x, self.edge_type, + self.edge_norm.unsqueeze(-1)) + x = drop1(x) # embeddings of entities [num_ent, dim] + x = self.conv2( + bg, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x + x = drop2(x) if self.n_layer == 2 else x + + bg.ndata['h'] = x + sub_list = [] + obj_list = [] + for idx, g in enumerate(dgl.unbatch(bg)): + head_idx = torch.where(g.ndata[NID] == subj[idx])[0] + tail_idx = torch.where(g.ndata[NID] == obj[idx])[0] + head_emb = g.ndata['h'][head_idx] + tail_emb = g.ndata['h'][tail_idx] + sub_list.append(head_emb) + obj_list.append(tail_emb) + sub_emb = torch.cat(sub_list, dim=0) + obj_emb = torch.cat(obj_list, dim=0) + # # filter out embeddings of subjects in this batch + # sub_emb = torch.index_select(x, 0, subj) + # # filter out embeddings of objects in this batch + # obj_emb = torch.index_select(x, 0, obj) + + return sub_emb, obj_emb, x, r + + def forward_base_rel_subgraph_trans(self, bg, input_ids, drop1, drop2): + """ + :param g: graph + :param sub: subjects in a batch [batch] + :param rel: relations in a batch [batch] + :param drop1: dropout rate in first layer + :param drop2: dropout rate in second layer + :return: sub_emb: [batch, D] + rel_emb: [num_rel*2, D] + x: [num_ent, D] + """ + x, r = self.init_embed[bg.ndata[NID]], self.init_rel # embedding of relations + edge_type = self.edge_type[bg.edata[EID]] + edge_norm = self.edge_norm[bg.edata[EID]] + if self.n_layer > 0: + if self.encoder == 'compgcn': + if self.n_layer < 3: + x, r = self.conv1(bg, x, r, edge_type, edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2( + bg, x, r, edge_type, edge_norm) if self.n_layer == 2 else (x, r) + x = drop2(x) if self.n_layer == 2 else x + else: + x, r = self.conv1(bg, x, r, edge_type, edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2(bg, x, r, edge_type, edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv3(bg, x, r, edge_type, edge_norm) + x = drop2(x) + elif self.encoder == 'rgcn': + x = self.conv1(bg, x, self.edge_type, + self.edge_norm.unsqueeze(-1)) + x = drop1(x) # embeddings of entities [num_ent, dim] + x = self.conv2( + bg, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x + x = drop2(x) if self.n_layer == 2 else x + + bg.ndata['h'] = x + input_emb = self.init_embed[input_ids] + for idx, g in enumerate(dgl.unbatch(bg)): + # print(g.ndata['h'].size()) + # print(input_emb[idx][:].size()) + input_emb[idx][1:g.num_nodes() + 1] = g.ndata['h'] + # sub_list = [] + # obj_list = [] + # for idx, g in enumerate(dgl.unbatch(bg)): + # head_idx = torch.where(g.ndata[NID] == subj[idx])[0] + # tail_idx = torch.where(g.ndata[NID] == obj[idx])[0] + # head_emb = g.ndata['h'][head_idx] + # tail_emb = g.ndata['h'][tail_idx] + # sub_list.append(head_emb) + # obj_list.append(tail_emb) + # sub_emb = torch.cat(sub_list,dim=0) + # obj_emb = torch.cat(obj_list,dim=0) + # # filter out embeddings of subjects in this batch + # sub_emb = torch.index_select(x, 0, subj) + # # filter out embeddings of objects in this batch + # obj_emb = torch.index_select(x, 0, obj) + + return input_emb + + def forward_base_rel_subgraph_trans_new(self, bg, drop1, drop2): + """ + :param g: graph + :param sub: subjects in a batch [batch] + :param rel: relations in a batch [batch] + :param drop1: dropout rate in first layer + :param drop2: dropout rate in second layer + :return: sub_emb: [batch, D] + rel_emb: [num_rel*2, D] + x: [num_ent, D] + """ + x, r = self.init_embed[bg.ndata[NID]], self.init_rel # embedding of relations + # print(bg.ndata[NID]) + # print(self.edge_type.size()) + # exit(0) + edge_type = self.edge_type[bg.edata[EID]] + edge_norm = self.edge_norm[bg.edata[EID]] + if self.n_layer > 0: + if self.encoder == 'compgcn': + if self.n_layer < 3: + # print(bg) + x, r = self.conv1(bg, x, r, edge_type, edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2( + bg, x, r, edge_type, edge_norm) if self.n_layer == 2 else (x, r) + x = drop2(x) if self.n_layer == 2 else x + else: + x, r = self.conv1(bg, x, r, edge_type, edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2(bg, x, r, edge_type, edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv3(bg, x, r, edge_type, edge_norm) + x = drop2(x) + elif self.encoder == 'rgcn': + x = self.conv1(bg, x, self.edge_type, + self.edge_norm.unsqueeze(-1)) + x = drop1(x) # embeddings of entities [num_ent, dim] + x = self.conv2( + bg, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x + x = drop2(x) if self.n_layer == 2 else x + + bg.ndata['h'] = x + + return bg, r + + def forward_base_no_transform(self, subj, rel): + x, r = self.init_embed, self.init_rel + sub_emb = torch.index_select(x, 0, subj) + rel_emb = torch.index_select(r, 0, rel) + + return sub_emb, rel_emb, x + + def forward_base_subgraph_search(self, bg, drop1, drop2): + """ + :param g: graph + :param sub: subjects in a batch [batch] + :param rel: relations in a batch [batch] + :param drop1: dropout rate in first layer + :param drop2: dropout rate in second layer + :return: sub_emb: [batch, D] + rel_emb: [num_rel*2, D] + x: [num_ent, D] + """ + x, r = self.init_embed[bg.ndata[NID]], self.init_rel # embedding of relations + edge_type = self.edge_type[bg.edata[EID]] + edge_norm = self.edge_norm[bg.edata[EID]] + x_hidden = [] + if self.n_layer > 0: + if self.encoder == 'compgcn': + if self.n_layer < 3: + x, r = self.conv1(bg, x, r, edge_type, edge_norm) + x_hidden.append(x) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2( + bg, x, r, edge_type, edge_norm) if self.n_layer == 2 else (x, r) + x_hidden.append(x) + x = drop2(x) if self.n_layer == 2 else x + else: + x, r = self.conv1(bg, x, r, edge_type, edge_norm) + x_hidden.append(x) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2(bg, x, r, edge_type, edge_norm) + x_hidden.append(x) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv3(bg, x, r, edge_type, edge_norm) + x_hidden.append(x) + x = drop2(x) + elif self.encoder == 'rgcn': + x = self.conv1(bg, x, self.edge_type, + self.edge_norm.unsqueeze(-1)) + x = drop1(x) # embeddings of entities [num_ent, dim] + x = self.conv2( + bg, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x + x = drop2(x) if self.n_layer == 2 else x + + x_hidden = torch.stack(x_hidden, dim=1) + + return x_hidden + + + +class GCN_TransE(GCNs): + def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + bias=True, gcn_drop=0., opn='mult', hid_drop=0., gamma=9., wni=False, wsi=False, encoder='compgcn', + use_bn=True, ltr=True): + super(GCN_TransE, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, + edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, use_bn, ltr) + self.drop = nn.Dropout(hid_drop) + self.gamma = gamma + + def forward(self, g, subj, rel): + """ + :param g: dgl graph + :param sub: subject in batch [batch_size] + :param rel: relation in batch [batch_size] + :return: score: [batch_size, ent_num], the prob in link-prediction + """ + sub_emb, rel_emb, all_ent = self.forward_base( + g, subj, rel, self.drop, self.drop) + obj_emb = sub_emb + rel_emb + + x = self.gamma - torch.norm(obj_emb.unsqueeze(1) - all_ent, p=1, dim=2) + + score = torch.sigmoid(x) + + return score + + +class GCN_DistMult(GCNs): + def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + bias=True, gcn_drop=0., opn='mult', hid_drop=0., wni=False, wsi=False, encoder='compgcn', use_bn=True, + ltr=True): + super(GCN_DistMult, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, + edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, use_bn, ltr) + self.drop = nn.Dropout(hid_drop) + + def forward(self, g, subj, rel): + """ + :param g: dgl graph + :param sub: subject in batch [batch_size] + :param rel: relation in batch [batch_size] + :return: score: [batch_size, ent_num], the prob in link-prediction + """ + sub_emb, rel_emb, all_ent = self.forward_base( + g, subj, rel, self.drop, self.drop) + obj_emb = sub_emb * rel_emb # [batch_size, emb_dim] + x = torch.mm(obj_emb, all_ent.transpose(1, 0)) # [batch_size, ent_num] + x += self.bias.expand_as(x) + score = torch.sigmoid(x) + return score + + +class GCN_ConvE(GCNs): + def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + bias=True, gcn_drop=0., opn='mult', hid_drop=0., input_drop=0., conve_hid_drop=0., feat_drop=0., + num_filt=None, ker_sz=None, k_h=None, k_w=None, wni=False, wsi=False, encoder='compgcn', use_bn=True, + ltr=True): + """ + :param num_ent: number of entities + :param num_rel: number of different relations + :param num_base: number of bases to use + :param init_dim: initial dimension + :param gcn_dim: dimension after first layer + :param embed_dim: dimension after second layer + :param n_layer: number of layer + :param edge_type: relation type of each edge, [E] + :param bias: weather to add bias + :param gcn_drop: dropout rate in compgcncov + :param opn: combination operator + :param hid_drop: gcn output (embedding of each entity) dropout + :param input_drop: dropout in conve input + :param conve_hid_drop: dropout in conve hidden layer + :param feat_drop: feature dropout in conve + :param num_filt: number of filters in conv2d + :param ker_sz: kernel size in conv2d + :param k_h: height of 2D reshape + :param k_w: width of 2D reshape + """ + super(GCN_ConvE, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, + edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, use_bn, ltr) + self.hid_drop, self.input_drop, self.conve_hid_drop, self.feat_drop = hid_drop, input_drop, conve_hid_drop, feat_drop + self.num_filt = num_filt + self.ker_sz, self.k_w, self.k_h = ker_sz, k_w, k_h + + # one channel, do bn on initial embedding + self.bn0 = torch.nn.BatchNorm2d(1) + self.bn1 = torch.nn.BatchNorm2d( + self.num_filt) # do bn on output of conv + self.bn2 = torch.nn.BatchNorm1d(self.embed_dim) + + self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout + self.input_drop = torch.nn.Dropout( + self.input_drop) # stacked input dropout + self.feature_drop = torch.nn.Dropout( + self.feat_drop) # feature map dropout + self.hidden_drop = torch.nn.Dropout( + self.conve_hid_drop) # hidden layer dropout + + self.conv2d = torch.nn.Conv2d(in_channels=1, out_channels=self.num_filt, + kernel_size=(self.ker_sz, self.ker_sz), stride=1, padding=0, bias=bias) + + flat_sz_h = int(2 * self.k_h) - self.ker_sz + 1 # height after conv + flat_sz_w = self.k_w - self.ker_sz + 1 # width after conv + self.flat_sz = flat_sz_h * flat_sz_w * self.num_filt + # fully connected projection + self.fc = torch.nn.Linear(self.flat_sz, self.embed_dim) + self.cat_type = 'multii' + + def concat(self, ent_embed, rel_embed): + """ + :param ent_embed: [batch_size, embed_dim] + :param rel_embed: [batch_size, embed_dim] + :return: stack_input: [B, C, H, W] + """ + ent_embed = ent_embed.view(-1, 1, self.embed_dim) + rel_embed = rel_embed.view(-1, 1, self.embed_dim) + # [batch_size, 2, embed_dim] + stack_input = torch.cat([ent_embed, rel_embed], 1) + + assert self.embed_dim == self.k_h * self.k_w + # reshape to 2D [batch, 1, 2*k_h, k_w] + stack_input = stack_input.reshape(-1, 1, 2 * self.k_h, self.k_w) + return stack_input + + def forward(self, g, subj, rel): + """ + :param g: dgl graph + :param sub: subject in batch [batch_size] + :param rel: relation in batch [batch_size] + :return: score: [batch_size, ent_num], the prob in link-prediction + """ + sub_emb, rel_emb, all_ent = self.forward_base( + g, subj, rel, self.drop, self.input_drop) + # [batch_size, 1, 2*k_h, k_w] + stack_input = self.concat(sub_emb, rel_emb) + x = self.bn0(stack_input) + x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] + x = self.bn1(x) + x = F.relu(x) + x = self.feature_drop(x) + x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] + x = self.fc(x) # [batch_size, embed_dim] + x = self.hidden_drop(x) + x = self.bn2(x) + x = F.relu(x) + x = torch.mm(x, all_ent.transpose(1, 0)) # [batch_size, ent_num] + x += self.bias.expand_as(x) + score = torch.sigmoid(x) + return score + + def forward_search(self, g, subj, rel): + sub_emb, rel_emb, all_ent, hidden_all_ent = self.forward_base_search( + g, subj, rel, self.drop, self.input_drop) + + return hidden_all_ent + + def compute_pred(self, hidden_x, subj, obj): + # raise NotImplementedError + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + print(h.size()) + + def cross_pair(self, x_i, x_all): + x = [] + # print(x_i.size()) + # print(x_all.size()) + x_all = x_all.repeat(x_i.size(0), 1, 1, 1) + # print(x_all.size()) + for i in range(self.n_layer): + for j in range(self.n_layer): + if self.cat_type == 'multi': + # print(x_i[:, i, :].size()) + # print(x_all[:,:,j,:].size()) + test = x_i[:, i, :].unsqueeze(1) * x_all[:, :, j, :] + # print(test.size()) + x.append(test) + else: + test = torch.cat([x_i[:, i, :].unsqueeze(1), x_all[:, :, j, :]], dim=1) + print(test.size()) + x.append(test) + x = torch.stack(x, dim=1) + return x + + +class GCN_ConvE_Rel(GCNs): + def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + bias=True, gcn_drop=0., opn='mult', hid_drop=0., input_drop=0., conve_hid_drop=0., feat_drop=0., + num_filt=None, ker_sz=None, k_h=None, k_w=None, wni=False, wsi=False, encoder='compgcn', use_bn=True, + ltr=True, input_type='subgraph'): + """ + :param num_ent: number of entities + :param num_rel: number of different relations + :param num_base: number of bases to use + :param init_dim: initial dimension + :param gcn_dim: dimension after first layer + :param embed_dim: dimension after second layer + :param n_layer: number of layer + :param edge_type: relation type of each edge, [E] + :param bias: weather to add bias + :param gcn_drop: dropout rate in compgcncov + :param opn: combination operator + :param hid_drop: gcn output (embedding of each entity) dropout + :param input_drop: dropout in conve input + :param conve_hid_drop: dropout in conve hidden layer + :param feat_drop: feature dropout in conve + :param num_filt: number of filters in conv2d + :param ker_sz: kernel size in conv2d + :param k_h: height of 2D reshape + :param k_w: width of 2D reshape + """ + super(GCN_ConvE_Rel, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, + edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, use_bn, ltr, + input_type) + self.hid_drop, self.input_drop, self.conve_hid_drop, self.feat_drop = hid_drop, input_drop, conve_hid_drop, feat_drop + self.num_filt = num_filt + self.ker_sz, self.k_w, self.k_h = ker_sz, k_w, k_h + self.n_layer = n_layer + + # one channel, do bn on initial embedding + self.bn0 = torch.nn.BatchNorm2d(1) + self.bn1 = torch.nn.BatchNorm2d( + self.num_filt) # do bn on output of conv + self.bn2 = torch.nn.BatchNorm1d(self.embed_dim) + + self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout + self.input_drop = torch.nn.Dropout( + self.input_drop) # stacked input dropout + self.feature_drop = torch.nn.Dropout( + self.feat_drop) # feature map dropout + self.hidden_drop = torch.nn.Dropout( + self.conve_hid_drop) # hidden layer dropout + + self.conv2d = torch.nn.Conv2d(in_channels=1, out_channels=self.num_filt, + kernel_size=(self.ker_sz, self.ker_sz), stride=1, padding=0, bias=bias) + + flat_sz_h = int(2 * self.k_h) - self.ker_sz + 1 # height after conv + flat_sz_w = self.k_w - self.ker_sz + 1 # width after conv + self.flat_sz = flat_sz_h * flat_sz_w * self.num_filt + # fully connected projection + self.fc = torch.nn.Linear(self.flat_sz, self.embed_dim) + self.combine_type = 'concat' + + def concat(self, ent_embed, rel_embed): + """ + :param ent_embed: [batch_size, embed_dim] + :param rel_embed: [batch_size, embed_dim] + :return: stack_input: [B, C, H, W] + """ + ent_embed = ent_embed.view(-1, 1, self.embed_dim) + rel_embed = rel_embed.view(-1, 1, self.embed_dim) + # [batch_size, 2, embed_dim] + stack_input = torch.cat([ent_embed, rel_embed], 1) + + assert self.embed_dim == self.k_h * self.k_w + # reshape to 2D [batch, 1, 2*k_h, k_w] + stack_input = stack_input.reshape(-1, 1, 2 * self.k_h, self.k_w) + return stack_input + + def forward(self, g, subj, obj): + """ + :param g: dgl graph + :param sub: subject in batch [batch_size] + :param rel: relation in batch [batch_size] + :return: score: [batch_size, ent_num], the prob in link-prediction + """ + if self.input_type == 'subgraph': + sub_emb, obj_emb, all_ent, all_rel = self.forward_base_rel_subgraph(g, subj, obj, self.drop, + self.input_drop) + else: + sub_emb, obj_emb, all_ent, all_rel = self.forward_base_rel(g, subj, obj, self.drop, self.input_drop) + # [batch_size, 1, 2*k_h, k_w] + stack_input = self.concat(sub_emb, obj_emb) + x = self.bn0(stack_input) + x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] + x = self.bn1(x) + x = F.relu(x) + x = self.feature_drop(x) + x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] + x = self.fc(x) # [batch_size, embed_dim] + x = self.hidden_drop(x) + x = self.bn2(x) + x = F.relu(x) + x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num] + x += self.bias_rel.expand_as(x) + # score = torch.sigmoid(x) + score = x + return score + + def forward_search(self, g, subj, obj): + hidden_all_ent, all_rel = self.forward_base_rel_search( + g, subj, obj, self.drop, self.input_drop) + + return hidden_all_ent, all_rel + + def compute_pred(self, hidden_x, all_rel, subj, obj, subgraph_sampler, mode='search'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = subgraph_sampler(h, mode) + # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # print(subgraph_sampler(h,mode='argmax')) + n, c = atten_matrix.shape + h = h * atten_matrix.view(n, c, 1) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + h = torch.sum(h, dim=1) + # print(h.size()) # [batch_size, 2*dim] + h = h.reshape(-1, 1, 2 * self.k_h, self.k_w) + x = self.bn0(h) + x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] + x = self.bn1(x) + x = F.relu(x) + x = self.feature_drop(x) + x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] + x = self.fc(x) # [batch_size, embed_dim] + x = self.hidden_drop(x) + x = self.bn2(x) + x = F.relu(x) + x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num] + x += self.bias_rel.expand_as(x) + # score = torch.sigmoid(x) + score = x + return score + # print(h.size()) + + def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0') + for i in range(h.size(0)): + atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1 + # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # print(subgraph_sampler(h,mode='argmax')) + n, c = atten_matrix.shape + h = h * atten_matrix.view(n, c, 1) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + h = torch.sum(h, dim=1) + # print(h.size()) # [batch_size, 2*dim] + h = h.reshape(-1, 1, 2 * self.k_h, self.k_w) + x = self.bn0(h) + x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] + x = self.bn1(x) + x = F.relu(x) + x = self.feature_drop(x) + x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] + x = self.fc(x) # [batch_size, embed_dim] + x = self.hidden_drop(x) + x = self.bn2(x) + x = F.relu(x) + x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num] + x += self.bias_rel.expand_as(x) + # score = torch.sigmoid(x) + score = x + return score + # print(h.size()) + + def cross_pair(self, x_i, x_j): + x = [] + for i in range(self.n_layer): + for j in range(self.n_layer): + if self.combine_type == 'multi': + x.append(x_i[:, i, :] * x_j[:, j, :]) + elif self.combine_type == 'concat': + x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1)) + x = torch.stack(x, dim=1) + return x + + def vis_hop_distribution(self, hidden_x, all_rel, subj, obj, subgraph_sampler, mode='search'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + atten_matrix = subgraph_sampler(h, mode) + return torch.sum(atten_matrix, dim=0) + + def vis_hop_distribution_rs(self, hidden_x, all_rel, subj, obj, random_hops): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0') + for i in range(h.size(0)): + atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1 + return torch.sum(atten_matrix, dim=0) + + +class GCN_Transformer(GCNs): + def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + bias=True, gcn_drop=0., opn='mult', hid_drop=0., input_drop=0., conve_hid_drop=0., feat_drop=0., + num_filt=None, ker_sz=None, k_h=None, k_w=None, wni=False, wsi=False, encoder='compgcn', use_bn=True, + ltr=True, input_type='subgraph', + d_model=100, num_transformer_layers=2, nhead=8, dim_feedforward=100, transformer_dropout=0.1, + transformer_activation='relu', + graph_pooling='cls', concat_type="gso", max_input_len=100, loss_type='ce'): + super(GCN_Transformer, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, + edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, use_bn, ltr, + input_type, loss_type) + + self.hid_drop, self.input_drop, self.conve_hid_drop, self.feat_drop = hid_drop, input_drop, conve_hid_drop, feat_drop + self.num_filt = num_filt + self.ker_sz, self.k_w, self.k_h = ker_sz, k_w, k_h + + # one channel, do bn on initial embedding + self.bn0 = torch.nn.BatchNorm2d(1) + self.bn1 = torch.nn.BatchNorm2d( + self.num_filt) # do bn on output of conv + self.bn2 = torch.nn.BatchNorm1d(self.embed_dim) + + self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout + self.input_drop = torch.nn.Dropout( + self.input_drop) # stacked input dropout + self.feature_drop = torch.nn.Dropout( + self.feat_drop) # feature map dropout + self.hidden_drop = torch.nn.Dropout( + self.conve_hid_drop) # hidden layer dropout + + self.conv2d = torch.nn.Conv2d(in_channels=1, out_channels=self.num_filt, + kernel_size=(self.ker_sz, self.ker_sz), stride=1, padding=0, bias=bias) + + flat_sz_h = int(2 * self.k_h) - self.ker_sz + 1 # height after conv + flat_sz_w = self.k_w - self.ker_sz + 1 # width after conv + self.flat_sz = flat_sz_h * flat_sz_w * self.num_filt + # fully connected projection + self.fc = torch.nn.Linear(self.flat_sz, self.embed_dim) + + self.d_model = d_model + self.num_layer = num_transformer_layers + self.gnn2transformer = nn.Linear(gcn_dim, d_model) + # Creating Transformer Encoder Model + encoder_layer = nn.TransformerEncoderLayer( + d_model, nhead, dim_feedforward, transformer_dropout, transformer_activation + ) + encoder_norm = nn.LayerNorm(d_model) + self.transformer = nn.TransformerEncoder(encoder_layer, num_transformer_layers, encoder_norm) + # self.max_input_len = args.max_input_len + self.norm_input = nn.LayerNorm(d_model) + self.graph_pooling = graph_pooling + self.max_input_len = max_input_len + self.concat_type = concat_type + self.cls_embedding = None + if self.graph_pooling == "cls": + self.cls_embedding = nn.Parameter(torch.randn([1, 1, self.d_model], requires_grad=True)) + if self.concat_type == "gso": + self.fc1 = nn.Linear(d_model * 3, 256) + elif self.concat_type == "so": + self.fc1 = nn.Linear(d_model * 2, 256) + elif self.concat_type == "g": + self.fc1 = nn.Linear(d_model, 256) + self.fc2 = nn.Linear(256, num_rel * 2) + + def forward(self, bg, input_ids, subj, obj): + # tokens_emb = self.forward_base_rel_subgraph_trans(bg,input_ids,self.drop,self.input_drop) + # tokens_emb = tokens_emb.permute(1,0,2) # (s, b, d) + # tokens_emb = self.gnn2transformer(tokens_emb) + # padding_mask = get_pad_mask(input_ids, self.num_ent+1) + # tokens_emb = self.norm_input(tokens_emb) + # transformer_out = self.transformer(tokens_emb, src_key_padding_mask=padding_mask) + # h_graph = transformer_out[0] + # output = self.fc_out(h_graph) + # output = torch.sigmoid(output) + # return output + + bg, all_rel = self.forward_base_rel_subgraph_trans_new(bg, self.drop, self.input_drop) + batch_size = bg.batch_size + h_node = self.gnn2transformer(bg.ndata['h']) + padded_h_node, src_padding_mask, subj_idx, obj_idx = pad_batch(h_node, bg, self.max_input_len, subj, obj) + if self.cls_embedding is not None: + expand_cls_embedding = self.cls_embedding.expand(1, padded_h_node.size(1), -1) + padded_h_node = torch.cat([padded_h_node, expand_cls_embedding], dim=0) + zeros = src_padding_mask.data.new(src_padding_mask.size(0), 1).fill_(0) + src_padding_mask = torch.cat([src_padding_mask, zeros], dim=1) + padded_h_node = self.norm_input(padded_h_node) + transformer_out = self.transformer(padded_h_node, src_key_padding_mask=src_padding_mask) + subj_emb_list = [] + obj_emb_list = [] + for batch_idx in range(batch_size): + subj_emb = transformer_out[subj_idx[batch_idx], batch_idx, :] + obj_emb = transformer_out[obj_idx[batch_idx], batch_idx, :] + subj_emb_list.append(subj_emb) + obj_emb_list.append(obj_emb) + subj_emb = torch.stack(subj_emb_list) + obj_emb = torch.stack(obj_emb_list) + if self.graph_pooling in ["last", "cls"]: + h_graph = transformer_out[-1] + if self.concat_type == "gso": + h_repr = torch.cat([h_graph, subj_emb, obj_emb], dim=1) + elif self.concat_type == "g": + h_repr = h_graph + else: + if self.concat_type == "so": + h_repr = torch.cat([subj_emb, obj_emb], dim=1) + h_repr = F.relu(self.fc1(h_repr)) + score = self.fc2(h_repr) + # print(score.size()) + # print(score) + # score = self.conve(subj_emb, obj_emb, all_rel) + + return score + + def conve_rel(self, sub_emb, obj_emb, all_rel): + stack_input = self.concat(sub_emb, obj_emb) + x = self.bn0(stack_input) + x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] + x = self.bn1(x) + x = F.relu(x) + x = self.feature_drop(x) + x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] + x = self.fc(x) # [batch_size, embed_dim] + x = self.hidden_drop(x) + x = self.bn2(x) + x = F.relu(x) + x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num] + x += self.bias_rel.expand_as(x) + score = torch.sigmoid(x) + return score + + def concat(self, ent_embed, rel_embed): + """ + :param ent_embed: [batch_size, embed_dim] + :param rel_embed: [batch_size, embed_dim] + :return: stack_input: [B, C, H, W] + """ + ent_embed = ent_embed.view(-1, 1, self.embed_dim) + rel_embed = rel_embed.view(-1, 1, self.embed_dim) + # [batch_size, 2, embed_dim] + stack_input = torch.cat([ent_embed, rel_embed], 1) + + assert self.embed_dim == self.k_h * self.k_w + # reshape to 2D [batch, 1, 2*k_h, k_w] + stack_input = stack_input.reshape(-1, 1, 2 * self.k_h, self.k_w) + return stack_input + + def conve_ent(self, sub_emb, rel_emb, all_ent): + stack_input = self.concat(sub_emb, rel_emb) + x = self.bn0(stack_input) + x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] + x = self.bn1(x) + x = F.relu(x) + x = self.feature_drop(x) + x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] + x = self.fc(x) # [batch_size, embed_dim] + x = self.hidden_drop(x) + x = self.bn2(x) + x = F.relu(x) + x = torch.mm(x, all_ent.transpose(1, 0)) # [batch_size, ent_num] + x += self.bias.expand_as(x) + score = torch.sigmoid(x) + return score + + def evaluate(self, subj, obj): + sub_emb, rel_emb, all_ent = self.forward_base_no_transform(subj, obj) + score = self.conve_ent(sub_emb, rel_emb, all_ent) + return score + + +class GCN_None(GCNs): + + def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + bias=True, gcn_drop=0., opn='mult', hid_drop=0., input_drop=0., conve_hid_drop=0., feat_drop=0., + num_filt=None, ker_sz=None, k_h=None, k_w=None, wni=False, wsi=False, encoder='compgcn', use_bn=True, + ltr=True, + input_type='subgraph', graph_pooling='mean', concat_type='gso', loss_type='ce', add_reverse=True): + super(GCN_None, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, + edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, use_bn, ltr, + input_type, loss_type, add_reverse) + + self.hid_drop, self.input_drop, self.conve_hid_drop, self.feat_drop = hid_drop, input_drop, conve_hid_drop, feat_drop + self.graph_pooling_type = graph_pooling + self.concat_type = concat_type + self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout + self.input_drop = torch.nn.Dropout( + self.input_drop) # stacked input dropout + if self.concat_type == "gso": + self.fc1 = nn.Linear(gcn_dim * 3, 256) + elif self.concat_type == "so": + self.fc1 = nn.Linear(gcn_dim * 2, 256) + elif self.concat_type == "g": + self.fc1 = nn.Linear(gcn_dim, 256) + if add_reverse: + self.fc2 = nn.Linear(256, num_rel * 2) + else: + self.fc2 = nn.Linear(256, num_rel) + + def forward(self, g, subj, obj): + """ + :param g: dgl graph + :param sub: subject in batch [batch_size] + :param rel: relation in batch [batch_size] + :return: score: [batch_size, ent_num], the prob in link-prediction + """ + # start_time = time.time() + bg, all_rel = self.forward_base_rel_subgraph_trans_new(g, self.drop, self.input_drop) + # print(f'{time.time() - start_time:.2f}s') + h_repr_list = [] + for batch_idx, g in enumerate(dgl.unbatch(bg)): + h_graph = readout_nodes(g, 'h', op=self.graph_pooling_type) + sub_idx = torch.where(g.ndata[NID] == subj[batch_idx])[0] + obj_idx = torch.where(g.ndata[NID] == obj[batch_idx])[0] + sub_emb = g.ndata['h'][sub_idx] + obj_emb = g.ndata['h'][obj_idx] + h_repr = torch.cat([h_graph, sub_emb, obj_emb], dim=1) + h_repr_list.append(h_repr) + h_repr = torch.stack(h_repr_list, dim=0).squeeze(1) + # print(f'{time.time() - start_time:.2f}s') + score = F.relu(self.fc1(h_repr)) + score = self.fc2(score) + + return score + + +class GCN_MLP(GCNs): + def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + bias=True, gcn_drop=0., opn='mult', hid_drop=0., input_drop=0., conve_hid_drop=0., feat_drop=0., + num_filt=None, ker_sz=None, k_h=None, k_w=None, wni=False, wsi=False, encoder='compgcn', use_bn=True, + ltr=True, input_type='subgraph', graph_pooling='mean', combine_type='mult', loss_type='ce', + add_reverse=True): + """ + :param num_ent: number of entities + :param num_rel: number of different relations + :param num_base: number of bases to use + :param init_dim: initial dimension + :param gcn_dim: dimension after first layer + :param embed_dim: dimension after second layer + :param n_layer: number of layer + :param edge_type: relation type of each edge, [E] + :param bias: weather to add bias + :param gcn_drop: dropout rate in compgcncov + :param opn: combination operator + :param hid_drop: gcn output (embedding of each entity) dropout + :param input_drop: dropout in conve input + :param conve_hid_drop: dropout in conve hidden layer + :param feat_drop: feature dropout in conve + :param num_filt: number of filters in conv2d + :param ker_sz: kernel size in conv2d + :param k_h: height of 2D reshape + :param k_w: width of 2D reshape + """ + super(GCN_MLP, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, + edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, + use_bn, ltr, input_type, loss_type, add_reverse) + self.hid_drop, self.input_drop = hid_drop, input_drop + if add_reverse: + self.num_rel = num_rel * 2 + else: + self.num_rel = num_rel + + self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout + self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout + + self.combine_type = combine_type + if self.combine_type == 'concat': + self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel) + elif self.combine_type == 'mult': + self.fc = torch.nn.Linear(self.embed_dim, self.num_rel) + + def forward(self, g, subj, obj): + """ + :param g: dgl graph + :param sub: subject in batch [batch_size] + :param rel: relation in batch [batch_size] + :return: score: [batch_size, ent_num], the prob in link-prediction + """ + if self.input_type == 'subgraph': + print('1') + # sub_emb, obj_emb, all_ent, all_rel = self.forward_base_rel_subgraph(g, subj, obj, self.drop, self.input_drop) + bg, all_rel = self.forward_base_rel_subgraph_trans_new(g, self.drop, self.input_drop) + sub_ids = (bg.ndata['id'] == 1).nonzero().squeeze(1) + sub_embs = bg.ndata['h'][sub_ids] + obj_ids = (bg.ndata['id'] == 2).nonzero().squeeze(1) + obj_embs = bg.ndata['h'][obj_ids] + h_graph = readout_nodes(bg, 'h', op=self.graph_pooling_type) + # edge_embs = torch.concat([sub_embs, obj_embs], dim=1) + # score = self.fc(edge_embs) + else: + sub_embs, obj_embs, all_ent, all_rel = self.forward_base_rel(g, subj, obj, self.drop, self.input_drop) + if self.combine_type == 'concat': + edge_embs = torch.concat([sub_embs, obj_embs], dim=1) + elif self.combine_type == 'mult': + edge_embs = sub_embs * obj_embs + else: + raise NotImplementedError + score = self.fc(edge_embs) + return score + + def compute_vid_pred(self, hidden_x, subj, obj): + scores = [] + scores_sigmoid = [] + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + for i in range(h.size()[1]): + # print(h.size()) + # exit(0) + edge_embs = h[:,i,:] + # if self.combine_type == 'concat': + # edge_embs = torch.concat([sub_embs, obj_embs], dim=1) + # elif self.combine_type == 'mult': + # edge_embs = sub_embs * obj_embs + # else: + # raise NotImplementedError + # print(edge_embs.size()) + score = self.fc(edge_embs) + scores.append(score) + # score_sigmoid = torch.sigmoid(score) + # print(score.size()) + # scores.append(torch.max(score[:6,:],1)) + # scores_sigmoid.append(torch.max(score_sigmoid[:6,:],1)) + # print(scores[-1]) + return scores + + def forward_search(self, g, mode='allgraph'): + # if mode == 'allgraph': + hidden_all_ent, all_ent = self.forward_base_rel_search( + g, self.drop, self.input_drop) + # elif mode == 'subgraph': + # hidden_all_ent = self.forward_base_subgraph_search( + # g, self.drop, self.input_drop) + + return hidden_all_ent, all_ent + + def compute_pred(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = subgraph_sampler(h, mode, search_algorithm) + # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # print(subgraph_sampler(h,mode='argmax')) + n, c = atten_matrix.shape + h = h * atten_matrix.view(n, c, 1) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + h = torch.sum(h, dim=1) + # print(h.size()) # [batch_size, 2*dim] + score = self.fc(h) + return score + + def compute_mix_hop_pred(self, hidden_x, subj, obj, hop_index): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + edge_embs = h[:,hop_index,:] + score = self.fc(edge_embs) + return score + # def fine_tune_with_implicit_subgraph(self, all_ent, subgraph, subj, obj): + # sg_list = [] + # for idx in range(subgraph.size(0)): + # sg_list.append(torch.mean(all_ent[subgraph[idx,:]], dim=0).unsqueeze(0)) + # sg_embs = torch.concat(sg_list) + # # print(sg_embs.size()) + # sub_embs = torch.index_select(all_ent, 0, subj) + # # print(sub_embs.size()) + # # filter out embeddings of relations in this batch + # obj_embs = torch.index_select(all_ent, 0, obj) + # # print(obj_embs.size()) + # edge_embs = torch.concat([sub_embs, obj_embs, sg_embs], dim=1) + # score = self.predictor(edge_embs) + # # print(F.embedding(subgraph, all_ent)) + # return score + + # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops): + # h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0') + # for i in range(h.size(0)): + # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1 + # # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # # print(subgraph_sampler(h,mode='argmax')) + # n, c = atten_matrix.shape + # h = h * atten_matrix.view(n,c,1) + # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + # h = torch.sum(h,dim=1) + # # print(h.size()) # [batch_size, 2*dim] + # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w) + # # x = self.bn0(h) + # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] + # # x = self.bn1(x) + # # x = F.relu(x) + # # x = self.feature_drop(x) + # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] + # # x = self.fc(x) # [batch_size, embed_dim] + # # x = self.hidden_drop(x) + # # x = self.bn2(x) + # # x = F.relu(x) + # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num] + # # x += self.bias_rel.expand_as(x) + # # # score = torch.sigmoid(x) + # # score = x + # return score + # print(h.size()) + + def cross_pair(self, x_i, x_j): + x = [] + for i in range(self.n_layer): + for j in range(self.n_layer): + if self.combine_type == 'mult': + x.append(x_i[:, i, :] * x_j[:, j, :]) + elif self.combine_type == 'concat': + x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1)) + x = torch.stack(x, dim=1) + return x + + def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + atten_matrix = subgraph_sampler(h, mode) + return torch.sum(atten_matrix, dim=0) + + def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0') + for i in range(h.size(0)): + atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1 + return torch.sum(atten_matrix, dim=0) + + +class GCN_MLP_NCN(GCNs): + def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + bias=True, gcn_drop=0., opn='mult', hid_drop=0., input_drop=0., conve_hid_drop=0., feat_drop=0., + num_filt=None, ker_sz=None, k_h=None, k_w=None, wni=False, wsi=False, encoder='compgcn', use_bn=True, + ltr=True, input_type='subgraph', graph_pooling='mean', combine_type='mult', loss_type='ce', + add_reverse=True): + """ + :param num_ent: number of entities + :param num_rel: number of different relations + :param num_base: number of bases to use + :param init_dim: initial dimension + :param gcn_dim: dimension after first layer + :param embed_dim: dimension after second layer + :param n_layer: number of layer + :param edge_type: relation type of each edge, [E] + :param bias: weather to add bias + :param gcn_drop: dropout rate in compgcncov + :param opn: combination operator + :param hid_drop: gcn output (embedding of each entity) dropout + :param input_drop: dropout in conve input + :param conve_hid_drop: dropout in conve hidden layer + :param feat_drop: feature dropout in conve + :param num_filt: number of filters in conv2d + :param ker_sz: kernel size in conv2d + :param k_h: height of 2D reshape + :param k_w: width of 2D reshape + """ + super(GCN_MLP_NCN, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, + edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, + use_bn, ltr, input_type, loss_type, add_reverse) + self.hid_drop, self.input_drop = hid_drop, input_drop + self.num_rel = num_rel + + self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout + self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout + + # fully connected projection + self.combine_type = combine_type + self.graph_pooling_type = graph_pooling + self.lin_layers = 2 + self._init_predictor() + + def _init_predictor(self): + self.lins = torch.nn.ModuleList() + if self.combine_type == 'mult': + input_channels = self.embed_dim + else: + input_channels = self.embed_dim * 2 + self.lins.append(torch.nn.Linear(input_channels + self.embed_dim, self.embed_dim)) + for _ in range(self.lin_layers - 2): + self.lins.append(torch.nn.Linear(self.embed_dim, self.embed_dim)) + self.lins.append(torch.nn.Linear(self.embed_dim, self.num_rel)) + + def forward(self, g, subj, obj, cns): + """ + :param g: dgl graph + :param sub: subject in batch [batch_size] + :param rel: relation in batch [batch_size] + :return: score: [batch_size, ent_num], the prob in link-prediction + """ + sub_embs, obj_embs, all_ent, all_rel = self.forward_base_rel(g, subj, obj, self.drop, self.input_drop) + cn_embs = self.get_common_1hopneighbor_emb(all_ent, cns) + + if self.combine_type == 'mult': + edge_embs = torch.concat([sub_embs * obj_embs, cn_embs], dim=1) + elif self.combine_type == 'concat': + edge_embs = torch.concat([sub_embs, obj_embs, cn_embs], dim=1) + x = edge_embs + for lin in self.lins[:-1]: + x = lin(x) + x = F.relu(x) + x = F.dropout(x, p=0.1, training=self.training) + score = self.lins[-1](x) + + # [batch_size, 1, 2*k_h, k_w] + # stack_input = self.concat(sub_emb, obj_emb) + # x = self.bn0(stack_input) + # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] + # x = self.bn1(x) + # x = F.relu(x) + # x = self.feature_drop(x) + # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] + # x = self.fc(x) # [batch_size, embed_dim] + # x = self.hidden_drop(x) + # x = self.bn2(x) + # x = F.relu(x) + # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num] + # x += self.bias_rel.expand_as(x) + # # score = torch.sigmoid(x) + # score = x + return score + + def forward_search(self, g, mode='allgraph'): + hidden_all_ent, all_ent = self.forward_base_rel_search( + g, self.drop, self.input_drop) + + return hidden_all_ent, all_ent + + def compute_pred(self, hidden_x, all_ent, subj, obj, cns, subgraph_sampler, mode='search'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = subgraph_sampler(h, mode) + # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # print(subgraph_sampler(h,mode='argmax')) + n, c = atten_matrix.shape + h = h * atten_matrix.view(n, c, 1) + cn_embs = self.get_common_1hopneighbor_emb(all_ent, cns) + # cn = self.get_common_neighbor_emb(hidden_x, cns) + # cn = self.get_common_neighbor_emb_(all_ent, cns) + # cn = cn * atten_matrix.view(n, c, 1) + + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + h = torch.sum(h, dim=1) + # cn = torch.sum(cn, dim=1) + concat_embs = torch.concat([h, cn_embs], dim=1) + x = concat_embs + for lin in self.lins[:-1]: + x = lin(x) + x = F.relu(x) + x = F.dropout(x, p=0.1, training=self.training) + score = self.lins[-1](x) + # score = self.fc(h) + # if self.combine_type == 'so': + # score = self.fc(h) + # elif self.combine_type == 'gso': + # bg, all_rel = self.forward_base_rel_subgraph_trans_new(g, self.drop, self.input_drop) + # h_graph = readout_nodes(bg, 'h', op=self.graph_pooling_type) + # edge_embs = torch.concat([h_graph, h], dim=1) + # score = self.predictor(edge_embs) + # print(h.size()) # [batch_size, 2*dim] + + return score + + # # print(h.size()) + + def get_common_neighbor_emb(self, hidden_x, cns): + x = [] + for i in range(self.n_layer): + for j in range(self.n_layer): + x_tmp = [] + for idx in range(cns.size(0)): + print(hidden_x[cns[idx, i * 2 + j, :], i, :].size()) + print(hidden_x[cns[idx, i * 2 + j, :], j, :].size()) + x_tmp.append(torch.mean(hidden_x[cns[idx, i * 2 + j, :], i * 2 + j, :], dim=0)) + x.append(torch.stack(x_tmp, dim=0)) + x = torch.stack(x, dim=1) + return x + + def get_common_neighbor_emb_(self, all_ent, cns): + x = [] + for i in range(self.n_layer): + for j in range(self.n_layer): + x_tmp = [] + for idx in range(cns.size(0)): + x_tmp.append(torch.mean(all_ent[cns[idx, i * 2 + j, :]], dim=0)) + x.append(torch.stack(x_tmp, dim=0)) + x = torch.stack(x, dim=1) + return x + + def get_common_1hopneighbor_emb(self, all_ent, cns): + x = [] + for idx in range(cns.size(0)): + x.append(torch.mean(all_ent[cns[idx, :], :], dim=0)) + cn_embs = torch.stack(x, dim=0) + return cn_embs + + # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops): + # h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0') + # for i in range(h.size(0)): + # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1 + # # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # # print(subgraph_sampler(h,mode='argmax')) + # n, c = atten_matrix.shape + # h = h * atten_matrix.view(n,c,1) + # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + # h = torch.sum(h,dim=1) + # # print(h.size()) # [batch_size, 2*dim] + # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w) + # # x = self.bn0(h) + # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] + # # x = self.bn1(x) + # # x = F.relu(x) + # # x = self.feature_drop(x) + # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] + # # x = self.fc(x) # [batch_size, embed_dim] + # # x = self.hidden_drop(x) + # # x = self.bn2(x) + # # x = F.relu(x) + # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num] + # # x += self.bias_rel.expand_as(x) + # # # score = torch.sigmoid(x) + # # score = x + # return score + # print(h.size()) + + def cross_pair(self, x_i, x_j): + x = [] + for i in range(self.n_layer): + for j in range(self.n_layer): + if self.combine_type == 'mult': + x.append(x_i[:, i, :] * x_j[:, j, :]) + elif self.combine_type == 'concat': + x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1)) + x = torch.stack(x, dim=1) + return x + + def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + atten_matrix = subgraph_sampler(h, mode) + return torch.sum(atten_matrix, dim=0) + + def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0') + for i in range(h.size(0)): + atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1 + return torch.sum(atten_matrix, dim=0) + + +def get_pad_mask(seq, pad_idx): + return (seq == pad_idx) + + +def pad_batch(h_node, bg, max_input_len, subj, obj): + subj_list = [] + obj_list = [] + batch_size = bg.batch_size + max_num_nodes = min(max(bg.batch_num_nodes()).item(), max_input_len) + padded_h_node = h_node.data.new(max_num_nodes, batch_size, h_node.size(-1)).fill_(0) + src_padding_mask = h_node.data.new(batch_size, max_num_nodes).fill_(0).bool() + for batch_idx, g in enumerate(dgl.unbatch(bg)): + num_nodes = g.num_nodes() + padded_h_node[-num_nodes:, batch_idx] = g.ndata['h'] + src_padding_mask[batch_idx, : max_num_nodes - num_nodes] = True + subj_idx = torch.where(g.ndata[NID] == subj[batch_idx])[0] + max_num_nodes - num_nodes + obj_idx = torch.where(g.ndata[NID] == obj[batch_idx])[0] + max_num_nodes - num_nodes + subj_list.append(subj_idx) + obj_list.append(obj_idx) + subj_idx = torch.cat(subj_list) + obj_idx = torch.cat(obj_list) + return padded_h_node, src_padding_mask, subj_idx, obj_idx diff --git a/model/genotypes.py b/model/genotypes.py new file mode 100644 index 0000000..abce99e --- /dev/null +++ b/model/genotypes.py @@ -0,0 +1,23 @@ +COMP_PRIMITIVES = [ + 'sub', + 'mult', + 'ccorr', + 'rotate' +] + +AGG_PRIMITIVES = [ + 'mean', + 'sum', + 'max', +] + +COMB_PRIMITIVES = [ + 'mlp', + 'concat' +] + +ACT_PRIMITIVES = [ + 'identity', + 'relu', + 'tanh' +] \ No newline at end of file diff --git a/model/lte_models.py b/model/lte_models.py new file mode 100644 index 0000000..ca193e3 --- /dev/null +++ b/model/lte_models.py @@ -0,0 +1,171 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +def get_param(shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param.data) + return param + + +class LTEModel(nn.Module): + def __init__(self, num_ents, num_rels, params=None): + super(LTEModel, self).__init__() + + self.bceloss = torch.nn.BCELoss() + + self.p = params + self.init_embed = get_param((num_ents, self.p.init_dim)) + self.device = "cuda" + + self.init_rel = get_param((num_rels * 2, self.p.init_dim)) + + self.bias = nn.Parameter(torch.zeros(num_ents)) + + self.h_ops_dict = nn.ModuleDict({ + 'p': nn.Linear(self.p.init_dim, self.p.gcn_dim, bias=False), + 'b': nn.BatchNorm1d(self.p.gcn_dim), + 'd': nn.Dropout(self.p.hid_drop), + 'a': nn.Tanh(), + }) + + self.t_ops_dict = nn.ModuleDict({ + 'p': nn.Linear(self.p.init_dim, self.p.gcn_dim, bias=False), + 'b': nn.BatchNorm1d(self.p.gcn_dim), + 'd': nn.Dropout(self.p.hid_drop), + 'a': nn.Tanh(), + }) + + self.r_ops_dict = nn.ModuleDict({ + 'p': nn.Linear(self.p.init_dim, self.p.gcn_dim, bias=False), + 'b': nn.BatchNorm1d(self.p.gcn_dim), + 'd': nn.Dropout(self.p.hid_drop), + 'a': nn.Tanh(), + }) + + self.x_ops = self.p.x_ops + self.r_ops = self.p.r_ops + self.diff_ht = False + + def calc_loss(self, pred, label): + return self.loss(pred, label) + + def loss(self, pred, true_label): + return self.bceloss(pred, true_label) + + def exop(self, x, r, x_ops=None, r_ops=None, diff_ht=False): + x_head = x_tail = x + if len(x_ops) > 0: + for x_op in x_ops.split("."): + if diff_ht: + x_head = self.h_ops_dict[x_op](x_head) + x_tail = self.t_ops_dict[x_op](x_tail) + else: + x_head = x_tail = self.h_ops_dict[x_op](x_head) + + if len(r_ops) > 0: + for r_op in r_ops.split("."): + r = self.r_ops_dict[r_op](r) + + return x_head, x_tail, r + + +class TransE(LTEModel): + def __init__(self, num_ents, num_rels, params=None): + super(self.__class__, self).__init__(num_ents, num_rels, params) + self.loop_emb = get_param([1, self.p.init_dim]) + + def forward(self, g, sub, rel): + x = self.init_embed + r = self.init_rel + + x_h, x_t, r = self.exop(x - self.loop_emb, r, self.x_ops, self.r_ops) + + sub_emb = torch.index_select(x_h, 0, sub) + rel_emb = torch.index_select(r, 0, rel) + all_ent = x_t + + obj_emb = sub_emb + rel_emb + x = self.p.gamma - \ + torch.norm(obj_emb.unsqueeze(1) - all_ent, p=1, dim=2) + score = torch.sigmoid(x) + + return score + + +class DistMult(LTEModel): + def __init__(self, num_ents, num_rels, params=None): + super(self.__class__, self).__init__(num_ents, num_rels, params) + + def forward(self, g, sub, rel): + x = self.init_embed + r = self.init_rel + + x_h, x_t, r = self.exop(x, r, self.x_ops, self.r_ops) + + sub_emb = torch.index_select(x_h, 0, sub) + rel_emb = torch.index_select(r, 0, rel) + all_ent = x_t + + obj_emb = sub_emb * rel_emb + x = torch.mm(obj_emb, all_ent.transpose(1, 0)) + x += self.bias.expand_as(x) + score = torch.sigmoid(x) + + return score + + +class ConvE(LTEModel): + def __init__(self, num_ents, num_rels, params=None): + super(self.__class__, self).__init__(num_ents, num_rels, params) + self.bn0 = torch.nn.BatchNorm2d(1) + self.bn1 = torch.nn.BatchNorm2d(self.p.num_filt) + self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) + + self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) + self.hidden_drop2 = torch.nn.Dropout(self.p.conve_hid_drop) + self.feature_drop = torch.nn.Dropout(self.p.feat_drop) + self.m_conv1 = torch.nn.Conv2d(1, out_channels=self.p.num_filt, kernel_size=(self.p.ker_sz, self.p.ker_sz), + stride=1, padding=0, bias=self.p.bias) + + flat_sz_h = int(2 * self.p.k_w) - self.p.ker_sz + 1 + flat_sz_w = self.p.k_h - self.p.ker_sz + 1 + self.flat_sz = flat_sz_h * flat_sz_w * self.p.num_filt + self.fc = torch.nn.Linear(self.flat_sz, self.p.embed_dim) + + def concat(self, e1_embed, rel_embed): + e1_embed = e1_embed.view(-1, 1, self.p.embed_dim) + rel_embed = rel_embed.view(-1, 1, self.p.embed_dim) + stack_inp = torch.cat([e1_embed, rel_embed], 1) + stack_inp = torch.transpose(stack_inp, 2, 1).reshape( + (-1, 1, 2 * self.p.k_w, self.p.k_h)) + return stack_inp + + def forward(self, g, sub, rel): + x = self.init_embed + r = self.init_rel + + x_h, x_t, r = self.exop(x, r, self.x_ops, self.r_ops) + + sub_emb = torch.index_select(x_h, 0, sub) + rel_emb = torch.index_select(r, 0, rel) + all_ent = x_t + + stk_inp = self.concat(sub_emb, rel_emb) + x = self.bn0(stk_inp) + x = self.m_conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.feature_drop(x) + x = x.view(-1, self.flat_sz) + x = self.fc(x) + x = self.hidden_drop2(x) + x = self.bn2(x) + x = F.relu(x) + + x = torch.mm(x, all_ent.transpose(1, 0)) + x += self.bias.expand_as(x) + + score = torch.sigmoid(x) + return score diff --git a/model/model.py b/model/model.py new file mode 100644 index 0000000..a99da4b --- /dev/null +++ b/model/model.py @@ -0,0 +1,262 @@ +import torch +from torch import nn +import dgl +import dgl.function as fn +from torch.autograd import Variable +from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES +from model.operations import * +from model.fune_layer import SearchedGCNConv +from torch.nn.functional import softmax +from pprint import pprint + + +class NetworkGNN(nn.Module): + def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + conv_bias=True, gcn_drop=0., wni=False, wsi=False, use_bn=True, ltr=True, loss_type='ce', genotype=None): + super(NetworkGNN, self).__init__() + self.act = torch.tanh + self.args = args + self.loss_type = loss_type + self.num_ent, self.num_rel, self.num_base = num_ent, num_rel, num_base + self.init_dim, self.gcn_dim, self.embed_dim = init_dim, gcn_dim, embed_dim + + self.gcn_drop = gcn_drop + self.edge_type = edge_type # [E] + self.edge_norm = edge_norm # [E] + self.n_layer = n_layer + + self.init_embed = self.get_param([self.num_ent + 1, self.init_dim]) + self.init_rel = self.get_param([self.num_rel, self.init_dim]) + self.bias_rel = nn.Parameter(torch.zeros(self.num_rel)) + + self._initialize_loss() + self.gnn_layers = nn.ModuleList() + ops = genotype.split('||') + for idx in range(self.args.n_layer): + if idx == 0: + self.gnn_layers.append( + SearchedGCNConv(self.init_dim, self.gcn_dim, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + comp=ops[4*idx], agg=ops[4*idx+1], comb=ops[4*idx+2], act=ops[4*idx+3])) + elif idx == self.args.n_layer-1: + self.gnn_layers.append( + SearchedGCNConv(self.gcn_dim, self.embed_dim, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + comp=ops[4*idx], agg=ops[4*idx+1], comb=ops[4*idx+2], act=ops[4*idx+3])) + else: + self.gnn_layers.append( + SearchedGCNConv(self.gcn_dim, self.gcn_dim, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + comp=ops[4*idx], agg=ops[4*idx+1], comb=ops[4*idx+2], act=ops[4*idx+3])) + + def get_param(self, shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu')) + return param + + + def _initialize_loss(self): + if self.loss_type == 'ce': + self.loss = nn.CrossEntropyLoss() + elif self.loss_type == 'bce': + self.loss = nn.BCELoss(reduce=False) + elif self.loss_type == 'bce_logits': + self.loss = nn.BCEWithLogitsLoss() + else: + raise NotImplementedError + + def calc_loss(self, pred, label, pos_neg=None): + if pos_neg is not None: + m = nn.Sigmoid() + score_pos = m(pred) + targets_pos = pos_neg.unsqueeze(1) + loss = self.loss(score_pos, label * targets_pos) + return torch.sum(loss * label) + return self.loss(pred, label) + + def forward_base(self, g, subj, obj, drop1, drop2, mode=None): + x, r = self.init_embed, self.init_rel # embedding of relations + + for i, layer in enumerate(self.gnn_layers): + if i != self.args.n_layer-1: + x, r = layer(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) + else: + x, r = layer(g, x, r, self.edge_type, self.edge_norm) + x = drop2(x) + + sub_emb = torch.index_select(x, 0, subj) + # filter out embeddings of objects in this batch + obj_emb = torch.index_select(x, 0, obj) + + return sub_emb, obj_emb, x, r + + def forward_base_search(self, g, drop1, drop2): + x, r = self.init_embed, self.init_rel # embedding of relations + x_hidden = [] + for i, layer in enumerate(self.gnn_layers): + if i != self.args.n_layer-1: + x, r = layer(g, x, r, self.edge_type, self.edge_norm) + x_hidden.append(x) + x = drop1(x) + else: + x, r = layer(g, x, r, self.edge_type, self.edge_norm) + x_hidden.append(x) + x = drop2(x) + x_hidden = torch.stack(x_hidden, dim=1) + + return x_hidden, x + + +class SearchedGCN_MLP(NetworkGNN): + def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + bias=True, gcn_drop=0., hid_drop=0., input_drop=0., wni=False, wsi=False, use_bn=True, + ltr=True, combine_type='mult', loss_type='ce', genotype=None): + """ + :param num_ent: number of entities + :param num_rel: number of different relations + :param num_base: number of bases to use + :param init_dim: initial dimension + :param gcn_dim: dimension after first layer + :param embed_dim: dimension after second layer + :param n_layer: number of layer + :param edge_type: relation type of each edge, [E] + :param bias: weather to add bias + :param gcn_drop: dropout rate in compgcncov + :param opn: combination operator + :param hid_drop: gcn output (embedding of each entity) dropout + :param input_drop: dropout in conve input + :param conve_hid_drop: dropout in conve hidden layer + :param feat_drop: feature dropout in conve + :param num_filt: number of filters in conv2d + :param ker_sz: kernel size in conv2d + :param k_h: height of 2D reshape + :param k_w: width of 2D reshape + """ + super(SearchedGCN_MLP, self).__init__(args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, + edge_type, edge_norm, bias, gcn_drop, wni, wsi, use_bn, ltr, loss_type, genotype) + self.hid_drop, self.input_drop = hid_drop, input_drop + + self.num_rel = num_rel + + self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout + self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout + + self.combine_type = combine_type + if self.combine_type == 'concat': + self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel) + elif self.combine_type == 'mult': + self.fc = torch.nn.Linear(self.embed_dim, self.num_rel) + + def forward(self, g, subj, obj, mode=None): + """ + :param g: dgl graph + :param sub: subject in batch [batch_size] + :param rel: relation in batch [batch_size] + :return: score: [batch_size, ent_num], the prob in link-prediction + """ + + sub_embs, obj_embs, all_ent, all_rel = self.forward_base(g, subj, obj, self.drop, self.input_drop, mode) + if self.combine_type == 'concat': + edge_embs = torch.concat([sub_embs, obj_embs], dim=1) + elif self.combine_type == 'mult': + edge_embs = sub_embs * obj_embs + else: + raise NotImplementedError + score = self.fc(edge_embs) + return score + + def forward_search(self, g, mode='allgraph'): + # if mode == 'allgraph': + hidden_all_ent, all_ent = self.forward_base_search( + g, self.drop, self.input_drop) + # elif mode == 'subgraph': + # hidden_all_ent = self.forward_base_subgraph_search( + # g, self.drop, self.input_drop) + + return hidden_all_ent, all_ent + + def compute_pred(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = subgraph_sampler(h, mode, search_algorithm) + # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # print(subgraph_sampler(h,mode='argmax')) + n, c = atten_matrix.shape + h = h * atten_matrix.view(n, c, 1) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + h = torch.sum(h, dim=1) + # print(h.size()) # [batch_size, 2*dim] + score = self.fc(h) + return score + + # def fine_tune_with_implicit_subgraph(self, all_ent, subgraph, subj, obj): + # sg_list = [] + # for idx in range(subgraph.size(0)): + # sg_list.append(torch.mean(all_ent[subgraph[idx,:]], dim=0).unsqueeze(0)) + # sg_embs = torch.concat(sg_list) + # # print(sg_embs.size()) + # sub_embs = torch.index_select(all_ent, 0, subj) + # # print(sub_embs.size()) + # # filter out embeddings of relations in this batch + # obj_embs = torch.index_select(all_ent, 0, obj) + # # print(obj_embs.size()) + # edge_embs = torch.concat([sub_embs, obj_embs, sg_embs], dim=1) + # score = self.predictor(edge_embs) + # # print(F.embedding(subgraph, all_ent)) + # return score + + # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops): + # h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0') + # for i in range(h.size(0)): + # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1 + # # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # # print(subgraph_sampler(h,mode='argmax')) + # n, c = atten_matrix.shape + # h = h * atten_matrix.view(n,c,1) + # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + # h = torch.sum(h,dim=1) + # # print(h.size()) # [batch_size, 2*dim] + # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w) + # # x = self.bn0(h) + # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] + # # x = self.bn1(x) + # # x = F.relu(x) + # # x = self.feature_drop(x) + # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] + # # x = self.fc(x) # [batch_size, embed_dim] + # # x = self.hidden_drop(x) + # # x = self.bn2(x) + # # x = F.relu(x) + # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num] + # # x += self.bias_rel.expand_as(x) + # # # score = torch.sigmoid(x) + # # score = x + # return score + # print(h.size()) + + def cross_pair(self, x_i, x_j): + x = [] + for i in range(self.n_layer): + for j in range(self.n_layer): + if self.combine_type == 'mult': + x.append(x_i[:, i, :] * x_j[:, j, :]) + elif self.combine_type == 'concat': + x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1)) + x = torch.stack(x, dim=1) + return x + + def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + atten_matrix = subgraph_sampler(h, mode, search_algorithm) + return torch.sum(atten_matrix, dim=0) + + def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0') + for i in range(h.size(0)): + atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1 + return torch.sum(atten_matrix, dim=0) diff --git a/model/model_fast.py b/model/model_fast.py new file mode 100644 index 0000000..193743b --- /dev/null +++ b/model/model_fast.py @@ -0,0 +1,284 @@ +import torch +from torch import nn +import dgl +import dgl.function as fn +from torch.autograd import Variable +from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES +from model.operations import * +from model.fune_layer import SearchedGCNConv +from torch.nn.functional import softmax +from pprint import pprint + + +class NetworkGNN_MLP(nn.Module): + def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + conv_bias=True, gcn_drop=0., hid_drop=0., input_drop=0, wni=False, wsi=False, use_bn=True, ltr=True, + combine_type='mult', loss_type='ce', genotype=None): + super(NetworkGNN_MLP, self).__init__() + self.act = torch.tanh + self.args = args + self.loss_type = loss_type + self.num_ent, self.num_rel, self.num_base = num_ent, num_rel, num_base + self.init_dim, self.gcn_dim, self.embed_dim = init_dim, gcn_dim, embed_dim + self.hid_drop, self.input_drop = hid_drop, input_drop + self.gcn_drop = gcn_drop + self.edge_type = edge_type # [E] + self.edge_norm = edge_norm # [E] + self.n_layer = n_layer + + self.init_embed = self.get_param([self.num_ent + 1, self.init_dim]) + self.init_rel = self.get_param([self.num_rel, self.init_dim]) + self.bias_rel = nn.Parameter(torch.zeros(self.num_rel)) + + self._initialize_loss() + self.gnn_layers = nn.ModuleList() + ops = genotype.split('||') + for idx in range(self.args.n_layer): + if idx == 0: + self.gnn_layers.append( + SearchedGCNConv(self.init_dim, self.gcn_dim, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + comp=ops[4*idx], agg=ops[4*idx+1], comb=ops[4*idx+2], act=ops[4*idx+3])) + elif idx == self.args.n_layer-1: + self.gnn_layers.append( + SearchedGCNConv(self.gcn_dim, self.embed_dim, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + comp=ops[4*idx], agg=ops[4*idx+1], comb=ops[4*idx+2], act=ops[4*idx+3])) + else: + self.gnn_layers.append( + SearchedGCNConv(self.gcn_dim, self.gcn_dim, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + comp=ops[4*idx], agg=ops[4*idx+1], comb=ops[4*idx+2], act=ops[4*idx+3])) + + self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout + self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout + + self.combine_type = combine_type + if self.combine_type == 'concat': + self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel) + elif self.combine_type == 'mult': + self.fc = torch.nn.Linear(self.embed_dim, self.num_rel) + + def get_param(self, shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu')) + return param + + + def _initialize_loss(self): + if self.loss_type == 'ce': + self.loss = nn.CrossEntropyLoss() + elif self.loss_type == 'bce': + self.loss = nn.BCELoss() + elif self.loss_type == 'bce_logits': + self.loss = nn.BCEWithLogitsLoss() + else: + raise NotImplementedError + + def calc_loss(self, pred, label): + return self.loss(pred, label) + + def forward_base(self, g, subj, obj, drop1, drop2, mode=None): + x, r = self.init_embed, self.init_rel # embedding of relations + + for i, layer in enumerate(self.gnn_layers): + if i != self.args.n_layer-1: + x, r = layer(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) + else: + x, r = layer(g, x, r, self.edge_type, self.edge_norm) + x = drop2(x) + + sub_emb = torch.index_select(x, 0, subj) + # filter out embeddings of objects in this batch + obj_emb = torch.index_select(x, 0, obj) + + return sub_emb, obj_emb, x, r + + def forward_base_search(self, g, drop1, drop2): + x, r = self.init_embed, self.init_rel # embedding of relations + x_hidden = [] + for i, layer in enumerate(self.gnn_layers): + if i != self.args.n_layer-1: + x, r = layer(g, x, r, self.edge_type, self.edge_norm) + x_hidden.append(x) + x = drop1(x) + else: + x, r = layer(g, x, r, self.edge_type, self.edge_norm) + x_hidden.append(x) + x = drop2(x) + x_hidden = torch.stack(x_hidden, dim=1) + + return x_hidden, x + + def forward(self, g, subj, obj, mode=None): + """ + :param g: dgl graph + :param sub: subject in batch [batch_size] + :param rel: relation in batch [batch_size] + :return: score: [batch_size, ent_num], the prob in link-prediction + """ + + sub_embs, obj_embs, all_ent, all_rel = self.forward_base(g, subj, obj, self.drop, self.input_drop, mode) + if self.combine_type == 'concat': + edge_embs = torch.concat([sub_embs, obj_embs], dim=1) + elif self.combine_type == 'mult': + edge_embs = sub_embs * obj_embs + else: + raise NotImplementedError + score = self.fc(edge_embs) + return score + + +# class SearchedGCN_MLP(NetworkGNN): +# def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, +# bias=True, gcn_drop=0., hid_drop=0., input_drop=0., wni=False, wsi=False, use_bn=True, +# ltr=True, combine_type='mult', loss_type='ce', genotype=None): +# """ +# :param num_ent: number of entities +# :param num_rel: number of different relations +# :param num_base: number of bases to use +# :param init_dim: initial dimension +# :param gcn_dim: dimension after first layer +# :param embed_dim: dimension after second layer +# :param n_layer: number of layer +# :param edge_type: relation type of each edge, [E] +# :param bias: weather to add bias +# :param gcn_drop: dropout rate in compgcncov +# :param opn: combination operator +# :param hid_drop: gcn output (embedding of each entity) dropout +# :param input_drop: dropout in conve input +# :param conve_hid_drop: dropout in conve hidden layer +# :param feat_drop: feature dropout in conve +# :param num_filt: number of filters in conv2d +# :param ker_sz: kernel size in conv2d +# :param k_h: height of 2D reshape +# :param k_w: width of 2D reshape +# """ +# super(SearchedGCN_MLP, self).__init__(args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, +# edge_type, edge_norm, bias, gcn_drop, wni, wsi, use_bn, ltr, loss_type, genotype) +# self.hid_drop, self.input_drop = hid_drop, input_drop +# +# self.num_rel = num_rel +# +# self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout +# self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout +# +# self.combine_type = combine_type +# if self.combine_type == 'concat': +# self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel) +# elif self.combine_type == 'mult': +# self.fc = torch.nn.Linear(self.embed_dim, self.num_rel) +# +# def forward(self, g, subj, obj, mode=None): +# """ +# :param g: dgl graph +# :param sub: subject in batch [batch_size] +# :param rel: relation in batch [batch_size] +# :return: score: [batch_size, ent_num], the prob in link-prediction +# """ +# +# sub_embs, obj_embs, all_ent, all_rel = self.forward_base(g, subj, obj, self.drop, self.input_drop, mode) +# if self.combine_type == 'concat': +# edge_embs = torch.concat([sub_embs, obj_embs], dim=1) +# elif self.combine_type == 'mult': +# edge_embs = sub_embs * obj_embs +# else: +# raise NotImplementedError +# score = self.fc(edge_embs) +# return score +# +# def forward_search(self, g, mode='allgraph'): +# # if mode == 'allgraph': +# hidden_all_ent, all_ent = self.forward_base_rel_search( +# g, self.drop, self.input_drop) +# # elif mode == 'subgraph': +# # hidden_all_ent = self.forward_base_subgraph_search( +# # g, self.drop, self.input_drop) +# +# return hidden_all_ent, all_ent +# +# def compute_pred(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'): +# h = self.cross_pair(hidden_x[subj], hidden_x[obj]) +# # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] +# atten_matrix = subgraph_sampler(h, mode, search_algorithm) +# # print(atten_matrix.size()) # [batch_size, encoder_layer^2] +# # print(subgraph_sampler(h,mode='argmax')) +# n, c = atten_matrix.shape +# h = h * atten_matrix.view(n, c, 1) +# # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] +# h = torch.sum(h, dim=1) +# # print(h.size()) # [batch_size, 2*dim] +# score = self.fc(h) +# return score +# +# # def fine_tune_with_implicit_subgraph(self, all_ent, subgraph, subj, obj): +# # sg_list = [] +# # for idx in range(subgraph.size(0)): +# # sg_list.append(torch.mean(all_ent[subgraph[idx,:]], dim=0).unsqueeze(0)) +# # sg_embs = torch.concat(sg_list) +# # # print(sg_embs.size()) +# # sub_embs = torch.index_select(all_ent, 0, subj) +# # # print(sub_embs.size()) +# # # filter out embeddings of relations in this batch +# # obj_embs = torch.index_select(all_ent, 0, obj) +# # # print(obj_embs.size()) +# # edge_embs = torch.concat([sub_embs, obj_embs, sg_embs], dim=1) +# # score = self.predictor(edge_embs) +# # # print(F.embedding(subgraph, all_ent)) +# # return score +# +# # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops): +# # h = self.cross_pair(hidden_x[subj], hidden_x[obj]) +# # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] +# # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0') +# # for i in range(h.size(0)): +# # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1 +# # # print(atten_matrix.size()) # [batch_size, encoder_layer^2] +# # # print(subgraph_sampler(h,mode='argmax')) +# # n, c = atten_matrix.shape +# # h = h * atten_matrix.view(n,c,1) +# # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] +# # h = torch.sum(h,dim=1) +# # # print(h.size()) # [batch_size, 2*dim] +# # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w) +# # # x = self.bn0(h) +# # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] +# # # x = self.bn1(x) +# # # x = F.relu(x) +# # # x = self.feature_drop(x) +# # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] +# # # x = self.fc(x) # [batch_size, embed_dim] +# # # x = self.hidden_drop(x) +# # # x = self.bn2(x) +# # # x = F.relu(x) +# # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num] +# # # x += self.bias_rel.expand_as(x) +# # # # score = torch.sigmoid(x) +# # # score = x +# # return score +# # print(h.size()) +# +# def cross_pair(self, x_i, x_j): +# x = [] +# for i in range(self.n_layer): +# for j in range(self.n_layer): +# if self.combine_type == 'mult': +# x.append(x_i[:, i, :] * x_j[:, j, :]) +# elif self.combine_type == 'concat': +# x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1)) +# x = torch.stack(x, dim=1) +# return x +# +# def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search'): +# h = self.cross_pair(hidden_x[subj], hidden_x[obj]) +# atten_matrix = subgraph_sampler(h, mode) +# return torch.sum(atten_matrix, dim=0) +# +# def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops): +# h = self.cross_pair(hidden_x[subj], hidden_x[obj]) +# # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] +# atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0') +# for i in range(h.size(0)): +# atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1 +# return torch.sum(atten_matrix, dim=0) diff --git a/model/model_search.py b/model/model_search.py new file mode 100644 index 0000000..da6a662 --- /dev/null +++ b/model/model_search.py @@ -0,0 +1,448 @@ +import torch +from torch import nn +import dgl +import dgl.function as fn +from torch.autograd import Variable +from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES +from model.operations import * +from model.search_layer import SearchGCNConv +from torch.nn.functional import softmax +from pprint import pprint + + +class CompMixOP(nn.Module): + def __init__(self): + super(CompMixOP, self).__init__() + self._ops = nn.ModuleList() + for primitive in COMP_PRIMITIVES: + op = COMP_OPS[primitive]() + self._ops.append(op) + + def reset_parameters(self): + for op in self._ops: + op.reset_parameters() + + def forward(self, src_emb, rel_emb, weights): + mixed_res = [] + for w, op in zip(weights, self._ops): + mixed_res.append(w * op(src_emb, rel_emb)) + return sum(mixed_res) + + +class CompOp(nn.Module): + + def __init__(self, op_name): + super(CompOp, self).__init__() + self.op = COMP_OPS[op_name]() + + def reset_parameters(self): + self.op.reset_parameters() + + def forward(self, src_emb, rel_emb): + return self.op(src_emb, rel_emb) + + +class Network(nn.Module): + def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + conv_bias=True, gcn_drop=0., wni=False, wsi=False, use_bn=True, ltr=True, loss_type='ce'): + super(Network, self).__init__() + self.act = torch.tanh + self.args = args + self.loss_type = loss_type + self.num_ent, self.num_rel, self.num_base = num_ent, num_rel, num_base + self.init_dim, self.gcn_dim, self.embed_dim = init_dim, gcn_dim, embed_dim + + self.gcn_drop = gcn_drop + self.edge_type = edge_type # [E] + self.edge_norm = edge_norm # [E] + self.n_layer = n_layer + + self.init_embed = self.get_param([self.num_ent + 1, self.init_dim]) + self.init_rel = self.get_param([self.num_rel, self.init_dim]) + self.bias_rel = nn.Parameter(torch.zeros(self.num_rel)) + + self._initialize_alphas() + self._initialize_loss() + self.gnn_layers = nn.ModuleList() + for idx in range(self.args.n_layer): + if idx == 0: + self.gnn_layers.append( + SearchGCNConv(self.init_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + comp_weights=self.comp_alphas[idx])) + elif idx == self.args.n_layer-1: + self.gnn_layers.append(SearchGCNConv(self.gcn_dim, self.embed_dim, self.act, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + comp_weights=self.comp_alphas[idx])) + else: + self.gnn_layers.append( + SearchGCNConv(self.gcn_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr, + comp_weights=self.comp_alphas[idx])) + + def arch_parameters(self): + return self._arch_parameters + + def genotype(self): + gene = [] + comp_max, comp_indices = torch.max(softmax(self.comp_alphas, dim=-1).data.cpu(), dim=-1) + agg_max, agg_indices = torch.max(softmax(self.agg_alphas, dim=-1).data.cpu(), dim=-1) + comb_max, comb_indices = torch.max(softmax(self.comb_alphas, dim=-1).data.cpu(), dim=-1) + act_max, act_indices = torch.max(softmax(self.act_alphas, dim=-1).data.cpu(), dim=-1) + pprint(comp_max) + pprint(agg_max) + pprint(comb_max) + pprint(act_max) + for idx in range(self.args.n_layer): + gene.append(COMP_PRIMITIVES[comp_indices[idx]]) + gene.append(AGG_PRIMITIVES[agg_indices[idx]]) + gene.append(COMB_PRIMITIVES[comb_indices[idx]]) + gene.append(ACT_PRIMITIVES[act_indices[idx]]) + return "||".join(gene) + + # self.in_channels = in_channels + # self.out_channels = out_channels + # self.act = act # activation function + # self.device = None + # if add_reverse: + # self.rel = nn.Parameter(torch.empty([num_rel * 2, in_channels], dtype=torch.float)) + # else: + # self.rel = nn.Parameter(torch.empty([num_rel, in_channels], dtype=torch.float)) + # self.opn = opn + # + # self.use_bn = use_bn + # self.ltr = ltr + # + # # relation-type specific parameter + # self.in_w = self.get_param([in_channels, out_channels]) + # self.out_w = self.get_param([in_channels, out_channels]) + # self.loop_w = self.get_param([in_channels, out_channels]) + # # transform embedding of relations to next layer + # self.w_rel = self.get_param([in_channels, out_channels]) + # self.loop_rel = self.get_param([1, in_channels]) # self-loop embedding + # + # self.drop = nn.Dropout(drop_rate) + # self.bn = torch.nn.BatchNorm1d(out_channels) + # self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None + + # if num_base > 0: + # if add_reverse: + # self.rel_wt = self.get_param([num_rel * 2, num_base]) + # else: + # self.rel_wt = self.get_param([num_rel, num_base]) + # else: + # self.rel_wt = None + # + # self.wni = wni + # self.wsi = wsi + + def get_param(self, shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu')) + return param + + def _initialize_alphas(self): + comp_ops_num = len(COMP_PRIMITIVES) + agg_ops_num = len(AGG_PRIMITIVES) + comb_ops_num = len(COMB_PRIMITIVES) + act_ops_num = len(ACT_PRIMITIVES) + if self.args.search_algorithm == "darts": + self.comp_alphas = Variable(1e-3 * torch.randn(self.args.n_layer, comp_ops_num).cuda(), requires_grad=True) + self.agg_alphas = Variable(1e-3 * torch.randn(self.args.n_layer, agg_ops_num).cuda(), requires_grad=True) + self.comb_alphas = Variable(1e-3 * torch.randn(self.args.n_layer, comb_ops_num).cuda(), requires_grad=True) + self.act_alphas = Variable(1e-3 * torch.randn(self.args.n_layer, act_ops_num).cuda(), requires_grad=True) + elif self.args.search_algorithm == "snas": + # self.comp_alphas = Variable(1e-3 * torch.randn(self.args.n_layer, comp_ops_num).cuda(), requires_grad=True) + self.comp_alphas = Variable( + torch.ones(self.args.n_layer, comp_ops_num).normal_(self.args.loc_mean, self.args.loc_std).cuda(), + requires_grad=True) + self.agg_alphas = Variable( + torch.ones(self.args.n_layer, agg_ops_num).normal_(self.args.loc_mean, self.args.loc_std).cuda(), + requires_grad=True) + self.comb_alphas = Variable( + torch.ones(self.args.n_layer, comb_ops_num).normal_(self.args.loc_mean, self.args.loc_std).cuda(), + requires_grad=True) + self.act_alphas = Variable( + torch.ones(self.args.n_layer, act_ops_num).normal_(self.args.loc_mean, self.args.loc_std).cuda(), + requires_grad=True) + # self.la_alphas = Variable(torch.ones(1, la_ops_num).normal_(self.args.loc_mean, self.args.loc_std).cuda(), + # requires_grad=True) + # self.seq_alphas = Variable( + # torch.ones(self.layer_num, seq_ops_num).normal_(self.args.loc_mean, self.args.loc_std).cuda(), + # requires_grad=True) + else: + raise NotImplementedError + self._arch_parameters = [ + self.comp_alphas, + self.agg_alphas, + self.comb_alphas, + self.act_alphas + ] + + def _initialize_loss(self): + if self.loss_type == 'ce': + self.loss = nn.CrossEntropyLoss() + elif self.loss_type == 'bce': + self.loss = nn.BCELoss(reduce=False) + elif self.loss_type == 'bce_logits': + self.loss = nn.BCEWithLogitsLoss() + else: + raise NotImplementedError + + def calc_loss(self, pred, label, pos_neg=None): + if pos_neg is not None: + m = nn.Sigmoid() + score_pos = m(pred) + targets_pos = pos_neg.unsqueeze(1) + loss = self.loss(score_pos, label * targets_pos) + return torch.sum(loss * label) + return self.loss(pred, label) + + def _get_categ_mask(self, alpha): + # log_alpha = torch.log(alpha) + log_alpha = alpha + u = torch.zeros_like(log_alpha).uniform_() + softmax = torch.nn.Softmax(-1) + one_hot = softmax((log_alpha + (-((-(u.log())).log()))) / self.args.temperature) + return one_hot + + def _get_categ_mask_new(self, alpha): + # log_alpha = torch.log(alpha) + print(alpha) + m = torch.distributions.relaxed_categorical.RelaxedOneHotCategorical( + torch.tensor([self.args.temperature]).cuda(), alpha) + print(m.sample()) + print(m.rsample()) + + def get_one_hot_alpha(self, alpha): + one_hot_alpha = torch.zeros_like(alpha, device=alpha.device) + idx = torch.argmax(alpha, dim=-1) + + for i in range(one_hot_alpha.size(0)): + one_hot_alpha[i, idx[i]] = 1.0 + + return one_hot_alpha + + def forward_base(self, g, subj, obj, drop1, drop2, mode=None): + if self.args.search_algorithm == "darts": + comp_weights = softmax(self.comp_alphas / self.args.temperature, dim=-1) + agg_weights = softmax(self.agg_alphas / self.args.temperature, dim=-1) + comb_weights = softmax(self.comb_alphas / self.args.temperature, dim=-1) + act_weights = softmax(self.act_alphas / self.args.temperature, dim=-1) + elif self.args.search_algorithm == "snas": + # comp_weights = self._get_categ_mask_new(self.comp_alphas) + comp_weights = self._get_categ_mask(self.comp_alphas) + agg_weights = self._get_categ_mask(self.agg_alphas) + comb_weights = self._get_categ_mask(self.comb_alphas) + act_weights = self._get_categ_mask(self.act_alphas) + else: + raise NotImplementedError + if mode == 'evaluate_single_path': + comp_weights = self.get_one_hot_alpha(comp_weights) + agg_weights = self.get_one_hot_alpha(agg_weights) + comb_weights = self.get_one_hot_alpha(comb_weights) + act_weights = self.get_one_hot_alpha(act_weights) + # weights = dict() + # weights['comp'] = comp_weights + x, r = self.init_embed, self.init_rel # embedding of relations + + for i, layer in enumerate(self.gnn_layers): + if i != self.args.n_layer-1: + x, r = layer(g, x, r, self.edge_type, self.edge_norm, comp_weights[i], agg_weights[i], comb_weights[i], act_weights[i]) + x = drop1(x) + else: + x, r = layer(g, x, r, self.edge_type, self.edge_norm, comp_weights[i], agg_weights[i], comb_weights[i], act_weights[i]) + x = drop2(x) + + sub_emb = torch.index_select(x, 0, subj) + # filter out embeddings of objects in this batch + obj_emb = torch.index_select(x, 0, obj) + + return sub_emb, obj_emb, x, r + + def forward_base_search(self, g, drop1, drop2, mode=None): + if self.args.search_algorithm == "darts": + comp_weights = softmax(self.comp_alphas / self.args.temperature, dim=-1) + agg_weights = softmax(self.agg_alphas / self.args.temperature, dim=-1) + comb_weights = softmax(self.comb_alphas / self.args.temperature, dim=-1) + act_weights = softmax(self.act_alphas / self.args.temperature, dim=-1) + elif self.args.search_algorithm == "snas": + comp_weights = self._get_categ_mask(self.comp_alphas) + agg_weights = self._get_categ_mask(self.agg_alphas) + comb_weights = self._get_categ_mask(self.comb_alphas) + act_weights = self._get_categ_mask(self.act_alphas) + else: + raise NotImplementedError + if mode == 'evaluate_single_path': + comp_weights = self.get_one_hot_alpha(comp_weights) + agg_weights = self.get_one_hot_alpha(agg_weights) + comb_weights = self.get_one_hot_alpha(comb_weights) + act_weights = self.get_one_hot_alpha(act_weights) + x, r = self.init_embed, self.init_rel # embedding of relations + x_hidden = [] + for i, layer in enumerate(self.gnn_layers): + if i != self.args.n_layer-1: + x, r = layer(g, x, r, self.edge_type, self.edge_norm, comp_weights[i], agg_weights[i], comb_weights[i], act_weights[i]) + x_hidden.append(x) + x = drop1(x) + else: + x, r = layer(g, x, r, self.edge_type, self.edge_norm, comp_weights[i], agg_weights[i], comb_weights[i], act_weights[i]) + x_hidden.append(x) + x = drop2(x) + x_hidden = torch.stack(x_hidden, dim=1) + + return x_hidden, x + + +class SearchGCN_MLP(Network): + def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + bias=True, gcn_drop=0., hid_drop=0., input_drop=0., wni=False, wsi=False, use_bn=True, + ltr=True, combine_type='mult', loss_type='ce'): + """ + :param num_ent: number of entities + :param num_rel: number of different relations + :param num_base: number of bases to use + :param init_dim: initial dimension + :param gcn_dim: dimension after first layer + :param embed_dim: dimension after second layer + :param n_layer: number of layer + :param edge_type: relation type of each edge, [E] + :param bias: weather to add bias + :param gcn_drop: dropout rate in compgcncov + :param opn: combination operator + :param hid_drop: gcn output (embedding of each entity) dropout + :param input_drop: dropout in conve input + :param conve_hid_drop: dropout in conve hidden layer + :param feat_drop: feature dropout in conve + :param num_filt: number of filters in conv2d + :param ker_sz: kernel size in conv2d + :param k_h: height of 2D reshape + :param k_w: width of 2D reshape + """ + super(SearchGCN_MLP, self).__init__(args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, + edge_type, edge_norm, bias, gcn_drop, wni, wsi, use_bn, ltr, loss_type) + self.hid_drop, self.input_drop = hid_drop, input_drop + + self.num_rel = num_rel + + self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout + self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout + + self.combine_type = combine_type + if self.combine_type == 'concat': + self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel) + elif self.combine_type == 'mult': + self.fc = torch.nn.Linear(self.embed_dim, self.num_rel) + + def forward(self, g, subj, obj, mode=None): + """ + :param g: dgl graph + :param sub: subject in batch [batch_size] + :param rel: relation in batch [batch_size] + :return: score: [batch_size, ent_num], the prob in link-prediction + """ + + sub_embs, obj_embs, all_ent, all_rel = self.forward_base(g, subj, obj, self.drop, self.input_drop, mode) + if self.combine_type == 'concat': + edge_embs = torch.concat([sub_embs, obj_embs], dim=1) + elif self.combine_type == 'mult': + edge_embs = sub_embs * obj_embs + else: + raise NotImplementedError + score = self.fc(edge_embs) + return score + + def forward_search(self, g, mode=None): + # if mode == 'allgraph': + hidden_all_ent, all_ent = self.forward_base_search( + g, self.drop, self.input_drop, mode) + # elif mode == 'subgraph': + # hidden_all_ent = self.forward_base_subgraph_search( + # g, self.drop, self.input_drop) + + return hidden_all_ent, all_ent + + def compute_pred(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = subgraph_sampler(h, mode, search_algorithm) + # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # print(subgraph_sampler(h,mode='argmax')) + n, c = atten_matrix.shape + h = h * atten_matrix.view(n, c, 1) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + h = torch.sum(h, dim=1) + # print(h.size()) # [batch_size, 2*dim] + score = self.fc(h) + return score + + # def fine_tune_with_implicit_subgraph(self, all_ent, subgraph, subj, obj): + # sg_list = [] + # for idx in range(subgraph.size(0)): + # sg_list.append(torch.mean(all_ent[subgraph[idx,:]], dim=0).unsqueeze(0)) + # sg_embs = torch.concat(sg_list) + # # print(sg_embs.size()) + # sub_embs = torch.index_select(all_ent, 0, subj) + # # print(sub_embs.size()) + # # filter out embeddings of relations in this batch + # obj_embs = torch.index_select(all_ent, 0, obj) + # # print(obj_embs.size()) + # edge_embs = torch.concat([sub_embs, obj_embs, sg_embs], dim=1) + # score = self.predictor(edge_embs) + # # print(F.embedding(subgraph, all_ent)) + # return score + + # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops): + # h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0') + # for i in range(h.size(0)): + # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1 + # # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # # print(subgraph_sampler(h,mode='argmax')) + # n, c = atten_matrix.shape + # h = h * atten_matrix.view(n,c,1) + # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + # h = torch.sum(h,dim=1) + # # print(h.size()) # [batch_size, 2*dim] + # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w) + # # x = self.bn0(h) + # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] + # # x = self.bn1(x) + # # x = F.relu(x) + # # x = self.feature_drop(x) + # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] + # # x = self.fc(x) # [batch_size, embed_dim] + # # x = self.hidden_drop(x) + # # x = self.bn2(x) + # # x = F.relu(x) + # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num] + # # x += self.bias_rel.expand_as(x) + # # # score = torch.sigmoid(x) + # # score = x + # return score + # print(h.size()) + + def cross_pair(self, x_i, x_j): + x = [] + for i in range(self.n_layer): + for j in range(self.n_layer): + if self.combine_type == 'mult': + x.append(x_i[:, i, :] * x_j[:, j, :]) + elif self.combine_type == 'concat': + x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1)) + x = torch.stack(x, dim=1) + return x + + def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + atten_matrix = subgraph_sampler(h, mode) + return torch.sum(atten_matrix, dim=0) + + def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0') + for i in range(h.size(0)): + atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1 + return torch.sum(atten_matrix, dim=0) diff --git a/model/model_spos.py b/model/model_spos.py new file mode 100644 index 0000000..97a8e3f --- /dev/null +++ b/model/model_spos.py @@ -0,0 +1,421 @@ +import torch +from torch import nn +import dgl +import dgl.function as fn +from torch.autograd import Variable +from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES +from model.operations import * +from model.search_layer_spos import SearchSPOSGCNConv +from torch.nn.functional import softmax +from numpy.random import choice +from pprint import pprint + + +class NetworkSPOS(nn.Module): + def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + conv_bias=True, gcn_drop=0., wni=False, wsi=False, use_bn=True, ltr=True, loss_type='ce'): + super(NetworkSPOS, self).__init__() + self.act = torch.tanh + self.args = args + self.loss_type = loss_type + self.num_ent, self.num_rel, self.num_base = num_ent, num_rel, num_base + self.init_dim, self.gcn_dim, self.embed_dim = init_dim, gcn_dim, embed_dim + + self.gcn_drop = gcn_drop + self.edge_type = edge_type # [E] + self.edge_norm = edge_norm # [E] + self.n_layer = n_layer + + self.init_embed = self.get_param([self.num_ent + 1, self.init_dim]) + self.init_rel = self.get_param([self.num_rel, self.init_dim]) + self.bias_rel = nn.Parameter(torch.zeros(self.num_rel)) + + self._initialize_loss() + self.ops = None + self.gnn_layers = nn.ModuleList() + for idx in range(self.args.n_layer): + if idx == 0: + self.gnn_layers.append( + SearchSPOSGCNConv(self.init_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr)) + elif idx == self.args.n_layer-1: + self.gnn_layers.append(SearchSPOSGCNConv(self.gcn_dim, self.embed_dim, self.act, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr)) + else: + self.gnn_layers.append( + SearchSPOSGCNConv(self.gcn_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr)) + + def generate_single_path(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append(choice(COMP_PRIMITIVES)) + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append(choice(ACT_PRIMITIVES)) + return single_path + + def generate_single_path_ccorr(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append('ccorr') + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append(choice(ACT_PRIMITIVES)) + return single_path + + def generate_single_path_act(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append('ccorr') + single_path.append('sum') + single_path.append('add') + single_path.append(choice(ACT_PRIMITIVES)) + return single_path + + def generate_single_path_comb(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append('ccorr') + single_path.append('sum') + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append('tanh') + return single_path + + def generate_single_path_comp(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append(choice(COMP_PRIMITIVES)) + single_path.append('sum') + single_path.append('add') + single_path.append('tanh') + return single_path + + def generate_single_path_agg(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append('ccorr') + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append('add') + single_path.append('tanh') + return single_path + + def generate_single_path_agg_comb(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append('ccorr') + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append('tanh') + return single_path + + def generate_single_path_agg_comb_comp(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append(choice(COMP_PRIMITIVES)) + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append('relu') + return single_path + + def generate_single_path_agg_comb_act_rotate(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append('rotate') + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append(choice(ACT_PRIMITIVES)) + return single_path + + def generate_single_path_agg_comb_act_ccorr(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append('ccorr') + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append(choice(ACT_PRIMITIVES)) + return single_path + + def generate_single_path_agg_comb_act_mult(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append('mult') + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append(choice(ACT_PRIMITIVES)) + return single_path + + def generate_single_path_agg_comb_act_sub(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append('sub') + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append(choice(ACT_PRIMITIVES)) + return single_path + + def generate_single_path_agg_comb_act_1mult(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + if i == 0: + single_path.append('mult') + else: + single_path.append(choice(COMP_PRIMITIVES)) + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append(choice(ACT_PRIMITIVES)) + return single_path + + def generate_single_path_agg_comb_act_1ccorr(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + if i == 0: + single_path.append('ccorr') + else: + single_path.append(choice(COMP_PRIMITIVES)) + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append(choice(ACT_PRIMITIVES)) + return single_path + + def generate_single_path_agg_comb_act_few_shot_comp(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + if i == 0: + single_path.append(op_subsupernet) + else: + single_path.append(choice(COMP_PRIMITIVES)) + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append(choice(ACT_PRIMITIVES)) + return single_path + + def get_param(self, shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu')) + return param + + def _initialize_loss(self): + if self.loss_type == 'ce': + self.loss = nn.CrossEntropyLoss() + elif self.loss_type == 'bce': + self.loss = nn.BCELoss() + elif self.loss_type == 'bce_logits': + self.loss = nn.BCEWithLogitsLoss() + else: + raise NotImplementedError + + def calc_loss(self, pred, label, pos_neg=None): + if pos_neg is not None: + m = nn.Sigmoid() + score_pos = m(pred) + targets_pos = pos_neg.unsqueeze(1) + loss = self.loss(score_pos, label * targets_pos) + return torch.sum(loss * label) + return self.loss(pred, label) + + def forward_base(self, g, subj, obj, drop1, drop2, mode=None): + x, r = self.init_embed, self.init_rel # embedding of relations + + for i, layer in enumerate(self.gnn_layers): + if i != self.args.n_layer-1: + x, r = layer(g, x, r, self.edge_type, self.edge_norm, self.ops[4*i], self.ops[4*i+1], self.ops[4*i+2], self.ops[4*i+3]) + x = drop1(x) + else: + x, r = layer(g, x, r, self.edge_type, self.edge_norm, self.ops[4*i], self.ops[4*i+1], self.ops[4*i+2], self.ops[4*i+3]) + x = drop2(x) + + sub_emb = torch.index_select(x, 0, subj) + # filter out embeddings of objects in this batch + obj_emb = torch.index_select(x, 0, obj) + + return sub_emb, obj_emb, x, r + + def forward_base_search(self, g, drop1, drop2, mode=None): + # if self.args.search_algorithm == "darts": + # comp_weights = softmax(self.comp_alphas, dim=-1) + # agg_weights = softmax(self.agg_alphas, dim=-1) + # comb_weights = softmax(self.comb_alphas, dim=-1) + # act_weights = softmax(self.act_alphas, dim=-1) + # elif self.args.search_algorithm == "snas": + # comp_weights = self._get_categ_mask(self.comp_alphas) + # agg_weights = self._get_categ_mask(self.agg_alphas) + # comb_weights = self._get_categ_mask(self.comb_alphas) + # act_weights = self._get_categ_mask(self.act_alphas) + # else: + # raise NotImplementedError + # if mode == 'evaluate_single_path': + # comp_weights = self.get_one_hot_alpha(comp_weights) + # agg_weights = self.get_one_hot_alpha(agg_weights) + # comb_weights = self.get_one_hot_alpha(comb_weights) + # act_weights = self.get_one_hot_alpha(act_weights) + x, r = self.init_embed, self.init_rel # embedding of relations + x_hidden = [] + for i, layer in enumerate(self.gnn_layers): + if i != self.args.n_layer-1: + x, r = layer(g, x, r, self.edge_type, self.edge_norm, self.ops[4 * i], self.ops[4 * i + 1], + self.ops[4 * i + 2], self.ops[4 * i + 3]) + x_hidden.append(x) + x = drop1(x) + else: + x, r = layer(g, x, r, self.edge_type, self.edge_norm, self.ops[4 * i], self.ops[4 * i + 1], + self.ops[4 * i + 2], self.ops[4 * i + 3]) + x_hidden.append(x) + x = drop2(x) + x_hidden = torch.stack(x_hidden, dim=1) + + return x_hidden, x + + +class SearchGCN_MLP_SPOS(NetworkSPOS): + def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + bias=True, gcn_drop=0., hid_drop=0., input_drop=0., wni=False, wsi=False, use_bn=True, + ltr=True, combine_type='mult', loss_type='ce'): + """ + :param num_ent: number of entities + :param num_rel: number of different relations + :param num_base: number of bases to use + :param init_dim: initial dimension + :param gcn_dim: dimension after first layer + :param embed_dim: dimension after second layer + :param n_layer: number of layer + :param edge_type: relation type of each edge, [E] + :param bias: weather to add bias + :param gcn_drop: dropout rate in compgcncov + :param opn: combination operator + :param hid_drop: gcn output (embedding of each entity) dropout + :param input_drop: dropout in conve input + :param conve_hid_drop: dropout in conve hidden layer + :param feat_drop: feature dropout in conve + :param num_filt: number of filters in conv2d + :param ker_sz: kernel size in conv2d + :param k_h: height of 2D reshape + :param k_w: width of 2D reshape + """ + super(SearchGCN_MLP_SPOS, self).__init__(args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, + edge_type, edge_norm, bias, gcn_drop, wni, wsi, use_bn, ltr, loss_type) + self.hid_drop, self.input_drop = hid_drop, input_drop + + self.num_rel = num_rel + + self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout + self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout + + self.combine_type = combine_type + if self.combine_type == 'concat': + self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel) + elif self.combine_type == 'mult': + self.fc = torch.nn.Linear(self.embed_dim, self.num_rel) + + def forward(self, g, subj, obj, mode=None): + """ + :param g: dgl graph + :param sub: subject in batch [batch_size] + :param rel: relation in batch [batch_size] + :return: score: [batch_size, ent_num], the prob in link-prediction + """ + + sub_embs, obj_embs, all_ent, all_rel = self.forward_base(g, subj, obj, self.drop, self.input_drop, mode) + if self.combine_type == 'concat': + edge_embs = torch.concat([sub_embs, obj_embs], dim=1) + elif self.combine_type == 'mult': + edge_embs = sub_embs * obj_embs + else: + raise NotImplementedError + score = self.fc(edge_embs) + return score + + def forward_search(self, g, mode=None): + # if mode == 'allgraph': + hidden_all_ent, all_ent = self.forward_base_search( + g, self.drop, self.input_drop, mode) + # elif mode == 'subgraph': + # hidden_all_ent = self.forward_base_subgraph_search( + # g, self.drop, self.input_drop) + + return hidden_all_ent, all_ent + + def compute_pred(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = subgraph_sampler(h, mode, search_algorithm) + # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # print(subgraph_sampler(h,mode='argmax')) + n, c = atten_matrix.shape + h = h * atten_matrix.view(n, c, 1) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + h = torch.sum(h, dim=1) + # print(h.size()) # [batch_size, 2*dim] + score = self.fc(h) + return score + + # def fine_tune_with_implicit_subgraph(self, all_ent, subgraph, subj, obj): + # sg_list = [] + # for idx in range(subgraph.size(0)): + # sg_list.append(torch.mean(all_ent[subgraph[idx,:]], dim=0).unsqueeze(0)) + # sg_embs = torch.concat(sg_list) + # # print(sg_embs.size()) + # sub_embs = torch.index_select(all_ent, 0, subj) + # # print(sub_embs.size()) + # # filter out embeddings of relations in this batch + # obj_embs = torch.index_select(all_ent, 0, obj) + # # print(obj_embs.size()) + # edge_embs = torch.concat([sub_embs, obj_embs, sg_embs], dim=1) + # score = self.predictor(edge_embs) + # # print(F.embedding(subgraph, all_ent)) + # return score + + # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops): + # h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0') + # for i in range(h.size(0)): + # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1 + # # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # # print(subgraph_sampler(h,mode='argmax')) + # n, c = atten_matrix.shape + # h = h * atten_matrix.view(n,c,1) + # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + # h = torch.sum(h,dim=1) + # # print(h.size()) # [batch_size, 2*dim] + # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w) + # # x = self.bn0(h) + # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] + # # x = self.bn1(x) + # # x = F.relu(x) + # # x = self.feature_drop(x) + # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] + # # x = self.fc(x) # [batch_size, embed_dim] + # # x = self.hidden_drop(x) + # # x = self.bn2(x) + # # x = F.relu(x) + # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num] + # # x += self.bias_rel.expand_as(x) + # # # score = torch.sigmoid(x) + # # score = x + # return score + # print(h.size()) + + def cross_pair(self, x_i, x_j): + x = [] + for i in range(self.n_layer): + for j in range(self.n_layer): + if self.combine_type == 'mult': + x.append(x_i[:, i, :] * x_j[:, j, :]) + elif self.combine_type == 'concat': + x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1)) + x = torch.stack(x, dim=1) + return x + + def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + atten_matrix = subgraph_sampler(h, mode, search_algorithm) + return torch.sum(atten_matrix, dim=0) + + def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0') + for i in range(h.size(0)): + atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1 + return torch.sum(atten_matrix, dim=0) diff --git a/model/model_spos_fast.py b/model/model_spos_fast.py new file mode 100644 index 0000000..545752f --- /dev/null +++ b/model/model_spos_fast.py @@ -0,0 +1,287 @@ +import torch +from torch import nn +import dgl +import dgl.function as fn +from torch.autograd import Variable +from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES +from model.operations import * +from model.search_layer_spos import SearchSPOSGCNConv +from torch.nn.functional import softmax +from numpy.random import choice +from pprint import pprint + + +class NetworkSPOS(nn.Module): + def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + conv_bias=True, gcn_drop=0., wni=False, wsi=False, use_bn=True, ltr=True, loss_type='ce'): + super(NetworkSPOS, self).__init__() + self.act = torch.tanh + self.args = args + self.loss_type = loss_type + self.num_ent, self.num_rel, self.num_base = num_ent, num_rel, num_base + self.init_dim, self.gcn_dim, self.embed_dim = init_dim, gcn_dim, embed_dim + + self.gcn_drop = gcn_drop + self.edge_type = edge_type # [E] + self.edge_norm = edge_norm # [E] + self.n_layer = n_layer + + self.init_embed = self.get_param([self.num_ent + 1, self.init_dim]) + self.init_rel = self.get_param([self.num_rel, self.init_dim]) + self.bias_rel = nn.Parameter(torch.zeros(self.num_rel)) + + self._initialize_loss() + self.ops = None + self.gnn_layers = nn.ModuleList() + for idx in range(self.args.n_layer): + if idx == 0: + self.gnn_layers.append( + SearchSPOSGCNConv(self.init_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr)) + elif idx == self.args.n_layer-1: + self.gnn_layers.append(SearchSPOSGCNConv(self.gcn_dim, self.embed_dim, self.act, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr)) + else: + self.gnn_layers.append( + SearchSPOSGCNConv(self.gcn_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr)) + + def generate_single_path(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append(choice(COMP_PRIMITIVES)) + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append(choice(ACT_PRIMITIVES)) + return single_path + + def generate_single_path_ccorr(self, op_subsupernet=''): + single_path = [] + for i in range(self.args.n_layer): + single_path.append('ccorr') + single_path.append(choice(AGG_PRIMITIVES)) + single_path.append(choice(COMB_PRIMITIVES)) + single_path.append(choice(ACT_PRIMITIVES)) + return single_path + + def get_param(self, shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu')) + return param + + def _initialize_loss(self): + if self.loss_type == 'ce': + self.loss = nn.CrossEntropyLoss() + elif self.loss_type == 'bce': + self.loss = nn.BCELoss() + elif self.loss_type == 'bce_logits': + self.loss = nn.BCEWithLogitsLoss() + else: + raise NotImplementedError + + def calc_loss(self, pred, label): + return self.loss(pred, label) + + def forward_base(self, g, subj, obj, drop1, drop2, mode=None): + x, r = self.init_embed, self.init_rel # embedding of relations + + for i, layer in enumerate(self.gnn_layers): + if i != self.args.n_layer-1: + x, r = layer(g, x, r, self.edge_type, self.edge_norm, self.ops[4*i], self.ops[4*i+1], self.ops[4*i+2], self.ops[4*i+3]) + x = drop1(x) + else: + x, r = layer(g, x, r, self.edge_type, self.edge_norm, self.ops[4*i], self.ops[4*i+1], self.ops[4*i+2], self.ops[4*i+3]) + x = drop2(x) + + sub_emb = torch.index_select(x, 0, subj) + # filter out embeddings of objects in this batch + obj_emb = torch.index_select(x, 0, obj) + + return sub_emb, obj_emb, x, r + + def forward_base_search(self, g, drop1, drop2, mode=None): + if self.args.search_algorithm == "darts": + comp_weights = softmax(self.comp_alphas, dim=-1) + agg_weights = softmax(self.agg_alphas, dim=-1) + comb_weights = softmax(self.comb_alphas, dim=-1) + act_weights = softmax(self.act_alphas, dim=-1) + elif self.args.search_algorithm == "snas": + comp_weights = self._get_categ_mask(self.comp_alphas) + agg_weights = self._get_categ_mask(self.agg_alphas) + comb_weights = self._get_categ_mask(self.comb_alphas) + act_weights = self._get_categ_mask(self.act_alphas) + else: + raise NotImplementedError + if mode == 'evaluate_single_path': + comp_weights = self.get_one_hot_alpha(comp_weights) + agg_weights = self.get_one_hot_alpha(agg_weights) + comb_weights = self.get_one_hot_alpha(comb_weights) + act_weights = self.get_one_hot_alpha(act_weights) + x, r = self.init_embed, self.init_rel # embedding of relations + x_hidden = [] + for i, layer in enumerate(self.gnn_layers): + if i != self.args.n_layer-1: + x, r = layer(g, x, r, self.edge_type, self.edge_norm, comp_weights[i], agg_weights[i], comb_weights[i], act_weights[i]) + x_hidden.append(x) + x = drop1(x) + else: + x, r = layer(g, x, r, self.edge_type, self.edge_norm, comp_weights[i], agg_weights[i], comb_weights[i], act_weights[i]) + x_hidden.append(x) + x = drop2(x) + x_hidden = torch.stack(x_hidden, dim=1) + + return x_hidden, x + + +class SearchGCN_MLP_SPOS(NetworkSPOS): + def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + bias=True, gcn_drop=0., hid_drop=0., input_drop=0., wni=False, wsi=False, use_bn=True, + ltr=True, combine_type='mult', loss_type='ce'): + """ + :param num_ent: number of entities + :param num_rel: number of different relations + :param num_base: number of bases to use + :param init_dim: initial dimension + :param gcn_dim: dimension after first layer + :param embed_dim: dimension after second layer + :param n_layer: number of layer + :param edge_type: relation type of each edge, [E] + :param bias: weather to add bias + :param gcn_drop: dropout rate in compgcncov + :param opn: combination operator + :param hid_drop: gcn output (embedding of each entity) dropout + :param input_drop: dropout in conve input + :param conve_hid_drop: dropout in conve hidden layer + :param feat_drop: feature dropout in conve + :param num_filt: number of filters in conv2d + :param ker_sz: kernel size in conv2d + :param k_h: height of 2D reshape + :param k_w: width of 2D reshape + """ + super(SearchGCN_MLP_SPOS, self).__init__(args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, + edge_type, edge_norm, bias, gcn_drop, wni, wsi, use_bn, ltr, loss_type) + self.hid_drop, self.input_drop = hid_drop, input_drop + + self.num_rel = num_rel + + self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout + self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout + + self.combine_type = combine_type + if self.combine_type == 'concat': + self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel) + elif self.combine_type == 'mult': + self.fc = torch.nn.Linear(self.embed_dim, self.num_rel) + + def forward(self, g, subj, obj, mode=None): + """ + :param g: dgl graph + :param sub: subject in batch [batch_size] + :param rel: relation in batch [batch_size] + :return: score: [batch_size, ent_num], the prob in link-prediction + """ + + sub_embs, obj_embs, all_ent, all_rel = self.forward_base(g, subj, obj, self.drop, self.input_drop, mode) + if self.combine_type == 'concat': + edge_embs = torch.concat([sub_embs, obj_embs], dim=1) + elif self.combine_type == 'mult': + edge_embs = sub_embs * obj_embs + else: + raise NotImplementedError + score = self.fc(edge_embs) + return score + + def forward_search(self, g, mode=None): + # if mode == 'allgraph': + hidden_all_ent, all_ent = self.forward_base_search( + g, self.drop, self.input_drop, mode) + # elif mode == 'subgraph': + # hidden_all_ent = self.forward_base_subgraph_search( + # g, self.drop, self.input_drop) + + return hidden_all_ent, all_ent + + def compute_pred(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = subgraph_sampler(h, mode, search_algorithm) + # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # print(subgraph_sampler(h,mode='argmax')) + n, c = atten_matrix.shape + h = h * atten_matrix.view(n, c, 1) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + h = torch.sum(h, dim=1) + # print(h.size()) # [batch_size, 2*dim] + score = self.fc(h) + return score + + # def fine_tune_with_implicit_subgraph(self, all_ent, subgraph, subj, obj): + # sg_list = [] + # for idx in range(subgraph.size(0)): + # sg_list.append(torch.mean(all_ent[subgraph[idx,:]], dim=0).unsqueeze(0)) + # sg_embs = torch.concat(sg_list) + # # print(sg_embs.size()) + # sub_embs = torch.index_select(all_ent, 0, subj) + # # print(sub_embs.size()) + # # filter out embeddings of relations in this batch + # obj_embs = torch.index_select(all_ent, 0, obj) + # # print(obj_embs.size()) + # edge_embs = torch.concat([sub_embs, obj_embs, sg_embs], dim=1) + # score = self.predictor(edge_embs) + # # print(F.embedding(subgraph, all_ent)) + # return score + + # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops): + # h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0') + # for i in range(h.size(0)): + # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1 + # # print(atten_matrix.size()) # [batch_size, encoder_layer^2] + # # print(subgraph_sampler(h,mode='argmax')) + # n, c = atten_matrix.shape + # h = h * atten_matrix.view(n,c,1) + # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + # h = torch.sum(h,dim=1) + # # print(h.size()) # [batch_size, 2*dim] + # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w) + # # x = self.bn0(h) + # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w] + # # x = self.bn1(x) + # # x = F.relu(x) + # # x = self.feature_drop(x) + # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz] + # # x = self.fc(x) # [batch_size, embed_dim] + # # x = self.hidden_drop(x) + # # x = self.bn2(x) + # # x = F.relu(x) + # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num] + # # x += self.bias_rel.expand_as(x) + # # # score = torch.sigmoid(x) + # # score = x + # return score + # print(h.size()) + + def cross_pair(self, x_i, x_j): + x = [] + for i in range(self.n_layer): + for j in range(self.n_layer): + if self.combine_type == 'mult': + x.append(x_i[:, i, :] * x_j[:, j, :]) + elif self.combine_type == 'concat': + x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1)) + x = torch.stack(x, dim=1) + return x + + def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search'): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + atten_matrix = subgraph_sampler(h, mode) + return torch.sum(atten_matrix, dim=0) + + def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops): + h = self.cross_pair(hidden_x[subj], hidden_x[obj]) + # print(h.size()) # [batch_size, encoder_layer^2, 2*dim] + atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0') + for i in range(h.size(0)): + atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1 + return torch.sum(atten_matrix, dim=0) diff --git a/model/operations.py b/model/operations.py new file mode 100644 index 0000000..7f21f40 --- /dev/null +++ b/model/operations.py @@ -0,0 +1,148 @@ +import torch +from torch.nn import Module, Linear + +COMP_OPS = { + 'mult': lambda: MultOp(), + 'sub': lambda: SubOp(), + 'add': lambda: AddOp(), + 'ccorr': lambda: CcorrOp(), + 'rotate': lambda: RotateOp() +} + +AGG_OPS = { + 'max': lambda: MaxOp(), + 'sum': lambda: SumOp(), + 'mean': lambda: MeanOp() +} + +COMB_OPS = { + 'add': lambda out_channels: CombAddOp(out_channels), + 'mlp': lambda out_channels: CombMLPOp(out_channels), + 'concat': lambda out_channels: CombConcatOp(out_channels) +} + +ACT_OPS = { + 'identity': lambda: torch.nn.Identity(), + 'relu': lambda: torch.nn.ReLU(), + 'tanh': lambda: torch.nn.Tanh(), +} + +class CcorrOp(Module): + def __init__(self): + super(CcorrOp, self).__init__() + + def forward(self, src_emb, rel_emb): + return self.comp(src_emb, rel_emb) + + def comp(self, h, edge_data): + def com_mult(a, b): + r1, i1 = a.real, a.imag + r2, i2 = b.real, b.imag + real = r1 * r2 - i1 * i2 + imag = r1 * i2 + i1 * r2 + return torch.complex(real, imag) + + def conj(a): + a.imag = -a.imag + return a + + def ccorr(a, b): + return torch.fft.irfft(com_mult(conj(torch.fft.rfft(a)), torch.fft.rfft(b)), a.shape[-1]) + + return ccorr(h, edge_data.expand_as(h)) + + +class MultOp(Module): + + def __init__(self): + super(MultOp, self).__init__() + + def forward(self, src_emb, rel_emb): + # print('hr.shape', hr.shape) + return src_emb * rel_emb + + +class SubOp(Module): + def __init__(self): + super(SubOp, self).__init__() + + def forward(self, src_emb, rel_emb): + # print('hr.shape', hr.shape) + return src_emb - rel_emb + + +class AddOp(Module): + def __init__(self): + super(AddOp, self).__init__() + + def forward(self, src_emb, rel_emb): + # print('hr.shape', hr.shape) + return src_emb + rel_emb + + +class RotateOp(Module): + def __init__(self): + super(RotateOp, self).__init__() + + def forward(self, src_emb, rel_emb): + # print('hr.shape', hr.shape) + return self.rotate(src_emb, rel_emb) + + def rotate(self, h, r): + # re: first half, im: second half + # assume embedding dim is the last dimension + d = h.shape[-1] + h_re, h_im = torch.split(h, d // 2, -1) + r_re, r_im = torch.split(r, d // 2, -1) + return torch.cat([h_re * r_re - h_im * r_im, + h_re * r_im + h_im * r_re], dim=-1) + + +class MaxOp(Module): + def __init__(self): + super(MaxOp, self).__init__() + + def forward(self, msg): + return torch.max(msg, dim=1)[0] + + +class SumOp(Module): + def __init__(self): + super(SumOp, self).__init__() + + def forward(self, msg): + return torch.sum(msg, dim=1) + + +class MeanOp(Module): + def __init__(self): + super(MeanOp, self).__init__() + + def forward(self, msg): + return torch.mean(msg, dim=1) + + +class CombAddOp(Module): + def __init__(self, out_channels): + super(CombAddOp, self).__init__() + + def forward(self, self_emb, msg): + return self_emb + msg + + +class CombMLPOp(Module): + def __init__(self, out_channels): + super(CombMLPOp, self).__init__() + self.linear = Linear(out_channels, out_channels) + + def forward(self, self_emb, msg): + return self.linear(self_emb + msg) + + +class CombConcatOp(Module): + def __init__(self, out_channels): + super(CombConcatOp, self).__init__() + self.linear = Linear(2*out_channels, out_channels) + + def forward(self, self_emb, msg): + return self.linear(torch.concat([self_emb,msg],dim=1)) \ No newline at end of file diff --git a/model/rgcn_layer.py b/model/rgcn_layer.py new file mode 100644 index 0000000..25161f5 --- /dev/null +++ b/model/rgcn_layer.py @@ -0,0 +1,417 @@ +""" +based on the implementation in DGL +(https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/relgraphconv.py) +""" + + +"""Torch Module for Relational graph convolution layer""" +# pylint: disable= no-member, arguments-differ, invalid-name + + + + +import functools +import numpy as np +import torch as th +from torch import nn +import dgl.function as fn +from dgl.nn.pytorch import utils +from dgl.base import DGLError +from dgl.subgraph import edge_subgraph +class RelGraphConv(nn.Module): + r"""Relational graph convolution layer. + + Relational graph convolution is introduced in "`Modeling Relational Data with Graph + Convolutional Networks `__" + and can be described in DGL as below: + + .. math:: + + h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}} + \sum_{j\in\mathcal{N}^r(i)}e_{j,i}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)}) + + where :math:`\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation + :math:`r`. :math:`e_{j,i}` is the normalizer. :math:`\sigma` is an activation + function. :math:`W_0` is the self-loop weight. + + The basis regularization decomposes :math:`W_r` by: + + .. math:: + + W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)} + + where :math:`B` is the number of bases, :math:`V_b^{(l)}` are linearly combined + with coefficients :math:`a_{rb}^{(l)}`. + + The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B` + number of block diagonal matrices. We refer :math:`B` as the number of bases. + + The block regularization decomposes :math:`W_r` by: + + .. math:: + + W_r^{(l)} = \oplus_{b=1}^B Q_{rb}^{(l)} + + where :math:`B` is the number of bases, :math:`Q_{rb}^{(l)}` are block + bases with shape :math:`R^{(d^{(l+1)}/B)*(d^{l}/B)}`. + + Parameters + ---------- + in_feat : int + Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`. + out_feat : int + Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`. + num_rels : int + Number of relations. . + regularizer : str + Which weight regularizer to use "basis" or "bdd". + "basis" is short for basis-diagonal-decomposition. + "bdd" is short for block-diagonal-decomposition. + num_bases : int, optional + Number of bases. If is none, use number of relations. Default: ``None``. + bias : bool, optional + True if bias is added. Default: ``True``. + activation : callable, optional + Activation function. Default: ``None``. + self_loop : bool, optional + True to include self loop message. Default: ``True``. + low_mem : bool, optional + True to use low memory implementation of relation message passing function. Default: False. + This option trades speed with memory consumption, and will slowdown the forward/backward. + Turn it on when you encounter OOM problem during training or evaluation. Default: ``False``. + dropout : float, optional + Dropout rate. Default: ``0.0`` + layer_norm: float, optional + Add layer norm. Default: ``False`` + + Examples + -------- + >>> import dgl + >>> import numpy as np + >>> import torch as th + >>> from dgl.nn import RelGraphConv + >>> + >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) + >>> feat = th.ones(6, 10) + >>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2) + >>> conv.weight.shape + torch.Size([2, 10, 2]) + >>> etype = th.tensor(np.array([0,1,2,0,1,2]).astype(np.int64)) + >>> res = conv(g, feat, etype) + >>> res + tensor([[ 0.3996, -2.3303], + [-0.4323, -0.1440], + [ 0.3996, -2.3303], + [ 2.1046, -2.8654], + [-0.4323, -0.1440], + [-0.1309, -1.0000]], grad_fn=) + + >>> # One-hot input + >>> one_hot_feat = th.tensor(np.array([0,1,2,3,4,5]).astype(np.int64)) + >>> res = conv(g, one_hot_feat, etype) + >>> res + tensor([[ 0.5925, 0.0985], + [-0.3953, 0.8408], + [-0.9819, 0.5284], + [-1.0085, -0.1721], + [ 0.5962, 1.2002], + [ 0.0365, -0.3532]], grad_fn=) + """ + + def __init__(self, + in_feat, + out_feat, + num_rels, + regularizer="basis", + num_bases=None, + bias=True, + activation=None, + self_loop=True, + low_mem=False, + dropout=0.0, + layer_norm=False, + wni=False): + super(RelGraphConv, self).__init__() + self.in_feat = in_feat + self.out_feat = out_feat + self.num_rels = num_rels + self.regularizer = regularizer + self.num_bases = num_bases + if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases <= 0: + self.num_bases = self.num_rels + self.bias = bias + self.activation = activation + self.self_loop = self_loop + self.low_mem = low_mem + self.layer_norm = layer_norm + + self.wni = wni + + if regularizer == "basis": + # add basis weights + self.weight = nn.Parameter( + th.Tensor(self.num_bases, self.in_feat, self.out_feat)) + if self.num_bases < self.num_rels: + # linear combination coefficients + self.w_comp = nn.Parameter( + th.Tensor(self.num_rels, self.num_bases)) + nn.init.xavier_uniform_( + self.weight, gain=nn.init.calculate_gain('relu')) + if self.num_bases < self.num_rels: + nn.init.xavier_uniform_(self.w_comp, + gain=nn.init.calculate_gain('relu')) + # message func + self.message_func = self.basis_message_func + elif regularizer == "bdd": + print(in_feat) + print(out_feat) + if in_feat % self.num_bases != 0 or out_feat % self.num_bases != 0: + raise ValueError( + 'Feature size must be a multiplier of num_bases (%d).' + % self.num_bases + ) + # add block diagonal weights + self.submat_in = in_feat // self.num_bases + self.submat_out = out_feat // self.num_bases + + # assuming in_feat and out_feat are both divisible by num_bases + self.weight = nn.Parameter(th.Tensor( + self.num_rels, self.num_bases * self.submat_in * self.submat_out)) + nn.init.xavier_uniform_( + self.weight, gain=nn.init.calculate_gain('relu')) + # message func + self.message_func = self.bdd_message_func + else: + raise ValueError("Regularizer must be either 'basis' or 'bdd'") + + # bias + if self.bias: + self.h_bias = nn.Parameter(th.Tensor(out_feat)) + nn.init.zeros_(self.h_bias) + + # layer norm + if self.layer_norm: + self.layer_norm_weight = nn.LayerNorm( + out_feat, elementwise_affine=True) + + # weight for self loop + if self.self_loop: + self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) + nn.init.xavier_uniform_(self.loop_weight, + gain=nn.init.calculate_gain('relu')) + + self.dropout = nn.Dropout(dropout) + + def basis_message_func(self, edges, etypes): + """Message function for basis regularizer. + + Parameters + ---------- + edges : dgl.EdgeBatch + Input to DGL message UDF. + etypes : torch.Tensor or list[int] + Edge type data. Could be either: + + * An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID. + Preferred format if ``lowmem == False``. + * An integer list. The i^th element is the number of edges of the i^th type. + This requires the input graph to store edges sorted by their type IDs. + Preferred format if ``lowmem == True``. + """ + if self.num_bases < self.num_rels: + # generate all weights from bases + weight = self.weight.view(self.num_bases, + self.in_feat * self.out_feat) + weight = th.matmul(self.w_comp, weight).view( + self.num_rels, self.in_feat, self.out_feat) + else: + weight = self.weight + + h = edges.src['h'] + device = h.device + + if h.dtype == th.int64 and h.ndim == 1: + # Each element is the node's ID. Use index select: weight[etypes, h, :] + # The following is a faster version of it. + if isinstance(etypes, list): + etypes = th.repeat_interleave(th.arange(len(etypes), device=device), + th.tensor(etypes, device=device)) + idim = weight.shape[1] + weight = weight.view(-1, weight.shape[2]) + flatidx = etypes * idim + h + msg = weight.index_select(0, flatidx) + elif self.low_mem: + # A more memory-friendly implementation. + # Calculate msg @ W_r before put msg into edge. + assert isinstance(etypes, list) + h_t = th.split(h, etypes) + msg = [] + for etype in range(self.num_rels): + if h_t[etype].shape[0] == 0: + continue + msg.append(th.matmul(h_t[etype], weight[etype])) + msg = th.cat(msg) + else: + # Use batched matmult + if isinstance(etypes, list): + etypes = th.repeat_interleave(th.arange(len(etypes), device=device), + th.tensor(etypes, device=device)) + weight = weight.index_select(0, etypes) + msg = th.bmm(h.unsqueeze(1), weight).squeeze(1) + + if 'norm' in edges.data: + msg = msg * edges.data['norm'] + return {'msg': msg} + + def bdd_message_func(self, edges, etypes): + """Message function for block-diagonal-decomposition regularizer. + + Parameters + ---------- + edges : dgl.EdgeBatch + Input to DGL message UDF. + etypes : torch.Tensor or list[int] + Edge type data. Could be either: + + * An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID. + Preferred format if ``lowmem == False``. + * An integer list. The i^th element is the number of edges of the i^th type. + This requires the input graph to store edges sorted by their type IDs. + Preferred format if ``lowmem == True``. + """ + h = edges.src['h'] + device = h.device + + if h.dtype == th.int64 and h.ndim == 1: + raise TypeError( + 'Block decomposition does not allow integer ID feature.') + + if self.low_mem: + # A more memory-friendly implementation. + # Calculate msg @ W_r before put msg into edge. + assert isinstance(etypes, list) + h_t = th.split(h, etypes) + msg = [] + for etype in range(self.num_rels): + if h_t[etype].shape[0] == 0: + continue + tmp_w = self.weight[etype].view( + self.num_bases, self.submat_in, self.submat_out) + tmp_h = h_t[etype].view(-1, self.num_bases, self.submat_in) + msg.append(th.einsum('abc,bcd->abd', tmp_h, + tmp_w).reshape(-1, self.out_feat)) + msg = th.cat(msg) + else: + # Use batched matmult + if isinstance(etypes, list): + etypes = th.repeat_interleave(th.arange(len(etypes), device=device), + th.tensor(etypes, device=device)) + weight = self.weight.index_select(0, etypes).view( + -1, self.submat_in, self.submat_out) + node = h.view(-1, 1, self.submat_in) + msg = th.bmm(node, weight).view(-1, self.out_feat) + if 'norm' in edges.data: + msg = msg * edges.data['norm'] + return {'msg': msg} + + def forward(self, g, feat, etypes, norm=None): + """Forward computation. + + Parameters + ---------- + g : DGLGraph + The graph. + feat : torch.Tensor + Input node features. Could be either + + * :math:`(|V|, D)` dense tensor + * :math:`(|V|,)` int64 vector, representing the categorical values of each + node. It then treat the input feature as an one-hot encoding feature. + etypes : torch.Tensor or list[int] + Edge type data. Could be either + + * An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID. + Preferred format if ``lowmem == False``. + * An integer list. The i^th element is the number of edges of the i^th type. + This requires the input graph to store edges sorted by their type IDs. + Preferred format if ``lowmem == True``. + norm : torch.Tensor, optional + Edge normalizer. Could be either + + * An :math:`(|E|, 1)` tensor storing the normalizer on each edge. + + Returns + ------- + torch.Tensor + New node features. + + Notes + ----- + Under the ``low_mem`` mode, DGL will sort the graph based on the edge types + and compute message passing one type at a time. DGL recommends sorts the + graph beforehand (and cache it if possible) and provides the integer list + format to the ``etypes`` argument. Use DGL's :func:`~dgl.to_homogeneous` API + to get a sorted homogeneous graph from a heterogeneous graph. Pass ``return_count=True`` + to it to get the ``etypes`` in integer list. + """ + if isinstance(etypes, th.Tensor): + if len(etypes) != g.num_edges(): + raise DGLError('"etypes" tensor must have length equal to the number of edges' + ' in the graph. But got {} and {}.'.format( + len(etypes), g.num_edges())) + if self.low_mem and not (feat.dtype == th.int64 and feat.ndim == 1): + # Low-mem optimization is not enabled for node ID input. When enabled, + # it first sorts the graph based on the edge types (the sorting will not + # change the node IDs). It then converts the etypes tensor to an integer + # list, where each element is the number of edges of the type. + # Sort the graph based on the etypes + sorted_etypes, index = th.sort(etypes) + g = edge_subgraph(g, index, relabel_nodes=False) + # Create a new etypes to be an integer list of number of edges. + pos = _searchsorted(sorted_etypes, th.arange( + self.num_rels, device=g.device)) + num = th.tensor([len(etypes)], device=g.device) + etypes = (th.cat([pos[1:], num]) - pos).tolist() + if norm is not None: + norm = norm[index] + + with g.local_scope(): + g.srcdata['h'] = feat + if norm is not None: + g.edata['norm'] = norm + if self.self_loop: + loop_message = utils.matmul_maybe_select(feat[:g.number_of_dst_nodes()], + self.loop_weight) + + if not self.wni: + # message passing + g.update_all(functools.partial(self.message_func, etypes=etypes), + fn.sum(msg='msg', out='h')) + # apply bias and activation + node_repr = g.dstdata['h'] + if self.layer_norm: + node_repr = self.layer_norm_weight(node_repr) + if self.bias: + node_repr = node_repr + self.h_bias + else: + node_repr = 0 + + if self.self_loop: + node_repr = node_repr + loop_message + if self.activation: + node_repr = self.activation(node_repr) + node_repr = self.dropout(node_repr) + return node_repr + + +_TORCH_HAS_SEARCHSORTED = getattr(th, 'searchsorted', None) + + +def _searchsorted(sorted_sequence, values): + # searchsorted is introduced to PyTorch in 1.6.0 + if _TORCH_HAS_SEARCHSORTED: + return th.searchsorted(sorted_sequence, values) + else: + device = values.device + return th.from_numpy(np.searchsorted(sorted_sequence.cpu().numpy(), + values.cpu().numpy())).to(device) diff --git a/model/seal_model.py b/model/seal_model.py new file mode 100644 index 0000000..a61c59f --- /dev/null +++ b/model/seal_model.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from dgl.nn.pytorch import SortPooling, SumPooling +from dgl.nn.pytorch import GraphConv, SAGEConv +from dgl import NID, EID + + +class SEAL_GCN(nn.Module): + def __init__(self, num_ent, num_rel, init_dim, gcn_dim, embed_dim, n_layer, loss_type, max_z=1000): + super(SEAL_GCN, self).__init__() + + if loss_type == 'ce': + self.loss = nn.CrossEntropyLoss() + elif loss_type == 'bce': + self.loss = nn.BCELoss(reduce=False) + elif loss_type == 'bce_logits': + self.loss = nn.BCEWithLogitsLoss() + self.init_embed = self.get_param([num_ent, init_dim]) + self.init_rel = self.get_param([num_rel, init_dim]) + self.z_embedding = nn.Embedding(max_z, init_dim) + init_dim += init_dim + + self.layers = nn.ModuleList() + self.layers.append(GraphConv(init_dim, gcn_dim, allow_zero_in_degree=True)) + for _ in range(n_layer - 1): + self.layers.append(GraphConv(gcn_dim, gcn_dim, allow_zero_in_degree=True)) + + self.linear_1 = nn.Linear(embed_dim, embed_dim) + self.linear_2 = nn.Linear(embed_dim, num_rel) + self.pooling = SumPooling() + + def get_param(self, shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu')) + return param + + def calc_loss(self, pred, label, pos_neg=None): + if pos_neg is not None: + m = nn.Sigmoid() + score_pos = m(pred) + targets_pos = pos_neg.unsqueeze(1) + loss = self.loss(score_pos, label * targets_pos) + return torch.sum(loss * label) + return self.loss(pred, label) + + def forward(self, g, z): + x, r = self.init_embed[g.ndata[NID]], self.init_rel + z_emb = self.z_embedding(z) + x = torch.cat([x, z_emb], 1) + for layer in self.layers[:-1]: + x = layer(g, x) + x = F.relu(x) + x = F.dropout(x, p=0.5, training=self.training) + x = self.layers[-1](g, x) + + x = self.pooling(g, x) + x = F.relu(self.linear_1(x)) + F.dropout(x, p=0.5, training=self.training) + x = self.linear_2(x) + + return x \ No newline at end of file diff --git a/model/search_layer.py b/model/search_layer.py new file mode 100644 index 0000000..ef76eee --- /dev/null +++ b/model/search_layer.py @@ -0,0 +1,192 @@ +import torch +from torch import nn +import dgl +import dgl.function as fn +from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES +from model.operations import * + + +class CompMixOp(nn.Module): + def __init__(self): + super(CompMixOp, self).__init__() + self._ops = nn.ModuleList() + for primitive in COMP_PRIMITIVES: + op = COMP_OPS[primitive]() + self._ops.append(op) + + def reset_parameters(self): + for op in self._ops: + op.reset_parameters() + + def forward(self, src_emb, rel_emb, weights): + mixed_res = [] + for w, op in zip(weights, self._ops): + mixed_res.append(w * op(src_emb, rel_emb)) + return sum(mixed_res) + + +class AggMixOp(nn.Module): + def __init__(self): + super(AggMixOp, self).__init__() + self._ops = nn.ModuleList() + for primitive in AGG_PRIMITIVES: + op = AGG_OPS[primitive]() + self._ops.append(op) + + def reset_parameters(self): + for op in self._ops: + op.reset_parameters() + + def forward(self, msg, weights): + mixed_res = [] + for w, op in zip(weights, self._ops): + mixed_res.append(w * op(msg)) + return sum(mixed_res) + + +class CombMixOp(nn.Module): + def __init__(self, out_channels): + super(CombMixOp, self).__init__() + self._ops = nn.ModuleList() + for primitive in COMB_PRIMITIVES: + op = COMB_OPS[primitive](out_channels) + self._ops.append(op) + + def reset_parameters(self): + for op in self._ops: + op.reset_parameters() + + def forward(self, self_emb, msg, weights): + mixed_res = [] + for w, op in zip(weights, self._ops): + mixed_res.append(w * op(self_emb, msg)) + return sum(mixed_res) + + +class ActMixOp(nn.Module): + def __init__(self): + super(ActMixOp, self).__init__() + self._ops = nn.ModuleList() + for primitive in ACT_PRIMITIVES: + op = ACT_OPS[primitive]() + self._ops.append(op) + + def reset_parameters(self): + for op in self._ops: + op.reset_parameters() + + def forward(self, emb, weights): + mixed_res = [] + for w, op in zip(weights, self._ops): + mixed_res.append(w * op(emb)) + return sum(mixed_res) + +class SearchGCNConv(nn.Module): + def __init__(self, in_channels, out_channels, act=lambda x: x, bias=True, drop_rate=0., num_base=-1, + num_rel=None, wni=False, wsi=False, use_bn=True, ltr=True, comp_weights=None, agg_weights=None): + super(SearchGCNConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + # self.act = act # activation function + self.device = None + + self.rel = nn.Parameter(torch.empty([num_rel, in_channels], dtype=torch.float)) + + self.use_bn = use_bn + self.ltr = ltr + + # relation-type specific parameter + self.in_w = self.get_param([in_channels, out_channels]) + self.out_w = self.get_param([in_channels, out_channels]) + self.loop_w = self.get_param([in_channels, out_channels]) + # transform embedding of relations to next layer + self.w_rel = self.get_param([in_channels, out_channels]) + self.loop_rel = self.get_param([1, in_channels]) # self-loop embedding + + self.drop = nn.Dropout(drop_rate) + self.bn = torch.nn.BatchNorm1d(out_channels) + self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None + if num_base > 0: + self.rel_wt = self.get_param([num_rel, num_base]) + else: + self.rel_wt = None + + self.wni = wni + self.wsi = wsi + self.comp_weights = comp_weights + self.agg_weights = agg_weights + self.comp = CompMixOp() + self.agg = AggMixOp() + self.comb = CombMixOp(out_channels) + self.act = ActMixOp() + + def get_param(self, shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu')) + return param + + def message_func(self, edges): + edge_type = edges.data['type'] # [E, 1] + edge_num = edge_type.shape[0] + edge_data = self.comp( + edges.src['h'], self.rel[edge_type], self.comp_weights) # [E, in_channel] + # NOTE: first half edges are all in-directions, last half edges are out-directions. + msg = torch.cat([torch.matmul(edge_data[:edge_num // 2, :], self.in_w), + torch.matmul(edge_data[edge_num // 2:, :], self.out_w)]) + msg = msg * edges.data['norm'].reshape(-1, 1) # [E, D] * [E, 1] + return {'msg': msg} + + def reduce_func(self, nodes): + # return {'h': torch.sum(nodes.mailbox['msg'], dim=1)} + return {'h': self.agg(nodes.mailbox['msg'], self.agg_weights)} + + def apply_node_func(self, nodes): + return {'h': self.drop(nodes.data['h'])} + + def forward(self, g: dgl.DGLGraph, x, rel_repr, edge_type, edge_norm, comp_weights, agg_weights, comb_weights, act_weights): + """ + :param g: dgl Graph, a graph without self-loop + :param x: input node features, [V, in_channel] + :param rel_repr: input relation features: 1. not using bases: [num_rel*2, in_channel] + 2. using bases: [num_base, in_channel] + :param edge_type: edge type, [E] + :param edge_norm: edge normalization, [E] + :return: x: output node features: [V, out_channel] + rel: output relation features: [num_rel*2, out_channel] + """ + self.device = x.device + g = g.local_var() + g.ndata['h'] = x + g.edata['type'] = edge_type + g.edata['norm'] = edge_norm + self.comp_weights = comp_weights + self.agg_weights = agg_weights + self.comb_weights = comb_weights + self.act_weights = act_weights + if self.rel_wt is None: + self.rel.data = rel_repr + else: + # [num_rel*2, num_base] @ [num_base, in_c] + self.rel.data = torch.mm(self.rel_wt, rel_repr) + g.update_all(self.message_func, self.reduce_func, self.apply_node_func) + + if (not self.wni) and (not self.wsi): + x = self.comb(g.ndata.pop('h'), torch.mm(self.comp(x, self.loop_rel, self.comp_weights), self.loop_w), self.comb_weights)*(1/3) + # x = (g.ndata.pop('h') + + # torch.mm(self.comp(x, self.loop_rel, self.comp_weights), self.loop_w)) / 3 + # else: + # if self.wsi: + # x = g.ndata.pop('h') / 2 + # if self.wni: + # x = torch.mm(self.comp(x, self.loop_rel), self.loop_w) + + if self.bias is not None: + x = x + self.bias + + if self.use_bn: + x = self.bn(x) + + if self.ltr: + return self.act(x, self.act_weights), torch.matmul(self.rel.data, self.w_rel) + else: + return self.act(x, self.act_weights), self.rel.data \ No newline at end of file diff --git a/model/search_layer_spos.py b/model/search_layer_spos.py new file mode 100644 index 0000000..919b762 --- /dev/null +++ b/model/search_layer_spos.py @@ -0,0 +1,185 @@ +import torch +from torch import nn +import dgl +import dgl.function as fn +from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES +from model.operations import * + + +class CompOpBlock(nn.Module): + def __init__(self): + super(CompOpBlock, self).__init__() + self._ops = nn.ModuleList() + for primitive in COMP_PRIMITIVES: + op = COMP_OPS[primitive]() + self._ops.append(op) + + def reset_parameters(self): + for op in self._ops: + op.reset_parameters() + + def forward(self, src_emb, rel_emb, primitive): + return self._ops[COMP_PRIMITIVES.index(primitive)](src_emb, rel_emb) + + +class AggOpBlock(nn.Module): + def __init__(self): + super(AggOpBlock, self).__init__() + self._ops = nn.ModuleList() + for primitive in AGG_PRIMITIVES: + op = AGG_OPS[primitive]() + self._ops.append(op) + + def reset_parameters(self): + for op in self._ops: + op.reset_parameters() + + def forward(self, msg, primitive): + return self._ops[AGG_PRIMITIVES.index(primitive)](msg) + + +class CombOpBlock(nn.Module): + def __init__(self, out_channels): + super(CombOpBlock, self).__init__() + self._ops = nn.ModuleList() + for primitive in COMB_PRIMITIVES: + op = COMB_OPS[primitive](out_channels) + self._ops.append(op) + + def reset_parameters(self): + for op in self._ops: + op.reset_parameters() + + def forward(self, self_emb, msg, primitive): + return self._ops[COMB_PRIMITIVES.index(primitive)](self_emb, msg) + + +class ActOpBlock(nn.Module): + def __init__(self): + super(ActOpBlock, self).__init__() + self._ops = nn.ModuleList() + for primitive in ACT_PRIMITIVES: + op = ACT_OPS[primitive]() + self._ops.append(op) + + def reset_parameters(self): + for op in self._ops: + op.reset_parameters() + + def forward(self, emb, primitive): + return self._ops[ACT_PRIMITIVES.index(primitive)](emb) + + +class SearchSPOSGCNConv(nn.Module): + def __init__(self, in_channels, out_channels, act=lambda x: x, bias=True, drop_rate=0., num_base=-1, + num_rel=None, wni=False, wsi=False, use_bn=True, ltr=True): + super(SearchSPOSGCNConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + # self.act = act # activation function + self.device = None + + self.rel = nn.Parameter(torch.empty([num_rel, in_channels], dtype=torch.float)) + + self.use_bn = use_bn + self.ltr = ltr + + # relation-type specific parameter + self.in_w = self.get_param([in_channels, out_channels]) + self.out_w = self.get_param([in_channels, out_channels]) + self.loop_w = self.get_param([in_channels, out_channels]) + # transform embedding of relations to next layer + self.w_rel = self.get_param([in_channels, out_channels]) + self.loop_rel = self.get_param([1, in_channels]) # self-loop embedding + + self.drop = nn.Dropout(drop_rate) + self.bn = torch.nn.BatchNorm1d(out_channels) + self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None + if num_base > 0: + self.rel_wt = self.get_param([num_rel, num_base]) + else: + self.rel_wt = None + + self.wni = wni + self.wsi = wsi + self.comp = CompOpBlock() + # self.agg = AggOpBlock() + self.comb = CombOpBlock(out_channels) + self.act = ActOpBlock() + + def get_param(self, shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu')) + return param + + def message_func(self, edges): + edge_type = edges.data['type'] # [E, 1] + edge_num = edge_type.shape[0] + edge_data = self.comp( + edges.src['h'], self.rel[edge_type], self.comp_primitive) # [E, in_channel] + # NOTE: first half edges are all in-directions, last half edges are out-directions. + msg = torch.cat([torch.matmul(edge_data[:edge_num // 2, :], self.in_w), + torch.matmul(edge_data[edge_num // 2:, :], self.out_w)]) + msg = msg * edges.data['norm'].reshape(-1, 1) # [E, D] * [E, 1] + return {'msg': msg} + + def reduce_func(self, nodes): + # return {'h': torch.sum(nodes.mailbox['msg'], dim=1)} + return {'h': self.agg(nodes.mailbox['msg'], self.agg_primitive)} + + def apply_node_func(self, nodes): + return {'h': self.drop(nodes.data['h'])} + + def forward(self, g: dgl.DGLGraph, x, rel_repr, edge_type, edge_norm, comp_primitive, agg_primitive, comb_primitive, act_primitive): + """ + :param g: dgl Graph, a graph without self-loop + :param x: input node features, [V, in_channel] + :param rel_repr: input relation features: 1. not using bases: [num_rel*2, in_channel] + 2. using bases: [num_base, in_channel] + :param edge_type: edge type, [E] + :param edge_norm: edge normalization, [E] + :return: x: output node features: [V, out_channel] + rel: output relation features: [num_rel*2, out_channel] + """ + self.device = x.device + g = g.local_var() + g.ndata['h'] = x + g.edata['type'] = edge_type + g.edata['norm'] = edge_norm + self.comp_primitive = comp_primitive + self.agg_primitive = agg_primitive + self.comb_primitive = comb_primitive + self.act_primitive = act_primitive + if self.rel_wt is None: + self.rel.data = rel_repr + else: + # [num_rel*2, num_base] @ [num_base, in_c] + self.rel.data = torch.mm(self.rel_wt, rel_repr) + if self.agg_primitive == 'max': + g.update_all(self.message_func, fn.max(msg='msg', out='h'), self.apply_node_func) + elif self.agg_primitive == 'mean': + g.update_all(self.message_func, fn.mean(msg='msg', out='h'), self.apply_node_func) + elif self.agg_primitive == 'sum': + g.update_all(self.message_func, fn.sum(msg='msg', out='h'), self.apply_node_func) + # g.update_all(self.message_func, self.reduce_func, self.apply_node_func) + + if (not self.wni) and (not self.wsi): + x = self.comb(g.ndata.pop('h'), torch.mm(self.comp(x, self.loop_rel, self.comp_primitive), self.loop_w), self.comb_primitive)*(1/3) + # x = (g.ndata.pop('h') + + # torch.mm(self.comp(x, self.loop_rel, self.comp_weights), self.loop_w)) / 3 + # else: + # if self.wsi: + # x = g.ndata.pop('h') / 2 + # if self.wni: + # x = torch.mm(self.comp(x, self.loop_rel), self.loop_w) + + if self.bias is not None: + x = x + self.bias + + if self.use_bn: + x = self.bn(x) + + if self.ltr: + return self.act(x, self.act_primitive), torch.matmul(self.rel.data, self.w_rel) + else: + return self.act(x, self.act_primitive), self.rel.data \ No newline at end of file diff --git a/model/subgraph_selector.py b/model/subgraph_selector.py new file mode 100644 index 0000000..ecb6583 --- /dev/null +++ b/model/subgraph_selector.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class SubgraphSelector(nn.Module): + def __init__(self, args): + super(SubgraphSelector,self).__init__() + self.args = args + self.temperature = self.args.temperature + self.num_layers = self.args.ss_num_layer + self.cat_type = self.args.combine_type + if self.cat_type == 'mult': + in_channels = self.args.ss_input_dim + else: + in_channels = self.args.ss_input_dim * 2 + hidden_channels = self.args.ss_hidden_dim + self.trans = nn.ModuleList() + for i in range(self.num_layers - 1): + if i == 0: + self.trans.append(nn.Linear(in_channels, hidden_channels, bias=False)) + else: + self.trans.append(nn.Linear(hidden_channels, hidden_channels, bias=False)) + self.trans.append(nn.Linear(hidden_channels, 1, bias=False)) + + def forward(self, x, mode='argmax', search_algorithm='darts'): + for layer in self.trans[:-1]: + x = layer(x) + x= F.relu(x) + x = self.trans[-1](x) + x = torch.squeeze(x,dim=2) + if search_algorithm == 'darts': + arch_set = torch.softmax(x/self.temperature,dim=1) + elif search_algorithm == 'snas': + arch_set = self._get_categ_mask(x) + if mode == 'argmax': + device = arch_set.device + n, c = arch_set.shape + eyes_atten = torch.eye(c).to(device) + atten_ , atten_indice = torch.max(arch_set, dim=1) + arch_set = eyes_atten[atten_indice] + return arch_set + # raise NotImplementedError + + def _get_categ_mask(self, alpha): + # log_alpha = torch.log(alpha) + log_alpha = alpha + u = torch.zeros_like(log_alpha).uniform_() + softmax = torch.nn.Softmax(-1) + one_hot = softmax((log_alpha + (-((-(u.log())).log()))) / self.temperature) + return one_hot \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000..5dd1ae8 --- /dev/null +++ b/run.py @@ -0,0 +1,3089 @@ +import csv +import os +import argparse +import time +import logging +from pprint import pprint +import numpy as np +import random + +import pandas as pd +import torch +from torch.utils.data import DataLoader +import dgl +from data.knowledge_graph import load_data +from model import GCN_TransE, GCN_DistMult, GCN_ConvE, SubgraphSelector, GCN_ConvE_Rel, GCN_Transformer, GCN_None, \ + GCN_MLP, GCN_MLP_NCN, SearchGCN_MLP, SearchedGCN_MLP, NetworkGNN_MLP, SearchGCN_MLP_SPOS, SEAL_GCN +from model.lte_models import TransE, DistMult, ConvE +from utils import process, process_multi_label, TrainDataset, TestDataset, get_logger, GraphTrainDataset, GraphTestDataset, \ + get_f1_score_list, get_acc_list, get_neighbor_nodes, NCNDataset, Temp_Scheduler +import wandb +from os.path import exists +from os import mkdir, makedirs +from dgl.dataloading import GraphDataLoader +from sklearn import metrics +import matplotlib.pyplot as plt +import pickle +from tqdm import tqdm +from torch.nn.utils import clip_grad_norm_ +import setproctitle +from hyperopt import fmin, tpe, hp, Trials, partial, STATUS_OK, rand, space_eval +from model.genotypes import * +import optuna +from optuna.samplers import RandomSampler +from utils import CategoricalASNG +import itertools +from sortedcontainers import SortedDict +from dgl import NID, EID + +torch.multiprocessing.set_sharing_strategy('file_system') + + +class Runner(object): + def __init__(self, params): + self.p = params + self.prj_path = os.getcwd() + self.data = load_data(self.p.dataset) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.num_ent, self.train_data, self.num_rels = self.data.num_nodes, self.data.train_graph, self.data.num_rels + # self.train_input, self.valid_rel, self.test_rel = self.data.train_rel, self.data.valid_rel, self.data.test_rel + # self.train_pos_neg, self.valid_pos_neg, self.test_pos_neg = self.data.train_pos_neg, self.data.valid_pos_neg, self.data.test_pos_neg + self.triplets = process_multi_label( + {'train': self.data.train_input, 'valid': self.data.valid_input, 'test': self.data.test_input}, + {'train': self.data.train_multi_label, 'valid': self.data.valid_multi_label, 'test': self.data.test_multi_label}, + {'train': self.data.train_pos_neg, 'valid': self.data.valid_pos_neg, 'test': self.data.test_pos_neg} + ) + else: + self.num_ent, self.train_data, self.valid_data, self.test_data, self.num_rels = self.data.num_nodes, self.data.train, self.data.valid, self.data.test, self.data.num_rels + self.triplets, self.class2num = process( + {'train': self.train_data, 'valid': self.valid_data, 'test': self.test_data}, + self.num_rels, self.p.n_layer, self.p.add_reverse) + self.p.embed_dim = self.p.k_w * \ + self.p.k_h if self.p.embed_dim is None else self.p.embed_dim # output dim of gnn + self.g = self.build_graph() + self.edge_type, self.edge_norm = self.get_edge_dir_and_norm() + if self.p.input_type == "subgraph": + self.get_subgraph() + self.data_iter = self.get_data_iter() + + if (self.p.search_mode != 'arch_random' or self.p.search_mode != 'arch_search') and self.p.search_algorithm!='random_ps2': + self.model = self.get_model() + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.p.lr, weight_decay=self.p.l2) + self.best_val_f1, self.best_val_auroc, self.best_epoch, self.best_val_results, self.best_test_results = 0., 0., 0., {}, {} + self.best_test_f1, self.best_test_auroc = 0., 0. + self.early_stop_cnt = 0 + os.makedirs(f'logs/{self.p.dataset}/', exist_ok=True) + if self.p.train_mode == 'tune': + tmp_name = self.p.name + '_tune' + self.logger = get_logger(f'logs/{self.p.dataset}/', tmp_name) + else: + self.logger = get_logger(f'logs/{self.p.dataset}/', self.p.name) + pprint(vars(self.p)) + + def save_model(self, path): + """ + Function to save a model. It saves the model parameters, best validation scores, + best epoch corresponding to best validation, state of the optimizer and all arguments for the run. + :param path: path where the model is saved + :return: + """ + state = { + 'model': self.model.state_dict(), + 'best_val': self.best_val_results, + 'best_epoch': self.best_epoch, + 'optimizer': self.optimizer.state_dict(), + 'args': vars(self.p) + } + torch.save(state, path) + + def save_search_model(self, path): + """ + Function to save a model. It saves the model parameters, best validation scores, + best epoch corresponding to best validation, state of the optimizer and all arguments for the run. + :param path: path where the model is saved + :return: + """ + state = { + 'model': self.model.state_dict(), + 'best_val': self.best_val_results, + 'best_epoch': self.best_epoch, + 'optimizer': self.optimizer.state_dict(), + 'args': vars(self.p) + } + torch.save(state, path) + + def load_model(self, path): + """ + Function to load a saved model + :param path: path where model is loaded + :return: + """ + state = torch.load(path) + self.best_val_results = state['best_val'] + self.best_epoch = state['best_epoch'] + self.model.load_state_dict(state['model']) + self.optimizer.load_state_dict(state['optimizer']) + + def build_graph(self): + g = dgl.DGLGraph() + g.add_nodes(self.num_ent + 1) + g.add_edges(self.train_data[:, 0], self.train_data[:, 2]) + if self.p.add_reverse: + g.add_edges(self.train_data[:, 2], self.train_data[:, 0]) + return g + + def get_data_iter(self): + + def get_data_loader(dataset_class, split, shuffle=True): + return DataLoader( + dataset_class(self.triplets[split], self.num_ent, self.num_rels, self.p), + batch_size=self.p.batch_size, + shuffle=shuffle, + num_workers=self.p.num_workers + ) + + def get_graph_data_loader(dataset_class, split, db_name_pos=None): + return GraphDataLoader( + dataset_class(self.triplets[split], self.num_ent, self.num_rels, self.p, self.g, db_name_pos), + batch_size=self.p.batch_size, + shuffle=True, + num_workers=self.p.num_workers + ) + + def get_ncndata_loader(dataset_class, split, db_name_pos=None): + return DataLoader( + dataset_class(self.triplets[split], self.num_ent, self.num_rels, self.p, self.adj, db_name_pos), + batch_size=self.p.batch_size, + shuffle=True, + num_workers=self.p.num_workers + ) + + if self.p.input_type == 'subgraph' or self.p.fine_tune_with_implicit_subgraph: + if self.p.add_reverse: + return { + 'train_rel': get_graph_data_loader(GraphTrainDataset, 'train_rel'), + 'valid_rel': get_graph_data_loader(GraphTestDataset, 'valid_rel'), + 'valid_rel_inv': get_graph_data_loader(GraphTestDataset, 'valid_rel_inv'), + 'test_rel': get_graph_data_loader(GraphTestDataset, 'test_rel'), + 'test_rel_inv': get_graph_data_loader(GraphTestDataset, 'test_rel_inv'), + # 'valid_head': get_data_loader(TestDataset, 'valid_head'), + # 'valid_tail': get_data_loader(TestDataset, 'valid_tail'), + # 'test_head': get_data_loader(TestDataset, 'test_head'), + # 'test_tail': get_data_loader(TestDataset, 'test_tail'), + } + else: + return { + 'train_rel': get_graph_data_loader(GraphTrainDataset, 'train_rel', 'train_pos'), + 'valid_rel': get_graph_data_loader(GraphTestDataset, 'valid_rel', 'valid_pos'), + 'test_rel': get_graph_data_loader(GraphTestDataset, 'test_rel', 'test_pos'), + } + elif self.p.input_type == 'allgraph' and self.p.score_func == 'mlp_ncn': + return { + 'train_rel': get_ncndata_loader(NCNDataset, 'train_rel', 'train_pos'), + 'valid_rel': get_ncndata_loader(NCNDataset, 'valid_rel', 'valid_pos'), + 'test_rel': get_ncndata_loader(NCNDataset, 'test_rel', 'test_pos'), + } + else: + if self.p.add_reverse: + return { + 'train_rel': get_data_loader(TrainDataset, 'train_rel'), + 'valid_rel': get_data_loader(TestDataset, 'valid_rel'), + 'valid_rel_inv': get_data_loader(TestDataset, 'valid_rel_inv'), + 'test_rel': get_data_loader(TestDataset, 'test_rel'), + 'test_rel_inv': get_data_loader(TestDataset, 'test_rel_inv') + } + else: + return { + 'train_rel': get_data_loader(TrainDataset, 'train_rel'), + 'valid_rel': get_data_loader(TestDataset, 'valid_rel'), + 'test_rel': get_data_loader(TestDataset, 'test_rel') + } + + def get_edge_dir_and_norm(self): + """ + :return: edge_type: indicates type of each edge: [E] + """ + in_deg = self.g.in_degrees(range(self.g.number_of_nodes())).float() + norm = in_deg ** -0.5 + norm[torch.isinf(norm).bool()] = 0 + self.g.ndata['xxx'] = norm + self.g.apply_edges( + lambda edges: {'xxx': edges.dst['xxx'] * edges.src['xxx']}) + if self.p.gpu >= 0: + norm = self.g.edata.pop('xxx').squeeze().to("cuda:0") + if self.p.add_reverse: + edge_type = torch.tensor(np.concatenate( + [self.train_data[:, 1], self.train_data[:, 1] + self.num_rels])).to("cuda:0") + else: + edge_type = torch.tensor(self.train_data[:, 1]).to("cuda:0") + else: + norm = self.g.edata.pop('xxx').squeeze() + edge_type = torch.tensor(np.concatenate( + [self.train_data[:, 1], self.train_data[:, 1] + self.num_rels])) + return edge_type, norm + + def get_model(self): + if self.p.n_layer > 0: + if self.p.score_func.lower() == 'transe': + model = GCN_TransE(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases, + init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim, + n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm, + bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn, + hid_drop=self.p.hid_drop, gamma=self.p.gamma, wni=self.p.wni, wsi=self.p.wsi, + encoder=self.p.encoder, use_bn=(not self.p.nobn), ltr=(not self.p.noltr)) + elif self.p.encoder == 'gcn': + model = SEAL_GCN(self.num_ent, self.num_rels, self.p.init_dim, self.p.gcn_dim, self.p.embed_dim, self.p.n_layer, loss_type=self.p.loss_type) + elif self.p.score_func.lower() == 'distmult': + model = GCN_DistMult(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases, + init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim, + n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm, + bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn, + hid_drop=self.p.hid_drop, wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder, + use_bn=(not self.p.nobn), ltr=(not self.p.noltr)) + elif self.p.score_func.lower() == 'conve': + model = GCN_ConvE(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases, + init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim, + n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm, + bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn, + hid_drop=self.p.hid_drop, input_drop=self.p.input_drop, + conve_hid_drop=self.p.conve_hid_drop, feat_drop=self.p.feat_drop, + num_filt=self.p.num_filt, ker_sz=self.p.ker_sz, k_h=self.p.k_h, k_w=self.p.k_w, + wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder, use_bn=(not self.p.nobn), + ltr=(not self.p.noltr)) + elif self.p.score_func.lower() == 'conve_rel': + model = GCN_ConvE_Rel(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases, + init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim, + n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm, + bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn, + hid_drop=self.p.hid_drop, input_drop=self.p.input_drop, + conve_hid_drop=self.p.conve_hid_drop, feat_drop=self.p.feat_drop, + num_filt=self.p.num_filt, ker_sz=self.p.ker_sz, k_h=self.p.k_h, k_w=self.p.k_w, + wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder, use_bn=(not self.p.nobn), + ltr=(not self.p.noltr), input_type=self.p.input_type) + elif self.p.score_func.lower() == 'transformer': + model = GCN_Transformer(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases, + init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim, + n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm, + bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn, + hid_drop=self.p.hid_drop, input_drop=self.p.input_drop, + conve_hid_drop=self.p.conve_hid_drop, feat_drop=self.p.feat_drop, + num_filt=self.p.num_filt, ker_sz=self.p.ker_sz, k_h=self.p.k_h, k_w=self.p.k_w, + wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder, + use_bn=(not self.p.nobn), + ltr=(not self.p.noltr), input_type=self.p.input_type, + d_model=self.p.d_model, num_transformer_layers=self.p.num_transformer_layers, + nhead=self.p.nhead, dim_feedforward=self.p.dim_feedforward, + transformer_dropout=self.p.transformer_dropout, + transformer_activation=self.p.transformer_activation, + graph_pooling=self.p.graph_pooling_type, concat_type=self.p.concat_type, + max_input_len=self.p.subgraph_max_num_nodes, loss_type=self.p.loss_type) + elif self.p.score_func.lower() == 'none': + model = GCN_None(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases, + init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim, + n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm, + bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn, + hid_drop=self.p.hid_drop, input_drop=self.p.input_drop, + conve_hid_drop=self.p.conve_hid_drop, feat_drop=self.p.feat_drop, + num_filt=self.p.num_filt, ker_sz=self.p.ker_sz, k_h=self.p.k_h, k_w=self.p.k_w, + wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder, + use_bn=(not self.p.nobn), + ltr=(not self.p.noltr), input_type=self.p.input_type, + graph_pooling=self.p.graph_pooling_type, concat_type=self.p.concat_type, + loss_type=self.p.loss_type, add_reverse=self.p.add_reverse) + elif self.p.score_func.lower() == 'mlp' and self.p.encoder != 'searchgcn' and self.p.genotype is None: + model = GCN_MLP(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases, + init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim, + n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm, + bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn, + hid_drop=self.p.hid_drop, input_drop=self.p.input_drop, + conve_hid_drop=self.p.conve_hid_drop, feat_drop=self.p.feat_drop, + num_filt=self.p.num_filt, ker_sz=self.p.ker_sz, k_h=self.p.k_h, k_w=self.p.k_w, + wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder, + use_bn=(not self.p.nobn), + ltr=(not self.p.noltr), input_type=self.p.input_type, + graph_pooling=self.p.graph_pooling_type, combine_type=self.p.combine_type, + loss_type=self.p.loss_type, add_reverse=self.p.add_reverse) + elif self.p.score_func.lower() == 'mlp_ncn': + model = GCN_MLP_NCN(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases, + init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim, + n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm, + bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn, + hid_drop=self.p.hid_drop, input_drop=self.p.input_drop, + conve_hid_drop=self.p.conve_hid_drop, feat_drop=self.p.feat_drop, + num_filt=self.p.num_filt, ker_sz=self.p.ker_sz, k_h=self.p.k_h, k_w=self.p.k_w, + wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder, + use_bn=(not self.p.nobn), + ltr=(not self.p.noltr), input_type=self.p.input_type, + graph_pooling=self.p.graph_pooling_type, combine_type=self.p.combine_type, + loss_type=self.p.loss_type, add_reverse=self.p.add_reverse) + elif self.p.genotype is not None: + model = SearchedGCN_MLP(args=self.p, num_ent=self.num_ent, num_rel=self.num_rels, + num_base=self.p.num_bases, + init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim, + n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm, + bias=self.p.bias, gcn_drop=self.p.gcn_drop, hid_drop=self.p.hid_drop, + input_drop=self.p.input_drop, + wni=self.p.wni, wsi=self.p.wsi, use_bn=(not self.p.nobn), + ltr=(not self.p.noltr), + combine_type=self.p.combine_type, loss_type=self.p.loss_type, + genotype=self.p.genotype) + elif "spos" in self.p.search_algorithm or self.p.train_mode=='vis_hop' or (self.p.train_mode=='spos_tune' and self.p.weight_sharing == True): + model = SearchGCN_MLP_SPOS(args=self.p, num_ent=self.num_ent, num_rel=self.num_rels, + num_base=self.p.num_bases, + init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim, + n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm, + bias=self.p.bias, gcn_drop=self.p.gcn_drop, hid_drop=self.p.hid_drop, + input_drop=self.p.input_drop, + wni=self.p.wni, wsi=self.p.wsi, use_bn=(not self.p.nobn), ltr=(not self.p.noltr), + combine_type=self.p.combine_type, loss_type=self.p.loss_type) + elif self.p.score_func.lower() == 'mlp' and self.p.genotype is None: + model = SearchGCN_MLP(args=self.p, num_ent=self.num_ent, num_rel=self.num_rels, + num_base=self.p.num_bases, + init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim, + n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm, + bias=self.p.bias, gcn_drop=self.p.gcn_drop, hid_drop=self.p.hid_drop, + input_drop=self.p.input_drop, + wni=self.p.wni, wsi=self.p.wsi, use_bn=(not self.p.nobn), ltr=(not self.p.noltr), + combine_type=self.p.combine_type, loss_type=self.p.loss_type) + else: + raise KeyError( + f'score function {self.p.score_func} not recognized.') + else: + if self.p.score_func.lower() == 'transe': + model = TransE(self.num_ent, self.num_rels, params=self.p) + elif self.p.score_func.lower() == 'distmult': + model = DistMult(self.num_ent, self.num_rels, params=self.p) + elif self.p.score_func.lower() == 'conve': + model = ConvE(self.num_ent, self.num_rels, params=self.p) + else: + raise NotImplementedError + + if self.p.gpu >= 0: + model.to("cuda:0") + return model + + def get_subgraph(self): + subgraph_dir = f'subgraph/{args.dataset}/{self.p.subgraph_type}_{self.p.subgraph_hop}_{self.p.subgraph_max_num_nodes}_{self.p.subgraph_sample_type}_{self.p.seed}' + if not exists(subgraph_dir): + makedirs(subgraph_dir) + + for mode in ['train_rel', 'valid_rel', 'test_rel']: + if self.p.subgraph_is_saved: + if self.p.save_mode == 'pickle': + with open( + f'{subgraph_dir}/{mode}_{self.p.subgraph_type}_{self.p.subgraph_hop}_{self.p.subgraph_max_num_nodes}_{self.p.subgraph_sample_type}_{self.p.seed}.pkl', + 'rb') as f: + sample_nodes = pickle.load(f) + # graph_list = dgl.load_graphs(f'{subgraph_dir}/{mode}_{self.p.subgraph_type}_{self.p.subgraph_hop}_{self.p.subgraph_max_num_nodes}_{self.p.subgraph_sample_type}_{self.p.seed}.bin')[0] + for idx, _ in enumerate(self.triplets[mode]): + self.triplets[mode][idx]['sample_nodes'] = sample_nodes[idx][0] + elif self.p.save_mode == 'graph': + if self.p.add_reverse: + graph_list = dgl.load_graphs( + f'{subgraph_dir}/{mode}_add_reverse.bin')[0] + else: + graph_list = dgl.load_graphs( + f'{subgraph_dir}/{mode}.bin')[0] + for idx, _ in enumerate(self.triplets[mode]): + self.triplets[mode][idx]['subgraph'] = graph_list[idx] + def random_search(self): + save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}' + os.makedirs(save_root, exist_ok=True) + save_path = f'{save_root}/{self.p.name}_random.pt' + for epoch in range(self.p.max_epochs): + start_time = time.time() + train_loss = self.train_epoch_rs() + val_results, valid_loss = self.evaluate_epoch('valid', mode='random') + if self.p.dataset == 'drugbank': + if val_results['macro_f1'] > self.best_val_f1: + self.best_val_results = val_results + self.best_val_f1 = val_results['macro_f1'] + self.best_epoch = epoch + self.save_model(save_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if self.p.dataset == 'drugbank': + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}") + wandb.log({"train_loss": train_loss, "best_valid_f1": self.best_val_f1}) + if self.early_stop_cnt == 15: + self.logger.info("Early stop!") + break + self.load_model(save_path) + self.logger.info( + f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data') + start = time.time() + test_results, test_loss = self.evaluate_epoch('test', mode='random') + end = time.time() + if self.p.dataset == 'drugbank': + self.logger.info( + f"f1: Rel {test_results['left_f1']:.5}, Rel_rev {test_results['right_f1']:.5}, Avg {test_results['macro_f1']:.5}") + self.logger.info( + f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + wandb.log({ + "test_acc": test_results['acc'], + "test_f1": test_results['macro_f1'], + "test_cohen": test_results['kappa'] + }) + + def train(self): + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.train_mode}/{self.p.name}/' + os.makedirs(save_root, exist_ok=True) + save_path = f'{save_root}/model_weight.pt' + for epoch in range(self.p.max_epochs): + start_time = time.time() + train_loss = self.train_epoch() + val_results, valid_loss = self.evaluate_epoch('valid', mode='normal') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + if val_results['auroc'] > self.best_val_auroc: + self.best_val_results = val_results + self.best_val_auroc = val_results['auroc'] + self.best_epoch = epoch + self.save_model(save_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + else: + if val_results['macro_f1'] > self.best_val_f1: + self.best_val_results = val_results + self.best_val_f1 = val_results['macro_f1'] + self.best_epoch = epoch + self.save_model(save_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}") + # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + # "best_valid_auroc": self.best_val_auroc}) + else: + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}") + # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + # "best_valid_f1": self.best_val_f1}) + if self.early_stop_cnt == 10: + self.logger.info("Early stop!") + break + # self.logger.info(vars(self.p)) + self.load_model(save_path) + self.logger.info( + f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data') + start = time.time() + test_results, test_loss = self.evaluate_epoch('test', mode='normal') + end = time.time() + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + # wandb.log({ + # "test_auroc": test_results['auroc'], + # "test_auprc": test_results['auprc'], + # "test_ap": test_results['ap'] + # }) + else: + self.logger.info( + f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + # wandb.log({ + # "test_acc": test_results['acc'], + # "test_f1": test_results['macro_f1'], + # "test_cohen": test_results['kappa'] + # }) + + def train_epoch(self): + self.model.train() + losses = [] + train_iter = self.data_iter['train_rel'] + # train_bar = tqdm(train_iter, ncols=0) + for step, batch in enumerate(train_iter): + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + if self.p.score_func == 'mlp_ncn': + cns = batch[2].to("cuda:0") + pred = self.model(g, subj, obj, cns) + else: + if self.p.encoder == 'gcn': + pred = self.model(g, g.ndata['z']) + else: + pred = self.model(g, subj, obj) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + loss = self.model.calc_loss(pred, labels, pos_neg) + else: + loss = self.model.calc_loss(pred, labels) + self.optimizer.zero_grad() + loss.backward() + if self.p.clip_grad: + clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2) + self.optimizer.step() + losses.append(loss.item()) + loss = np.mean(losses) + return loss + + def evaluate_epoch(self, split, mode='normal'): + + def get_combined_results(left, right): + results = dict() + results['acc'] = round((left['acc'] + right['acc']) / 2, 5) + results['left_f1'] = round(left['macro_f1'], 5) + results['right_f1'] = round(right['macro_f1'], 5) + results['macro_f1'] = round((left['macro_f1'] + right['macro_f1']) / 2, 5) + results['kappa'] = round((left['kappa'] + right['kappa']) / 2, 5) + results['macro_f1_per_class'] = (np.array(left['macro_f1_per_class']) + np.array( + right['macro_f1_per_class'])) / 2.0 + results['acc_per_class'] = (np.array(left['acc_per_class']) + np.array(right['acc_per_class'])) / 2.0 + return results + + def get_results(left): + results = dict() + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + results['auroc'] = round(left['auroc'], 5) + results['auprc'] = round(left['auprc'], 5) + results['ap'] = round(left['ap'], 5) + else: + results['acc'] = round(left['acc'], 5) + # results['auc_pr'] = round((left['auc_pr'] + right['auc_pr']) / 2, 5) + # results['micro_f1'] = round((left['micro_f1'] + right['micro_f1']) / 2, 5) + results['macro_f1'] = round(left['macro_f1'], 5) + results['kappa'] = round(left['kappa'], 5) + return results + + self.model.eval() + if mode == 'normal': + if self.p.add_reverse: + left_result, left_loss = self.predict(split, '') + right_result, right_loss = self.predict(split, '_inv') + else: + left_result, left_loss = self.predict(split, '') + elif mode == 'normal_mix_hop': + left_result, left_loss = self.predict_mix_hop(split, '') + elif mode == 'random': + left_result, left_loss = self.predict_rs(split, '') + right_result, right_loss = self.predict_rs(split, '_inv') + elif mode == 'ps2': + if self.p.add_reverse: + left_result, left_loss = self.predict_search(split, '') + right_result, right_loss = self.predict_search(split, '_inv') + else: + left_result, left_loss = self.predict_search(split, '') + elif mode == 'arch_search': + left_result, left_loss = self.predict_arch_search(split, '') + elif mode == 'arch_search_s': + left_result, left_loss = self.predict_arch_search(split, '', 'evaluate_single_path') + elif mode == 'joint_search': + left_result, left_loss = self.predict_joint_search(split, '') + elif mode == 'joint_search_s': + left_result, left_loss = self.predict_joint_search(split, '', 'evaluate_single_path') + elif mode == 'spos_train_supernet': + left_result, left_loss = self.predict_spos_search(split, '') + elif mode == 'spos_arch_search': + left_result, left_loss = self.predict_spos_search(split, '', spos_mode='arch_search') + elif mode == 'spos_train_supernet_ps2': + left_result, left_loss = self.predict_spos_search_ps2(split, '') + elif mode == 'spos_arch_search_ps2': + left_result, left_loss = self.predict_spos_search_ps2(split, '', spos_mode='arch_search') + # res = get_results(left_result) + # return res, left_loss + if self.p.add_reverse: + res = get_combined_results(left_result, right_result) + return res, (left_loss + right_loss) / 2.0 + else: + return left_result, left_loss + + def predict(self, split, mode): + loss_list = [] + pos_scores = [] + pos_labels = [] + pred_class = {} + self.model.eval() + with torch.no_grad(): + results = dict() + eval_iter = self.data_iter[f'{split}_rel{mode}'] + for step, batch in enumerate(eval_iter): + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + if self.p.score_func == 'mlp_ncn': + cns = batch[2].to("cuda:0") + pred = self.model(g, subj, obj, cns) + else: + if self.p.encoder == 'gcn': + pred = self.model(g, g.ndata['z']) + else: + pred = self.model(g, subj, obj) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + eval_loss = self.model.calc_loss(pred, labels, pos_neg) + m = torch.nn.Sigmoid() + pred = m(pred) + labels = labels.detach().to('cpu').numpy() + preds = pred.detach().to('cpu').numpy() + pos_neg = pos_neg.detach().to('cpu').numpy() + for (label_ids, pred, label_t) in zip(labels, preds, pos_neg): + for i, (l, p) in enumerate(zip(label_ids, pred)): + if l == 1: + if i in pred_class: + pred_class[i]['pred'] += [p] + pred_class[i]['l'] += [label_t] + pred_class[i]['pred_label'] += [1 if p > 0.5 else 0] + else: + pred_class[i] = {'pred': [p], 'l': [label_t], 'pred_label': [1 if p > 0.5 else 0]} + else: + eval_loss = self.model.calc_loss(pred, labels) + pos_labels += rel.to('cpu').numpy().flatten().tolist() + pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist() + loss_list.append(eval_loss.item()) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + roc_auc = [metrics.roc_auc_score(pred_class[l]['l'], pred_class[l]['pred']) for l in pred_class] + prc_auc = [metrics.average_precision_score(pred_class[l]['l'], pred_class[l]['pred']) for l in + pred_class] + ap = [metrics.accuracy_score(pred_class[l]['l'], pred_class[l]['pred_label']) for l in pred_class] + results['auroc'] = np.mean(roc_auc) + results['auprc'] = np.mean(prc_auc) + results['ap'] = np.mean(ap) + else: + results['acc'] = metrics.accuracy_score(pos_labels, pos_scores) + results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro') + results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores) + # dict_res = metrics.classification_report(pos_labels, pos_scores, output_dict=True, zero_division=1) + # results['macro_f1_per_class'] = get_f1_score_list(dict_res) + # results['acc_per_class'] = get_acc_list(dict_res) + # self.logger.info(f'Macro f1 per class: {results["macro_f1_per_class"]}') + # self.logger.info(f'Acc per class: {results["acc_per_class"]}') + loss = np.mean(loss_list) + return results, loss + + def train_epoch_rs(self): + self.model.train() + loss_list = [] + train_iter = self.data_iter['train_rel'] + for step, batch in enumerate(train_iter): + if self.p.input_type == 'subgraph': + g, triplets, labels, input_ids = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0"), \ + batch[3].to("cuda:0") + else: + triplets, labels, random_hops = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to('cuda:0') + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_rel = self.model.forward_search(g, subj, obj) + pred = self.model.compute_pred_rs(hidden_all_ent, all_rel, subj, obj, random_hops) + # pred = self.model(g, subj, obj, random_hops) # [batch_size, num_ent] + loss = self.model.calc_loss(pred, labels) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + loss_list.append(loss.item()) + loss = np.mean(loss_list) + return loss + + def predict_rs(self, split, mode): + loss_list = [] + pos_scores = [] + pos_labels = [] + self.model.eval() + with torch.no_grad(): + results = dict() + eval_iter = self.data_iter[f'{split}_rel{mode}'] + for step, batch in enumerate(eval_iter): + if self.p.input_type == 'subgraph': + g, triplets, labels, input_ids = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to( + "cuda:0"), batch[3].to("cuda:0") + else: + triplets, labels, random_hops = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to('cuda:0') + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_rel = self.model.forward_search(g, subj, obj) + pred = self.model.compute_pred_rs(hidden_all_ent, all_rel, subj, obj, random_hops) + loss = self.model.calc_loss(pred, labels) + loss_list.append(loss.item()) + pos_labels += rel.to('cpu').numpy().flatten().tolist() + pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist() + results['acc'] = metrics.accuracy_score(pos_labels, pos_scores) + results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro') + results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores) + loss = np.mean(loss_list) + return results, loss + + def fine_tune(self): + self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0") + save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}' + save_model_path = f'{save_root}/{self.p.name}.pt' + save_ss_path = f'{save_root}/{self.p.name}_ss.pt' + save_tune_path = f'{save_root}/{self.p.name}_tune.pt' + self.model.load_state_dict(torch.load(str(save_model_path))) + self.subgraph_selector.load_state_dict(torch.load(str(save_ss_path))) + self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='max', factor=0.2, patience=10, verbose=True) + val_results, val_loss = self.evaluate_epoch('valid', 'ps2') + test_results, test_loss = self.evaluate_epoch('test', 'ps2') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Validation]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + wandb.log({ + "init_test_auroc": test_results['auroc'], + "init_test_auprc": test_results['auprc'], + "init_test_ap": test_results['ap'] + }) + else: + self.logger.info( + f"[Validation]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + wandb.log({ + "init_test_acc": test_results['acc'], + "init_test_f1": test_results['macro_f1'], + "init_test_cohen": test_results['kappa'] + }) + for epoch in range(self.p.max_epochs): + start_time = time.time() + train_loss = self.train_epoch_fine_tune() + val_results, valid_loss = self.evaluate_epoch('valid', 'ps2') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + if val_results['auroc'] > self.best_val_auroc: + self.best_val_results = val_results + self.best_val_auroc = val_results['auroc'] + self.best_epoch = epoch + self.save_model(save_tune_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + else: + if val_results['macro_f1'] > self.best_val_f1: + self.best_val_results = val_results + self.best_val_f1 = val_results['macro_f1'] + self.best_epoch = epoch + self.save_model(save_tune_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}") + wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + "best_valid_auroc": self.best_val_auroc}) + self.scheduler.step(self.best_val_auroc) + else: + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}") + wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + "best_valid_f1": self.best_val_f1}) + self.scheduler.step(self.best_val_f1) + if self.early_stop_cnt == 15: + self.logger.info("Early stop!") + break + self.load_model(save_tune_path) + self.logger.info( + f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data') + test_results, test_loss = self.evaluate_epoch('test', 'ps2') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + wandb.log({ + "test_auroc": test_results['auroc'], + "test_auprc": test_results['auprc'], + "test_ap": test_results['ap'] + }) + else: + self.logger.info( + f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + wandb.log({ + "test_acc": test_results['acc'], + "test_f1": test_results['macro_f1'], + "test_cohen": test_results['kappa'] + }) + + def train_epoch_fine_tune(self, mode=None): + self.model.train() + self.subgraph_selector.eval() + loss_list = [] + train_iter = self.data_iter['train_rel'] + # train_bar = tqdm(train_iter, ncols=0) + for step, batch in enumerate(train_iter): + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_ent = self.model.forward_search(g, mode=mode) # [batch_size, num_ent] + if self.p.score_func == 'mlp_ncn': + cns = batch[2].to("cuda:0") + pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector, + mode='argmax') + else: + pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, mode='argmax', search_algorithm=self.p.ss_search_algorithm) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + train_loss = self.model.calc_loss(pred, labels, pos_neg) + else: + train_loss = self.model.calc_loss(pred, labels) + loss_list.append(train_loss.item()) + self.optimizer.zero_grad() + train_loss.backward() + self.optimizer.step() + + loss = np.mean(loss_list) + return loss + + def ps2(self): + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}' + os.makedirs(save_root, exist_ok=True) + self.logger = get_logger(f'{save_root}/',f'train') + save_model_path = f'{save_root}/weight.pt' + save_ss_path = f'{save_root}/weight_ss.pt' + + self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0") + self.subgraph_selector_optimizer = torch.optim.Adam( + self.subgraph_selector.parameters(), lr=self.p.ss_lr, weight_decay=self.p.l2) + # temp_scheduler = Temp_Scheduler(self.p.max_epochs, self.p.temperature, self.p.temperature, temp_min=self.p.temperature_min) + for epoch in range(self.p.max_epochs): + start_time = time.time() + # if self.p.cos_temp: + # self.p.temperature = temp_scheduler.step() + # else: + # self.p.temperature = self.p.temperature + train_loss = self.search_epoch() + val_results, valid_loss = self.evaluate_epoch('valid', 'ps2') + test_results, test_loss = self.evaluate_epoch('test', 'ps2') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + if val_results['auroc'] > self.best_val_auroc: + self.best_val_results = val_results + self.best_test_results = test_results + self.best_val_auroc = val_results['auroc'] + self.best_test_auroc = test_results['auroc'] + self.best_epoch = epoch + torch.save(self.model.state_dict(), str(save_model_path)) + torch.save(self.subgraph_selector.state_dict(), str(save_ss_path)) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + else: + if val_results['macro_f1'] > self.best_val_f1: + self.best_val_results = val_results + self.best_test_results = test_results + self.best_val_f1 = val_results['macro_f1'] + self.best_test_f1 = test_results['macro_f1'] + self.best_epoch = epoch + torch.save(self.model.state_dict(), str(save_model_path)) + torch.save(self.subgraph_selector.state_dict(), str(save_ss_path)) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}") + wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + "best_valid_auroc": self.best_val_auroc}) + wandb.log({'best_test_auroc': self.best_test_auroc}) + else: + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}") + # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + # "best_valid_f1": self.best_val_f1}) + # wandb.log({'best_test_f1': self.best_test_f1}) + if self.early_stop_cnt == 10: + self.logger.info("Early stop!") + break + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + # self.logger.info( + # f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + wandb.log({ + "test_auroc": self.best_test_results['auroc'], + "test_auprc": self.best_test_results['auprc'], + "test_ap": self.best_test_results['ap'] + }) + else: + # self.logger.info( + # f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + wandb.log({ + "test_acc": self.best_test_results['acc'], + "test_f1": self.best_test_results['macro_f1'], + "test_cohen": self.best_test_results['kappa'] + }) + + + def search_epoch(self): + self.model.train() + self.subgraph_selector.train() + loss_list = [] + train_iter = self.data_iter['train_rel'] + valid_iter = self.data_iter['valid_rel'] + # train_bar = tqdm(train_iter, ncols=0) + for step, batch in enumerate(train_iter): + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim] + if self.p.score_func == 'mlp_ncn': + cns = batch[2].to("cuda:0") + pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector) + else: + pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, search_algorithm=self.p.ss_search_algorithm) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + train_loss = self.model.calc_loss(pred, labels, pos_neg) + else: + train_loss = self.model.calc_loss(pred, labels) + loss_list.append(train_loss.item()) + self.optimizer.zero_grad() + train_loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() + self.subgraph_selector_optimizer.zero_grad() + + batch = next(iter(valid_iter)) + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim] + if self.p.score_func == 'mlp_ncn': + cns = batch[2].to("cuda:0") + pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector) + else: + pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, search_algorithm=self.p.ss_search_algorithm) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + valid_loss = self.model.calc_loss(pred, labels, pos_neg) + else: + valid_loss = self.model.calc_loss(pred, labels) + valid_loss.backward() + self.subgraph_selector_optimizer.step() + loss = np.mean(loss_list) + return loss + + def predict_search(self, split, mode): + loss_list = [] + pos_scores = [] + pos_labels = [] + pred_class = {} + self.model.eval() + self.subgraph_selector.eval() + with torch.no_grad(): + results = dict() + test_iter = self.data_iter[f'{split}_rel{mode}'] + for step, batch in enumerate(test_iter): + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim] + if self.p.score_func == 'mlp_ncn': + cns = batch[2].to("cuda:0") + pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector, + mode='argmax') + else: + pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, mode='argmax', search_algorithm=self.p.ss_search_algorithm) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + eval_loss = self.model.calc_loss(pred, labels, pos_neg) + m = torch.nn.Sigmoid() + pred = m(pred) + labels = labels.detach().to('cpu').numpy() + preds = pred.detach().to('cpu').numpy() + pos_neg = pos_neg.detach().to('cpu').numpy() + for (label_ids, pred, label_t) in zip(labels, preds, pos_neg): + for i, (l, p) in enumerate(zip(label_ids, pred)): + if l == 1: + if i in pred_class: + pred_class[i]['pred'] += [p] + pred_class[i]['l'] += [label_t] + pred_class[i]['pred_label'] += [1 if p > 0.5 else 0] + else: + pred_class[i] = {'pred': [p], 'l': [label_t], 'pred_label': [1 if p > 0.5 else 0]} + else: + eval_loss = self.model.calc_loss(pred, labels) + pos_labels += rel.to('cpu').numpy().flatten().tolist() + pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist() + loss_list.append(eval_loss.item()) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + roc_auc = [metrics.roc_auc_score(pred_class[l]['l'], pred_class[l]['pred']) for l in pred_class] + prc_auc = [metrics.average_precision_score(pred_class[l]['l'], pred_class[l]['pred']) for l in + pred_class] + ap = [metrics.accuracy_score(pred_class[l]['l'], pred_class[l]['pred_label']) for l in pred_class] + results['auroc'] = np.mean(roc_auc) + results['auprc'] = np.mean(prc_auc) + results['ap'] = np.mean(ap) + else: + results['acc'] = metrics.accuracy_score(pos_labels, pos_scores) + results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro') + results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores) + # dict_res = metrics.classification_report(pos_labels, pos_scores, output_dict=True, zero_division=1) + # results['macro_f1_per_class'] = get_f1_score_list(dict_res) + # results['acc_per_class'] = get_acc_list(dict_res) + loss = np.mean(loss_list) + return results, loss + + def architecture_search(self): + save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}' + os.makedirs(save_root, exist_ok=True) + save_model_path = f'{save_root}/{self.p.name}.pt' + + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, self.p.max_epochs, eta_min=self.p.lr_min) + self.arch_optimizer = torch.optim.Adam(self.model.arch_parameters(), lr=self.p.arch_lr, weight_decay=self.p.arch_weight_decay) + self.arch_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.arch_optimizer, self.p.max_epochs, eta_min=self.p.arch_lr_min) + temp_scheduler = Temp_Scheduler(self.p.max_epochs, self.p.temperature, self.p.temperature, temp_min=self.p.temperature_min) + for epoch in range(self.p.max_epochs): + start_time = time.time() + # genotype = self.model.genotype() + # self.logger.info(f'Genotype: {genotype}') + if self.p.cos_temp: + self.p.temperature = temp_scheduler.step() + else: + self.p.temperature = self.p.temperature + # print(self.p.temperature) + train_loss = self.arch_search_epoch() + self.scheduler.step() + self.arch_scheduler.step() + val_results, valid_loss = self.evaluate_epoch('valid', 'arch_search') + s_val_results, s_valid_loss = self.evaluate_epoch('valid', 'arch_search_s') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + if val_results['auroc'] > self.best_val_auroc: + self.best_val_results = val_results + self.best_val_auroc = val_results['auroc'] + self.best_epoch = epoch + # self.save_model(save_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + else: + if val_results['macro_f1'] > self.best_val_f1: + self.best_val_results = val_results + self.best_val_f1 = val_results['macro_f1'] + self.best_epoch = epoch + # self.save_model(save_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + genotype = self.model.genotype() + self.logger.info(f'[Epoch {epoch}]: LR: {self.scheduler.get_last_lr()[0]}, TEMP: {self.p.temperature}, Genotype: {genotype}') + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Valid_S Loss: {s_valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Valid_S AUROC: {s_val_results['auroc']:.5}, Valid_S AUPRC: {s_val_results['auprc']:.5}, Valid_S AP@50: {s_val_results['ap']:.5}") + else: + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Valid_S ACC: {s_val_results['acc']:.5}, Valid_S Macro F1: {s_val_results['macro_f1']:.5}, Valid_S Cohen: {s_val_results['kappa']:.5}") + # if self.early_stop_cnt == 15: + # self.logger.info("Early stop!") + # break + + def arch_search_epoch(self): + self.model.train() + loss_list = [] + train_iter = self.data_iter['train_rel'] + valid_iter = self.data_iter['valid_rel'] + # train_bar = tqdm(train_iter, ncols=0) + for step, batch in enumerate(train_iter): + for update_idx in range(self.p.w_update_epoch): + self.optimizer.zero_grad() + self.arch_optimizer.zero_grad() + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + # hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim] + pred = self.model(g, subj, obj) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + loss = self.model.calc_loss(pred, labels, pos_neg) + else: + loss = self.model.calc_loss(pred, labels) + self.arch_optimizer.zero_grad() + loss.backward(retain_graph=True) + if self.p.clip_grad: + clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2) + self.optimizer.step() + loss_list.append(loss.item()) + if self.p.alpha_mode == 'train_loss': + self.arch_optimizer.step() + elif self.p.alpha_mode == 'valid_loss': + self.optimizer.zero_grad() + self.arch_optimizer.zero_grad() + batch = next(iter(valid_iter)) + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + pred = self.model(g, subj, obj) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + valid_loss = self.model.calc_loss(pred, labels, pos_neg) + else: + valid_loss = self.model.calc_loss(pred, labels) + valid_loss.backward() + self.arch_optimizer.step() + loss = np.mean(loss_list) + return loss + + def predict_arch_search(self, split, mode, eval_mode=None): + loss_list = [] + pos_scores = [] + pos_labels = [] + pred_class = {} + self.model.eval() + with torch.no_grad(): + results = dict() + test_iter = self.data_iter[f'{split}_rel{mode}'] + for step, batch in enumerate(test_iter): + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + # hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim] + pred = self.model(g, subj, obj, mode=eval_mode) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + eval_loss = self.model.calc_loss(pred, labels, pos_neg) + m = torch.nn.Sigmoid() + pred = m(pred) + labels = labels.detach().to('cpu').numpy() + preds = pred.detach().to('cpu').numpy() + pos_neg = pos_neg.detach().to('cpu').numpy() + for (label_ids, pred, label_t) in zip(labels, preds, pos_neg): + for i, (l, p) in enumerate(zip(label_ids, pred)): + if l == 1: + if i in pred_class: + pred_class[i]['pred'] += [p] + pred_class[i]['l'] += [label_t] + pred_class[i]['pred_label'] += [1 if p > 0.5 else 0] + else: + pred_class[i] = {'pred': [p], 'l': [label_t], 'pred_label': [1 if p > 0.5 else 0]} + else: + eval_loss = self.model.calc_loss(pred, labels) + pos_labels += rel.to('cpu').numpy().flatten().tolist() + pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist() + loss_list.append(eval_loss.item()) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + roc_auc = [metrics.roc_auc_score(pred_class[l]['l'], pred_class[l]['pred']) for l in pred_class] + prc_auc = [metrics.average_precision_score(pred_class[l]['l'], pred_class[l]['pred']) for l in + pred_class] + ap = [metrics.accuracy_score(pred_class[l]['l'], pred_class[l]['pred_label']) for l in pred_class] + results['auroc'] = np.mean(roc_auc) + results['auprc'] = np.mean(prc_auc) + results['ap'] = np.mean(ap) + else: + results['acc'] = metrics.accuracy_score(pos_labels, pos_scores) + results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro') + results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores) + # dict_res = metrics.classification_report(pos_labels, pos_scores, output_dict=True, zero_division=1) + # results['macro_f1_per_class'] = get_f1_score_list(dict_res) + # results['acc_per_class'] = get_acc_list(dict_res) + loss = np.mean(loss_list) + return results, loss + + def arch_random_search_each(self, trial): + genotype_space = [] + for i in range(self.p.n_layer): + genotype_space.append(trial.suggest_categorical("mess"+ str(i), COMP_PRIMITIVES)) + genotype_space.append(trial.suggest_categorical("agg"+ str(i), AGG_PRIMITIVES)) + genotype_space.append(trial.suggest_categorical("comb"+ str(i), COMB_PRIMITIVES)) + genotype_space.append(trial.suggest_categorical("act"+ str(i), ACT_PRIMITIVES)) + self.best_val_f1 = 0.0 + self.best_val_auroc = 0.0 + self.early_stop_cnt = 0 + # self.best_valid_metric, self.best_test_metric = 0.0, {} + self.p.genotype = "||".join(genotype_space) + # run = self.reinit_wandb() + self.model = self.get_model().to("cuda:0") + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.p.lr, weight_decay=self.p.l2) + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}' + os.makedirs(save_root, exist_ok=True) + save_path = f'{save_root}/random_search.pt' + for epoch in range(self.p.max_epochs): + start_time = time.time() + train_loss = self.train_epoch() + val_results, valid_loss = self.evaluate_epoch('valid', mode='normal') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + if val_results['auroc'] > self.best_val_auroc: + self.best_val_results = val_results + self.best_val_auroc = val_results['auroc'] + self.best_epoch = epoch + self.save_model(save_path) + self.early_stop_cnt = 0 + # self.logger.info("Update best valid auroc!") + else: + self.early_stop_cnt += 1 + else: + if val_results['macro_f1'] > self.best_val_f1: + self.best_val_results = val_results + self.best_val_f1 = val_results['macro_f1'] + self.best_epoch = epoch + self.save_model(save_path) + self.early_stop_cnt = 0 + # self.logger.info("Update best valid f1!") + else: + self.early_stop_cnt += 1 + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}") + # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + # "best_valid_auroc": self.best_val_auroc}) + else: + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}") + # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + # "best_valid_f1": self.best_val_f1}) + if self.early_stop_cnt == 10: + self.logger.info("Early stop!") + break + # self.logger.info(vars(self.p)) + self.load_model(save_path) + self.logger.info( + f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data') + self.logger.info(f'{self.p.genotype}') + start = time.time() + test_results, test_loss = self.evaluate_epoch('test', mode='normal') + end = time.time() + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + # wandb.log({ + # "test_auroc": test_results['auroc'], + # "test_auprc": test_results['auprc'], + # "test_ap": test_results['ap'] + # }) + # run.finish() + if self.best_val_auroc > self.best_valid_metric: + self.best_valid_metric = self.best_val_auroc + self.best_test_metric = test_results + with open(f'{save_root}/random_search_arch_list.csv', "a") as f: + writer = csv.writer(f) + writer.writerow([self.p.genotype, self.best_val_auroc, test_results['auroc']]) + return self.best_val_auroc + else: + self.logger.info( + f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + # wandb.log({ + # "test_acc": test_results['acc'], + # "test_f1": test_results['macro_f1'], + # "test_cohen": test_results['kappa'] + # }) + # run.finish() + if self.best_val_f1 > self.best_valid_metric: + self.best_valid_metric = self.best_val_f1 + self.best_test_metric = test_results + with open(f'{save_root}/random_search_arch_list.csv', "a") as f: + writer = csv.writer(f) + writer.writerow([self.p.genotype, self.best_val_f1, test_results['macro_f1']]) + return self.best_val_f1 + + def arch_random_search(self): + self.best_valid_metric = 0.0 + self.best_test_metric = {} + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/' + os.makedirs(save_root, exist_ok=True) + self.logger = get_logger(f'{save_root}/', f'random_search') + study = optuna.create_study(directions=["maximize"], sampler=RandomSampler()) + study.optimize(self.arch_random_search_each, n_trials=self.p.baseline_sample_num) + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/' + with open(f'{save_root}/random_search_res.txt', "w") as f1: + f1.write(f'{self.p.__dict__}\n') + f1.write(f'{self.p.genotype}\n') + f1.write(f'Valid performance: {study.best_value}\n') + f1.write(f'Test performance: {self.best_test_metric}') + + def train_parameter(self, parameter): + self.best_val_f1 = 0.0 + self.best_val_auroc = 0.0 + self.early_stop_cnt = 0 + self.p.genotype = "||".join(parameter) + run = self.reinit_wandb() + self.model = self.get_model().to("cuda:0") + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.p.lr, weight_decay=self.p.l2) + save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}' + os.makedirs(save_root, exist_ok=True) + save_path = f'{save_root}/{self.p.name}.pt' + for epoch in range(self.p.max_epochs): + start_time = time.time() + train_loss = self.train_epoch() + val_results, valid_loss = self.evaluate_epoch('valid', mode='normal') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + if val_results['auroc'] > self.best_val_auroc: + self.best_val_results = val_results + self.best_val_auroc = val_results['auroc'] + self.best_epoch = epoch + self.save_model(save_path) + self.early_stop_cnt = 0 + self.logger.info("Update best valid auroc!") + else: + self.early_stop_cnt += 1 + else: + if val_results['macro_f1'] > self.best_val_f1: + self.best_val_results = val_results + self.best_val_f1 = val_results['macro_f1'] + self.best_epoch = epoch + self.save_model(save_path) + self.early_stop_cnt = 0 + self.logger.info("Update best valid f1!") + else: + self.early_stop_cnt += 1 + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}") + wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + "best_valid_auroc": self.best_val_auroc}) + else: + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}") + wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + "best_valid_f1": self.best_val_f1}) + if self.early_stop_cnt == 15: + self.logger.info("Early stop!") + break + # self.logger.info(vars(self.p)) + self.load_model(save_path) + self.logger.info( + f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data') + start = time.time() + test_results, test_loss = self.evaluate_epoch('test', mode='normal') + end = time.time() + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + wandb.log({ + "test_auroc": test_results['auroc'], + "test_auprc": test_results['auprc'], + "test_ap": test_results['ap'] + }) + else: + if self.p.add_reverse: + self.logger.info( + f"f1: Rel {test_results['left_f1']:.5}, Rel_rev {test_results['right_f1']:.5}, Avg {test_results['macro_f1']:.5}") + self.logger.info( + f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + wandb.log({ + "test_acc": test_results['acc'], + "test_f1": test_results['macro_f1'], + "test_cohen": test_results['kappa'] + }) + run.finish() + return {'loss': -self.best_val_f1, "status": STATUS_OK} + + def reinit_wandb(self): + if self.p.train_mode == 'spos_tune': + run = wandb.init( + reinit=True, + project=self.p.wandb_project, + settings=wandb.Settings(start_method="fork"), + config={ + "dataset": self.p.dataset, + "encoder": self.p.encoder, + "score_function": self.p.score_func, + "batch_size": self.p.batch_size, + "learning_rate": self.p.lr, + "weight_decay":self.p.l2, + "encoder_layer_num": self.p.n_layer, + "epochs": self.p.max_epochs, + "seed": self.p.seed, + "train_mode": self.p.train_mode, + "init_dim": self.p.init_dim, + "embed_dim": self.p.embed_dim, + "input_type": self.p.input_type, + "loss_type": self.p.loss_type, + "search_mode": self.p.search_mode, + "combine_type": self.p.combine_type, + "genotype": self.p.genotype, + "exp_note": self.p.exp_note, + "alpha_mode": self.p.alpha_mode, + "few_shot_op": self.p.few_shot_op, + "tune_sample_num": self.p.tune_sample_num + } + ) + elif self.p.search_mode == 'arch_random': + run = wandb.init( + reinit=True, + project=self.p.wandb_project, + settings=wandb.Settings(start_method="fork"), + config={ + "dataset": self.p.dataset, + "encoder": self.p.encoder, + "score_function": self.p.score_func, + "batch_size": self.p.batch_size, + "learning_rate": self.p.lr, + "weight_decay":self.p.l2, + "encoder_layer_num": self.p.n_layer, + "epochs": self.p.max_epochs, + "seed": self.p.seed, + "train_mode": self.p.train_mode, + "init_dim": self.p.init_dim, + "embed_dim": self.p.embed_dim, + "input_type": self.p.input_type, + "loss_type": self.p.loss_type, + "search_mode": self.p.search_mode, + "combine_type": self.p.combine_type, + "genotype": self.p.genotype, + "exp_note": self.p.exp_note, + "tune_sample_num": self.p.tune_sample_num + } + ) + return run + + def joint_search(self): + save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}' + os.makedirs(save_root, exist_ok=True) + save_model_path = f'{save_root}/{self.p.name}.pt' + save_ss_path = f'{save_root}/{self.p.name}_ss.pt' + self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0") + self.upper_optimizer = torch.optim.Adam([{'params': self.model.arch_parameters()}, + {'params': self.subgraph_selector.parameters(), 'lr': self.p.ss_lr}], + lr=self.p.arch_lr, weight_decay=self.p.arch_weight_decay) + for epoch in range(self.p.max_epochs): + start_time = time.time() + # genotype = self.model.genotype() + # self.logger.info(f'Genotype: {genotype}') + if self.p.cos_temp and epoch % 5 == 0 and epoch != 0: + self.p.temperature = self.p.temperature * 0.5 + else: + self.p.temperature = self.p.temperature + train_loss = self.joint_search_epoch() + # self.scheduler.step() + # self.arch_scheduler.step() + val_results, valid_loss = self.evaluate_epoch('valid', 'joint_search') + s_val_results, s_valid_loss = self.evaluate_epoch('valid', 'joint_search_s') + if val_results['macro_f1'] > self.best_val_f1: + self.best_val_results = val_results + self.best_val_f1 = val_results['macro_f1'] + self.best_epoch = epoch + torch.save(self.model.state_dict(), str(save_model_path)) + torch.save(self.subgraph_selector.state_dict(), str(save_ss_path)) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + genotype = self.model.genotype() + # self.logger.info(f'[Epoch {epoch}]: LR: {self.scheduler.get_last_lr()[0]}, TEMP: {self.p.temperature}, Genotype: {genotype}') + self.logger.info( + f'[Epoch {epoch}]: TEMP: {self.p.temperature}, Genotype: {genotype}') + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Valid_S Loss: {s_valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if self.p.dataset == 'drugbank': + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Valid_S ACC: {s_val_results['acc']:.5}, Valid Macro F1: {s_val_results['macro_f1']:.5}, Valid Cohen: {s_val_results['kappa']:.5}") + if self.early_stop_cnt == 50: + self.logger.info("Early stop!") + break + + def joint_search_epoch(self): + self.model.train() + self.subgraph_selector.train() + loss_list = [] + train_iter = self.data_iter['train_rel'] + valid_iter = self.data_iter['valid_rel'] + # train_bar = tqdm(train_iter, ncols=0) + for step, batch in enumerate(train_iter): + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim] + if self.p.score_func == 'mlp_ncn': + cns = batch[2].to("cuda:0") + pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector) + else: + pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, search_algorithm=self.p.search_algorithm) + train_loss = self.model.calc_loss(pred, labels) + loss_list.append(train_loss.item()) + self.optimizer.zero_grad() + train_loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() + self.upper_optimizer.zero_grad() + + batch = next(iter(valid_iter)) + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim] + if self.p.score_func == 'mlp_ncn': + cns = batch[2].to("cuda:0") + pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector) + else: + pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, search_algorithm=self.p.search_algorithm) + valid_loss = self.model.calc_loss(pred, labels) + valid_loss.backward() + self.upper_optimizer.step() + loss = np.mean(loss_list) + return loss + + def predict_joint_search(self, split, mode, eval_mode=None): + loss_list = [] + pos_scores = [] + pos_labels = [] + self.model.eval() + self.subgraph_selector.eval() + with torch.no_grad(): + results = dict() + test_iter = self.data_iter[f'{split}_rel{mode}'] + for step, batch in enumerate(test_iter): + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_ent = self.model.forward_search(g, mode=eval_mode) # [batch_size, num_ent] + if self.p.score_func == 'mlp_ncn': + cns = batch[2].to("cuda:0") + pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector, + mode='argmax') + else: + pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, mode='argmax', search_algorithm=self.p.search_algorithm) + eval_loss = self.model.calc_loss(pred, labels) + loss_list.append(eval_loss.item()) + pos_labels += rel.to('cpu').numpy().flatten().tolist() + pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist() + results['acc'] = metrics.accuracy_score(pos_labels, pos_scores) + results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro') + results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores) + loss = np.mean(loss_list) + return results, loss + + def joint_tune(self): + self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0") + save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}' + save_model_path = f'{save_root}/{self.p.name}.pt' + save_ss_path = f'{save_root}/{self.p.name}_ss.pt' + save_tune_path = f'{save_root}/{self.p.name}_tune.pt' + self.model.load_state_dict(torch.load(str(save_model_path))) + self.subgraph_selector.load_state_dict(torch.load(str(save_ss_path))) + self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='max', factor=0.2, patience=10, verbose=True) + + val_results, val_loss = self.evaluate_epoch('valid', 'joint_search_s') + test_results, test_loss = self.evaluate_epoch('test', 'joint_search_s') + if self.p.dataset == 'drugbank': + # if self.p.add_reverse: + # self.logger.info( + # f"f1: Rel {test_results['left_f1']:.5}, Rel_rev {test_results['right_f1']:.5}, Avg {test_results['macro_f1']:.5}") + self.logger.info( + f"[Validation]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + wandb.log({ + "init_test_acc": test_results['acc'], + "init_test_f1": test_results['macro_f1'], + "init_test_cohen": test_results['kappa'] + }) + for epoch in range(self.p.max_epochs): + start_time = time.time() + train_loss = self.train_epoch_fine_tune(mode='evaluate_single_path') + val_results, valid_loss = self.evaluate_epoch('valid', 'joint_search_s') + if val_results['macro_f1'] > self.best_val_f1: + self.best_val_results = val_results + self.best_val_f1 = val_results['macro_f1'] + self.best_epoch = epoch + self.save_model(save_tune_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + # torch.save(self.model.state_dict(), str(save_model_path)) + # torch.save(self.subgraph_selector.state_dict(), str(save_ss_path)) + genotype = self.model.genotype() + self.logger.info( + f'[Epoch {epoch}]: Genotype: {genotype}') + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Cost: {time.time() - start_time:.2f}s") + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}") + wandb.log({"train_loss": train_loss, "best_valid_f1": self.best_val_f1}) + if self.early_stop_cnt == 15: + self.logger.info("Early stop!") + break + self.scheduler.step(self.best_val_f1) + self.load_model(save_tune_path) + self.logger.info( + f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data') + test_results, test_loss = self.evaluate_epoch('test', 'joint_search_s') + test_acc = test_results['acc'] + test_f1 = test_results['macro_f1'] + test_cohen = test_results['kappa'] + wandb.log({ + "test_acc": test_acc, + "test_f1": test_f1, + "test_cohen": test_cohen + }) + if self.p.dataset == 'drugbank': + test_acc = test_results['acc'] + test_f1 = test_results['macro_f1'] + test_cohen = test_results['kappa'] + wandb.log({ + "test_acc": test_acc, + "test_f1": test_f1, + "test_cohen": test_cohen + }) + if self.p.add_reverse: + self.logger.info( + f"f1: Rel {test_results['left_f1']:.5}, Rel_rev {test_results['right_f1']:.5}, Avg {test_results['macro_f1']:.5}") + self.logger.info( + f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + + def spos_train_supernet(self): + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}' + print(save_root) + os.makedirs(save_root, exist_ok=True) + if self.p.weight_sharing: + self.logger = get_logger(f'{save_root}/', f'train_supernet_ws_{self.p.few_shot_op}') + save_model_path = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{"_".join(save_root.split("/")[-1].split("_")[:-3])}/400.pt' + self.model.load_state_dict(torch.load(save_model_path)) + else: + self.logger = get_logger(f'{save_root}/', f'train_supernet') + for epoch in range(1, self.p.max_epochs+1): + start_time = time.time() + train_loss = self.architecture_search_spos_epoch() + val_results, valid_loss = self.evaluate_epoch('valid', 'spos_train_supernet') + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + else: + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + # self.scheduler.step(train_loss) + wandb.log({ + "train_loss": train_loss, + "valid_loss": valid_loss + }) + if epoch % 100 == 0: + torch.save(self.model.state_dict(), f'{save_root}/{epoch}.pt') + + def architecture_search_spos_epoch(self): + self.model.train() + loss_list = [] + train_iter = self.data_iter['train_rel'] + # train_bar = tqdm(train_iter, ncols=0) + for step, batch in enumerate(train_iter): + self.optimizer.zero_grad() + self.generate_single_path() + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + # hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim] + pred = self.model(g, subj, obj) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + loss = self.model.calc_loss(pred, labels, pos_neg) + else: + loss = self.model.calc_loss(pred, labels) + loss.backward() + if self.p.clip_grad: + clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2) + self.optimizer.step() + loss_list.append(loss.item()) + loss = np.mean(loss_list) + return loss + + def predict_spos_search(self, split, mode, eval_mode=None, spos_mode='train_supernet'): + loss_list = [] + pos_scores = [] + pos_labels = [] + pred_class = {} + self.model.eval() + with torch.no_grad(): + results = dict() + test_iter = self.data_iter[f'{split}_rel{mode}'] + for step, batch in enumerate(test_iter): + if spos_mode == 'train_supernet': + self.generate_single_path() + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + # hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim] + pred = self.model(g, subj, obj, mode=eval_mode) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + eval_loss = self.model.calc_loss(pred, labels, pos_neg) + m = torch.nn.Sigmoid() + pred = m(pred) + labels = labels.detach().to('cpu').numpy() + preds = pred.detach().to('cpu').numpy() + pos_neg = pos_neg.detach().to('cpu').numpy() + for (label_ids, pred, label_t) in zip(labels, preds, pos_neg): + for i, (l, p) in enumerate(zip(label_ids, pred)): + if l == 1: + if i in pred_class: + pred_class[i]['pred'] += [p] + pred_class[i]['l'] += [label_t] + pred_class[i]['pred_label'] += [1 if p > 0.5 else 0] + else: + pred_class[i] = {'pred': [p], 'l': [label_t], 'pred_label': [1 if p > 0.5 else 0]} + else: + eval_loss = self.model.calc_loss(pred, labels) + pos_labels += rel.to('cpu').numpy().flatten().tolist() + pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist() + loss_list.append(eval_loss.item()) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + roc_auc = [metrics.roc_auc_score(pred_class[l]['l'], pred_class[l]['pred']) for l in pred_class] + prc_auc = [metrics.average_precision_score(pred_class[l]['l'], pred_class[l]['pred']) for l in + pred_class] + ap = [metrics.accuracy_score(pred_class[l]['l'], pred_class[l]['pred_label']) for l in pred_class] + results['auroc'] = np.mean(roc_auc) + results['auprc'] = np.mean(prc_auc) + results['ap'] = np.mean(ap) + else: + results['acc'] = metrics.accuracy_score(pos_labels, pos_scores) + results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro') + results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores) + loss = np.mean(loss_list) + return results, loss + + def spos_arch_search(self): + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}' + os.makedirs(save_root, exist_ok=True) + for save_epoch in [800,700,600,500,400,300,200,100]: + try: + self.model.load_state_dict(torch.load(f'{save_root}/{save_epoch}.pt')) + except: + continue + self.logger = get_logger(f'{save_root}/', f'{save_epoch}_arch_search') + valid_loss_searched_arch_res = dict() + valid_f1_searched_arch_res = dict() + valid_auroc_searched_arch_res = dict() + search_time = 0.0 + t_start = time.time() + for epoch in range(1, self.p.spos_arch_sample_num + 1): + self.generate_single_path() + arch = "||".join(self.model.ops) + val_results, valid_loss = self.evaluate_epoch('valid', 'spos_arch_search') + test_results, test_loss = self.evaluate_epoch('test', 'spos_arch_search') + valid_loss_searched_arch_res.setdefault(arch, valid_loss) + self.logger.info(f'[Epoch {epoch}]: Path:{arch}') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}]: Valid Loss: {valid_loss:.5}, Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Test Loss: {test_loss:.5}, Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + valid_auroc_searched_arch_res.setdefault(arch, val_results['auroc']) + else: + self.logger.info( + f"[Epoch {epoch}]: Valid Loss: {valid_loss:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid ACC: {val_results['acc']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Test Loss: {test_loss:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test ACC: {test_results['acc']:.5}, Test Cohen: {test_results['kappa']:.5}") + valid_f1_searched_arch_res.setdefault(arch, val_results['macro_f1']) + + t_end = time.time() + search_time = (t_end - t_start) + + search_time = search_time / 3600 + self.logger.info(f'The search process costs {search_time:.2f}h.') + import csv + with open(f'{save_root}/valid_loss_{save_epoch}.csv', 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['arch', 'valid loss']) + valid_loss_searched_arch_res_sorted = sorted(valid_loss_searched_arch_res.items(), key=lambda x :x[1]) + res = valid_loss_searched_arch_res_sorted + for i in range(len(res)): + writer.writerow([res[i][0], res[i][1]]) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + with open(f'{save_root}/valid_auroc_{save_epoch}.csv', 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['arch', 'valid auroc']) + valid_auroc_searched_arch_res_sorted = sorted(valid_auroc_searched_arch_res.items(), key=lambda x: x[1], + reverse=True) + res = valid_auroc_searched_arch_res_sorted + for i in range(len(res)): + writer.writerow([res[i][0], res[i][1]]) + else: + with open(f'{save_root}/valid_f1_{save_epoch}.csv', 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['arch', 'valid f1']) + valid_f1_searched_arch_res_sorted = sorted(valid_f1_searched_arch_res.items(), key=lambda x: x[1], reverse=True) + res = valid_f1_searched_arch_res_sorted + for i in range(len(res)): + writer.writerow([res[i][0], res[i][1]]) + + def generate_single_path(self): + if self.p.exp_note is None: + self.model.ops = self.model.generate_single_path() + elif self.p.exp_note == 'only_search_act': + self.model.ops = self.model.generate_single_path_act() + elif self.p.exp_note == 'only_search_comb': + self.model.ops = self.model.generate_single_path_comb() + elif self.p.exp_note == 'only_search_comp': + self.model.ops = self.model.generate_single_path_comp() + elif self.p.exp_note == 'only_search_agg': + self.model.ops = self.model.generate_single_path_agg() + elif self.p.exp_note == 'only_search_agg_comb': + self.model.ops = self.model.generate_single_path_agg_comb() + elif self.p.exp_note == 'only_search_agg_comb_comp': + self.model.ops = self.model.generate_single_path_agg_comb_comp() + elif self.p.exp_note == 'only_search_agg_comb_act_rotate': + self.model.ops = self.model.generate_single_path_agg_comb_act_rotate() + elif self.p.exp_note == 'only_search_agg_comb_act_mult': + self.model.ops = self.model.generate_single_path_agg_comb_act_mult() + elif self.p.exp_note == 'only_search_agg_comb_act_ccorr': + self.model.ops = self.model.generate_single_path_agg_comb_act_ccorr() + elif self.p.exp_note == 'only_search_agg_comb_act_sub': + self.model.ops = self.model.generate_single_path_agg_comb_act_sub() + elif self.p.exp_note == 'spfs' and self.p.search_algorithm == 'spos_arch_search_ps2': + self.model.ops = self.model.generate_single_path() + elif self.p.exp_note == 'spfs' and self.p.few_shot_op is not None: + self.model.ops = self.model.generate_single_path_agg_comb_act_few_shot_comp(self.p.few_shot_op) + # print(1) + + def spos_train_supernet_ps2(self): + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}' + log_root = f'{save_root}/log' + os.makedirs(log_root, exist_ok=True) + self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0") + self.subgraph_selector_optimizer = torch.optim.Adam( + self.subgraph_selector.parameters(), lr=self.p.ss_lr, weight_decay=self.p.l2) + self.logger = get_logger(f'{log_root}/', f'train_supernet') + if self.p.weight_sharing: + save_model_path = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{"_".join(save_root.split("/")[-1].split("_")[:-3])}/400.pt' + save_ss_path = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{"_".join(save_root.split("/")[-1].split("_")[:-3])}/400_ss.pt' + print(save_model_path) + print(save_ss_path) + self.model.load_state_dict(torch.load(save_model_path)) + self.subgraph_selector.load_state_dict(torch.load(save_ss_path)) + for epoch in range(1, self.p.max_epochs+1): + start_time = time.time() + train_loss = self.architecture_search_spos_ps2_epoch() + val_results, valid_loss = self.evaluate_epoch('valid', 'spos_train_supernet_ps2') + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + else: + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + # self.scheduler.step(train_loss) + wandb.log({ + "train_loss": train_loss, + "valid_loss": valid_loss + }) + if epoch % 100 == 0: + torch.save(self.model.state_dict(), f'{save_root}/{epoch}.pt') + torch.save(self.subgraph_selector.state_dict(), f'{save_root}/{epoch}_ss.pt') + + def architecture_search_spos_ps2_epoch(self): + self.model.train() + self.subgraph_selector.train() + loss_list = [] + train_iter = self.data_iter['train_rel'] + valid_iter = self.data_iter['valid_rel'] + # train_bar = tqdm(train_iter, ncols=0) + for step, batch in enumerate(train_iter): + self.optimizer.zero_grad() + self.generate_single_path() + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, + search_algorithm=self.p.ss_search_algorithm) + + # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim] + # pred = self.model(g, subj, obj) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + train_loss = self.model.calc_loss(pred, labels, pos_neg) + else: + train_loss = self.model.calc_loss(pred, labels) + train_loss.backward() + if self.p.clip_grad: + clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2) + self.optimizer.step() + self.optimizer.zero_grad() + if self.p.alpha_mode == 'train_loss': + self.subgraph_selector_optimizer.step() + elif self.p.alpha_mode == 'valid_loss': + self.optimizer.zero_grad() + self.subgraph_selector_optimizer.zero_grad() + batch = next(iter(valid_iter)) + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, + search_algorithm=self.p.ss_search_algorithm) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + valid_loss = self.model.calc_loss(pred, labels, pos_neg) + else: + valid_loss = self.model.calc_loss(pred, labels) + valid_loss.backward() + self.subgraph_selector_optimizer.step() + loss_list.append(train_loss.item()) + loss = np.mean(loss_list) + return loss + + def predict_spos_search_ps2(self, split, mode, eval_mode=None, spos_mode='train_supernet'): + loss_list = [] + pos_scores = [] + pos_labels = [] + pred_class = {} + self.model.eval() + self.subgraph_selector.eval() + with torch.no_grad(): + results = dict() + test_iter = self.data_iter[f'{split}_rel{mode}'] + for step, batch in enumerate(test_iter): + if spos_mode == 'train_supernet': + self.generate_single_path() + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, mode='argmax', + search_algorithm=self.p.ss_search_algorithm) + # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim] + # pred = self.model(g, subj, obj, mode=eval_mode) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + eval_loss = self.model.calc_loss(pred, labels, pos_neg) + m = torch.nn.Sigmoid() + pred = m(pred) + labels = labels.detach().to('cpu').numpy() + preds = pred.detach().to('cpu').numpy() + pos_neg = pos_neg.detach().to('cpu').numpy() + for (label_ids, pred, label_t) in zip(labels, preds, pos_neg): + for i, (l, p) in enumerate(zip(label_ids, pred)): + if l == 1: + if i in pred_class: + pred_class[i]['pred'] += [p] + pred_class[i]['l'] += [label_t] + pred_class[i]['pred_label'] += [1 if p > 0.5 else 0] + else: + pred_class[i] = {'pred': [p], 'l': [label_t], 'pred_label': [1 if p > 0.5 else 0]} + else: + eval_loss = self.model.calc_loss(pred, labels) + pos_labels += rel.to('cpu').numpy().flatten().tolist() + pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist() + loss_list.append(eval_loss.item()) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + roc_auc = [metrics.roc_auc_score(pred_class[l]['l'], pred_class[l]['pred']) for l in pred_class] + prc_auc = [metrics.average_precision_score(pred_class[l]['l'], pred_class[l]['pred']) for l in + pred_class] + ap = [metrics.accuracy_score(pred_class[l]['l'], pred_class[l]['pred_label']) for l in pred_class] + results['auroc'] = np.mean(roc_auc) + results['auprc'] = np.mean(prc_auc) + results['ap'] = np.mean(ap) + else: + results['acc'] = metrics.accuracy_score(pos_labels, pos_scores) + results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro') + results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores) + # dict_res = metrics.classification_report(pos_labels, pos_scores, output_dict=True, zero_division=1) + # results['macro_f1_per_class'] = get_f1_score_list(dict_res) + # results['acc_per_class'] = get_acc_list(dict_res) + loss = np.mean(loss_list) + return results, loss + + def spos_arch_search_ps2(self): + res_list = [] + sorted_list = [] + exp_note = '_' + self.p.exp_note if self.p.exp_note is not None else '' + self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0") + # save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}/{self.p.search_mode}' + # save_model_path = f'{save_root}/{self.p.name}.pt' + # save_ss_path = f'{save_root}/{self.p.name}_ss.pt' + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}' + if self.p.exp_note == 'spfs': + epoch_list = [400] + weight_sharing = '_ws' + else: + epoch_list = [800, 700,600] + weight_sharing = '' + for save_epoch in epoch_list: + # try: + # self.model.load_state_dict(torch.load(f'{save_root}/{save_epoch}.pt')) + # self.subgraph_selector.load_state_dict(torch.load(f'{save_root}/{save_epoch}_ss.pt')) + # except: + # continue + self.logger = get_logger(f'{save_root}/log/', f'{save_epoch}_arch_search{exp_note}') + # res_root = f'{self.prj_path}/search_res/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/' + # os.makedirs(res_root, exist_ok=True) + valid_loss_searched_arch_res = dict() + valid_f1_searched_arch_res = dict() + valid_auroc_searched_arch_res = dict() + t_start = time.time() + for epoch in range(0, self.p.spos_arch_sample_num): + for sample_idx in range(1, self.p.asng_sample_num+1): + self.generate_single_path() + arch = "||".join(self.model.ops) + if self.p.exp_note == 'spfs': + few_shot_op = self.model.ops[0] + else: + few_shot_op = '' + self.model.load_state_dict(torch.load(f'{save_root}{exp_note}_{few_shot_op}{weight_sharing}/{save_epoch}.pt')) + self.subgraph_selector.load_state_dict(torch.load(f'{save_root}{exp_note}_{few_shot_op}{weight_sharing}/{save_epoch}_ss.pt')) + val_results, valid_loss = self.evaluate_epoch('valid', 'spos_arch_search_ps2') + test_results, test_loss = self.evaluate_epoch('test', 'spos_arch_search_ps2') + valid_loss_searched_arch_res.setdefault(arch, valid_loss) + self.logger.info(f'[Epoch {epoch*self.p.asng_sample_num+sample_idx}]: Path:{arch}') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch*self.p.asng_sample_num+sample_idx}]: Valid Loss: {valid_loss:.5}, Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch*self.p.asng_sample_num+sample_idx}]: Test Loss: {test_loss:.5}, Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + valid_auroc_searched_arch_res.setdefault(arch, val_results['auroc']) + sorted_list.append(val_results['auroc']) + else: + self.logger.info( + f"[Epoch {epoch*self.p.asng_sample_num+sample_idx}]: Valid Loss: {valid_loss:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid ACC: {val_results['acc']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch*self.p.asng_sample_num+sample_idx}]: Test Loss: {test_loss:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test ACC: {test_results['acc']:.5}, Test Cohen: {test_results['kappa']:.5}") + valid_f1_searched_arch_res.setdefault(arch, val_results['macro_f1']) + sorted_list.append(val_results['macro_f1']) + res_list.append(sorted(sorted_list, reverse=True)[:self.p.asng_sample_num]) + with open(f"{save_root}/topK_{save_epoch}{exp_note}.pkl", "wb") as f: + pickle.dump(res_list, f) + t_end = time.time() + search_time = (t_end - t_start) + + search_time = search_time / 3600 + self.logger.info(f'The search process costs {search_time:.2f}h.') + import csv + with open(f'{save_root}/valid_loss_{save_epoch}{exp_note}.csv', 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['arch', 'valid loss']) + valid_loss_searched_arch_res_sorted = sorted(valid_loss_searched_arch_res.items(), key=lambda x :x[1]) + res = valid_loss_searched_arch_res_sorted + for i in range(len(res)): + writer.writerow([res[i][0], res[i][1]]) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + with open(f'{save_root}/valid_auroc_{save_epoch}{exp_note}.csv', 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['arch', 'valid auroc']) + valid_auroc_searched_arch_res_sorted = sorted(valid_auroc_searched_arch_res.items(), key=lambda x: x[1], + reverse=True) + res = valid_auroc_searched_arch_res_sorted + for i in range(len(res)): + writer.writerow([res[i][0], res[i][1]]) + else: + with open(f'{save_root}/valid_f1_{save_epoch}{exp_note}.csv', 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['arch', 'valid f1']) + valid_f1_searched_arch_res_sorted = sorted(valid_f1_searched_arch_res.items(), key=lambda x: x[1], reverse=True) + res = valid_f1_searched_arch_res_sorted + for i in range(len(res)): + writer.writerow([res[i][0], res[i][1]]) + + def joint_spos_ps2_fine_tune(self): + self.best_valid_metric = 0.0 + self.best_test_metric = {} + arch_rank = 1 + exp_note = '_' + self.p.exp_note if self.p.exp_note is not None else '' + for save_epoch in [400]: + self.save_epoch = save_epoch + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}' + res_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/res' + log_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/log' + tmp_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/tmp' + os.makedirs(res_root, exist_ok=True) + os.makedirs(log_root, exist_ok=True) + os.makedirs(tmp_root, exist_ok=True) + print(save_root) + self.logger = get_logger(f'{log_root}/', f'fine_tune_e{save_epoch}_vmtop{arch_rank}{exp_note}') + metric = 'auroc' if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset else 'f1' + with open(f'{save_root}/valid_{metric}_{self.save_epoch}{exp_note}_ng.csv', 'r') as csv_file: + reader = csv.reader(csv_file) + for _ in range(arch_rank): + next(reader) # rank 1 is to skip head of csv + self.p.genotype = next(reader)[0] + self.model.ops = self.p.genotype.split("||") + if self.p.exp_note == 'spfs': + weight_sharing = '_ws' + few_shot_op = self.model.ops[0] + else: + weight_sharing = '' + few_shot_op = '' + self.ss_path = f'{save_root}{exp_note}_{few_shot_op}{weight_sharing}/{self.save_epoch}_ss.pt' + study = optuna.create_study(directions=["maximize"]) + study.optimize(self.spfs_fine_tune_each, n_trials=self.p.tune_sample_num) + self.p.lr = 10 ** study.best_params["learning_rate"] + self.p.l2 = 10 ** study.best_params["weight_decay"] + with open(f'{res_root}/fine_tune_e{save_epoch}_vmtop{arch_rank}{exp_note}.txt', "w") as f1: + f1.write(f'{self.p.__dict__}\n') + f1.write(f'Valid performance: {study.best_value}\n') + f1.write(f'Test performance: {self.best_test_metric}') + + def joint_spos_fine_tune(self, parameter): + self.best_val_f1 = 0.0 + self.best_val_auroc = 0.0 + self.early_stop_cnt = 0 + self.p.lr = 10 ** parameter['learning_rate'] + self.p.l2 = 10 ** parameter['weight_decay'] + run = self.reinit_wandb() + run.finish() + return {'loss': -self.p.lr, 'test_metric':self.p.l2, "status": STATUS_OK} + + def spfs_fine_tune_each(self, trial): + exp_note = '_' + self.p.exp_note if self.p.exp_note is not None else '' + self.best_val_f1 = 0.0 + self.best_val_auroc = 0.0 + self.early_stop_cnt = 0 + learning_rate = trial.suggest_float("learning_rate", -3.05, -2.95) + weight_decay = trial.suggest_float("weight_decay", -5, -3) + self.p.lr = 10 ** learning_rate + self.p.l2 = 10 ** weight_decay + self.model = self.get_model().to("cuda:0") + print(type(self.model).__name__) + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.p.lr, weight_decay=self.p.l2) + self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0") + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}' + save_model_path = f'{save_root}/tmp/fine_tune_e{self.save_epoch}{exp_note}.pt' + self.subgraph_selector.load_state_dict(torch.load(str(self.ss_path))) + self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='max', factor=0.2, patience=10, verbose=True) + for epoch in range(self.p.max_epochs): + start_time = time.time() + train_loss = self.train_epoch_fine_tune() + val_results, valid_loss = self.evaluate_epoch('valid', 'spos_arch_search_ps2') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + if val_results['auroc'] > self.best_val_auroc: + self.best_val_results = val_results + self.best_val_auroc = val_results['auroc'] + self.best_epoch = epoch + self.save_model(save_model_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + else: + if val_results['macro_f1'] > self.best_val_f1: + self.best_val_results = val_results + self.best_val_f1 = val_results['macro_f1'] + self.best_epoch = epoch + self.save_model(save_model_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}") + # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + # "best_valid_auroc": self.best_val_auroc}) + self.scheduler.step(self.best_val_auroc) + else: + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}") + # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + # "best_valid_f1": self.best_val_f1}) + self.scheduler.step(self.best_val_f1) + if self.early_stop_cnt == 15: + self.logger.info("Early stop!") + break + self.load_model(save_model_path) + self.logger.info( + f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data') + test_results, test_loss = self.evaluate_epoch('test', 'spos_arch_search_ps2') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + # wandb.log({ + # "test_auroc": test_results['auroc'], + # "test_auprc": test_results['auprc'], + # "test_ap": test_results['ap'] + # }) + # run.finish() + if self.best_val_auroc > self.best_valid_metric: + self.best_valid_metric = self.best_val_auroc + self.best_test_metric = test_results + return self.best_val_auroc + else: + self.logger.info( + f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + # wandb.log({ + # "test_acc": test_results['acc'], + # "test_f1": test_results['macro_f1'], + # "test_cohen": test_results['kappa'] + # }) + # run.finish() + if self.best_val_f1 > self.best_valid_metric: + self.best_valid_metric = self.best_val_f1 + self.best_test_metric = test_results + return self.best_val_f1 + + def spos_fine_tune(self): + self.best_valid_metric = 0.0 + self.best_test_metric = {} + arch_rank = 1 + for save_epoch in [800]: + self.save_epoch = save_epoch + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/' + print(save_root) + self.logger = get_logger(f'{save_root}/', f'{save_epoch}_finu_tune_vmtop{arch_rank}') + metric = 'auroc' if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset else 'f1' + with open(f'{save_root}/valid_{metric}_{self.save_epoch}.csv', 'r') as csv_file: + reader = csv.reader(csv_file) + for _ in range(arch_rank): + next(reader) # rank 1 is to skip head of csv + self.p.genotype = next(reader)[0] + self.model.ops = self.p.genotype.split("||") + study = optuna.create_study(directions=["maximize"]) + study.optimize(self.spos_fine_tune_each, n_trials=self.p.tune_sample_num) + self.p.lr = 10 ** study.best_params["learning_rate"] + self.p.l2 = 10 ** study.best_params["weight_decay"] + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/' + with open(f'{save_root}/{save_epoch}_tune_res_vmtop{arch_rank}.txt', "w") as f1: + f1.write(f'{self.p.__dict__}\n') + f1.write(f'Valid performance: {study.best_value}\n') + f1.write(f'Test performance: {self.best_test_metric}') + + def spos_fine_tune_each(self, trial): + self.best_val_f1 = 0.0 + self.best_val_auroc = 0.0 + self.early_stop_cnt = 0 + learning_rate = trial.suggest_float("learning_rate", -3.05, -2.95) + weight_decay = trial.suggest_float("weight_decay", -5, -3) + self.p.lr = 10 ** learning_rate + self.p.l2 = 10 ** weight_decay + # run = self.reinit_wandb() + self.model = self.get_model().to("cuda:0") + print(type(self.model).__name__) + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.p.lr, weight_decay=self.p.l2) + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/' + # for save_epoch in [100, 200, 300, 400]: + save_model_path = f'{save_root}/{self.save_epoch}_tune.pt' + # self.model.load_state_dict(torch.load(str(supernet_path))) + for epoch in range(self.p.max_epochs): + start_time = time.time() + train_loss = self.train_epoch() + val_results, valid_loss = self.evaluate_epoch('valid', 'spos_arch_search') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + if val_results['auroc'] > self.best_val_auroc: + self.best_val_results = val_results + self.best_val_auroc = val_results['auroc'] + self.best_epoch = epoch + self.save_model(save_model_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + else: + if val_results['macro_f1'] > self.best_val_f1: + self.best_val_results = val_results + self.best_val_f1 = val_results['macro_f1'] + self.best_epoch = epoch + self.save_model(save_model_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}") + # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + # "best_valid_auroc": self.best_val_auroc}) + else: + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}") + # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + # "best_valid_f1": self.best_val_f1}) + if self.early_stop_cnt == 10: + self.logger.info("Early stop!") + break + self.load_model(save_model_path) + self.logger.info( + f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data') + test_results, test_loss = self.evaluate_epoch('test', 'spos_arch_search') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + # wandb.log({ + # "test_auroc": test_results['auroc'], + # "test_auprc": test_results['auprc'], + # "test_ap": test_results['ap'] + # }) + # run.finish() + if self.best_val_auroc > self.best_valid_metric: + self.best_valid_metric = self.best_val_auroc + self.best_test_metric = test_results + return self.best_val_auroc + else: + self.logger.info( + f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + # wandb.log({ + # "test_acc": test_results['acc'], + # "test_f1": test_results['macro_f1'], + # "test_cohen": test_results['kappa'] + # }) + # run.finish() + if self.best_val_f1 > self.best_valid_metric: + self.best_valid_metric = self.best_val_f1 + self.best_test_metric = test_results + return self.best_val_f1 + + def train_mix_hop(self): + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.train_mode}/{self.p.name}/' + os.makedirs(save_root, exist_ok=True) + self.logger = get_logger(f'{save_root}/', f'train') + save_path = f'{save_root}/model_weight.pt' + for epoch in range(self.p.max_epochs): + start_time = time.time() + train_loss = self.train_epoch_mix_hop() + val_results, valid_loss = self.evaluate_epoch('valid', mode='normal_mix_hop') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + if val_results['auroc'] > self.best_val_auroc: + self.best_val_results = val_results + self.best_val_auroc = val_results['auroc'] + self.best_epoch = epoch + self.save_model(save_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + else: + if val_results['macro_f1'] > self.best_val_f1: + self.best_val_results = val_results + self.best_val_f1 = val_results['macro_f1'] + self.best_epoch = epoch + self.save_model(save_path) + self.early_stop_cnt = 0 + else: + self.early_stop_cnt += 1 + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}") + wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + "best_valid_auroc": self.best_val_auroc}) + else: + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}") + wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, + "best_valid_f1": self.best_val_f1}) + if self.early_stop_cnt == 10: + self.logger.info("Early stop!") + break + # self.logger.info(vars(self.p)) + self.load_model(save_path) + self.logger.info( + f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data') + start = time.time() + test_results, test_loss = self.evaluate_epoch('test', mode='normal_mix_hop') + end = time.time() + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + wandb.log({ + "test_auroc": test_results['auroc'], + "test_auprc": test_results['auprc'], + "test_ap": test_results['ap'] + }) + else: + self.logger.info( + f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + wandb.log({ + "test_acc": test_results['acc'], + "test_f1": test_results['macro_f1'], + "test_cohen": test_results['kappa'] + }) + + def train_epoch_mix_hop(self): + self.model.train() + losses = [] + train_iter = self.data_iter['train_rel'] + # train_bar = tqdm(train_iter, ncols=0) + for step, batch in enumerate(train_iter): + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + mix_hop_index = self.transform_hop_index() + pred = self.model.compute_mix_hop_pred(hidden_all_ent, subj, obj, mix_hop_index) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + loss = self.model.calc_loss(pred, labels, pos_neg) + else: + loss = self.model.calc_loss(pred, labels) + self.optimizer.zero_grad() + loss.backward() + if self.p.clip_grad: + clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2) + self.optimizer.step() + losses.append(loss.item()) + loss = np.mean(losses) + return loss + + def transform_hop_index(self): + ij = self.p.exp_note.split("_") + return self.p.n_layer * (int(ij[0]) - 1) + int(ij[1]) - 1 + + def predict_mix_hop(self, split, mode): + loss_list = [] + pos_scores = [] + pos_labels = [] + pred_class = {} + self.model.eval() + with torch.no_grad(): + results = dict() + eval_iter = self.data_iter[f'{split}_rel{mode}'] + for step, batch in enumerate(eval_iter): + if self.p.input_type == 'subgraph': + g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0") + else: + triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0") + g = self.g.to("cuda:0") + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent] + mix_hop_index = self.transform_hop_index() + pred = self.model.compute_mix_hop_pred(hidden_all_ent, subj, obj, mix_hop_index) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + pos_neg = batch[2].to("cuda:0") + eval_loss = self.model.calc_loss(pred, labels, pos_neg) + m = torch.nn.Sigmoid() + pred = m(pred) + labels = labels.detach().to('cpu').numpy() + preds = pred.detach().to('cpu').numpy() + pos_neg = pos_neg.detach().to('cpu').numpy() + for (label_ids, pred, label_t) in zip(labels, preds, pos_neg): + for i, (l, p) in enumerate(zip(label_ids, pred)): + if l == 1: + if i in pred_class: + pred_class[i]['pred'] += [p] + pred_class[i]['l'] += [label_t] + pred_class[i]['pred_label'] += [1 if p > 0.5 else 0] + else: + pred_class[i] = {'pred': [p], 'l': [label_t], 'pred_label': [1 if p > 0.5 else 0]} + else: + eval_loss = self.model.calc_loss(pred, labels) + pos_labels += rel.to('cpu').numpy().flatten().tolist() + pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist() + loss_list.append(eval_loss.item()) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + roc_auc = [metrics.roc_auc_score(pred_class[l]['l'], pred_class[l]['pred']) for l in pred_class] + prc_auc = [metrics.average_precision_score(pred_class[l]['l'], pred_class[l]['pred']) for l in + pred_class] + ap = [metrics.accuracy_score(pred_class[l]['l'], pred_class[l]['pred_label']) for l in pred_class] + results['auroc'] = np.mean(roc_auc) + results['auprc'] = np.mean(prc_auc) + results['ap'] = np.mean(ap) + else: + results['acc'] = metrics.accuracy_score(pos_labels, pos_scores) + results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro') + results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores) + loss = np.mean(loss_list) + return results, loss + + def joint_random_ps2_each(self, trial): + genotype_space = [] + for i in range(self.p.n_layer): + genotype_space.append(trial.suggest_categorical("mess"+ str(i), COMP_PRIMITIVES)) + genotype_space.append(trial.suggest_categorical("agg"+ str(i), AGG_PRIMITIVES)) + genotype_space.append(trial.suggest_categorical("comb"+ str(i), COMB_PRIMITIVES)) + genotype_space.append(trial.suggest_categorical("act"+ str(i), ACT_PRIMITIVES)) + self.best_val_f1 = 0.0 + self.best_val_auroc = 0.0 + self.early_stop_cnt = 0 + # self.best_valid_metric, self.best_test_metric = 0.0, {} + self.p.genotype = "||".join(genotype_space) + self.model = self.get_model().to("cuda:0") + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.p.lr, weight_decay=self.p.l2) + self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0") + self.subgraph_selector_optimizer = torch.optim.Adam( + self.subgraph_selector.parameters(), lr=self.p.ss_lr, weight_decay=self.p.l2) + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}' + os.makedirs(save_root, exist_ok=True) + save_model_path = f'{save_root}/model.pt' + save_ss_path = f'{save_root}/model_ss.pt' + save_model_best_path = f'{save_root}/model_best.pt' + save_ss_best_path = f'{save_root}/model_best_ss.pt' + for epoch in range(self.p.max_epochs): + start_time = time.time() + train_loss = self.search_epoch() + val_results, valid_loss = self.evaluate_epoch('valid', mode='ps2') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + if val_results['auroc'] > self.best_val_auroc: + self.best_val_results = val_results + self.best_val_auroc = val_results['auroc'] + self.best_epoch = epoch + torch.save(self.model.state_dict(), str(save_model_path)) + torch.save(self.subgraph_selector.state_dict(), str(save_ss_path)) + self.early_stop_cnt = 0 + # self.logger.info("Update best valid auroc!") + else: + self.early_stop_cnt += 1 + else: + if val_results['macro_f1'] > self.best_val_f1: + self.best_val_results = val_results + self.best_val_f1 = val_results['macro_f1'] + self.best_epoch = epoch + torch.save(self.model.state_dict(), str(save_model_path)) + torch.save(self.subgraph_selector.state_dict(), str(save_ss_path)) + self.early_stop_cnt = 0 + # self.logger.info("Update best valid f1!") + else: + self.early_stop_cnt += 1 + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s") + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}") + else: + self.logger.info( + f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}") + if self.early_stop_cnt == 10: + self.logger.info("Early stop!") + break + self.model.load_state_dict(torch.load(str(save_model_path))) + self.subgraph_selector.load_state_dict(torch.load(str(save_ss_path))) + self.logger.info( + f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data') + self.logger.info(f'{self.p.genotype}') + start = time.time() + test_results, test_loss = self.evaluate_epoch('test', mode='ps2') + end = time.time() + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + if self.best_val_auroc > self.best_valid_metric: + self.best_valid_metric = self.best_val_auroc + self.best_test_metric = test_results + torch.save(self.model.state_dict(), str(save_model_best_path)) + torch.save(self.subgraph_selector.state_dict(), str(save_ss_best_path)) + with open(f'{save_root}/random_ps2_arch_list.csv', "a") as f: + writer = csv.writer(f) + writer.writerow([self.p.genotype, self.best_val_auroc, test_results['auroc']]) + return self.best_val_auroc + else: + self.logger.info( + f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}") + if self.best_val_f1 > self.best_valid_metric: + self.best_valid_metric = self.best_val_f1 + self.best_test_metric = test_results + torch.save(self.model.state_dict(), str(save_model_best_path)) + torch.save(self.subgraph_selector.state_dict(), str(save_ss_best_path)) + with open(f'{save_root}/random_ps2_arch_list.csv', "a") as f: + writer = csv.writer(f) + writer.writerow([self.p.genotype, self.best_val_f1, test_results['macro_f1']]) + return self.best_val_f1 + + def joint_random_ps2(self): + self.best_valid_metric = 0.0 + self.best_test_metric = {} + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/' + os.makedirs(save_root, exist_ok=True) + print(save_root) + self.logger = get_logger(f'{save_root}/', f'random_ps2') + study = optuna.create_study(directions=["maximize"], sampler=RandomSampler()) + study.optimize(self.joint_random_ps2_each, n_trials=self.p.baseline_sample_num) + with open(f'{save_root}/random_ps2_res.txt', "w") as f1: + f1.write(f'{self.p.__dict__}\n') + f1.write(f'{self.p.genotype}\n') + f1.write(f'Valid performance: {study.best_value}\n') + f1.write(f'Test performance: {self.best_test_metric}') + + def spos_arch_search_ps2_ng(self): + res_list = [] + sorted_list = [] + arch_list = [] + for _ in range(self.p.n_layer): + arch_list.append(len(COMP_PRIMITIVES)) + arch_list.append(len(AGG_PRIMITIVES)) + arch_list.append(len(COMB_PRIMITIVES)) + arch_list.append(len(ACT_PRIMITIVES)) + asng = CategoricalASNG(np.array(arch_list), alpha=1.5, delta_init=1) + exp_note = '_' + self.p.exp_note if self.p.exp_note is not None else '' + self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0") + save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}' + if self.p.exp_note == 'spfs': + epoch_list = [400] + weight_sharing = '_ws' + else: + epoch_list = [800, 700,600] + weight_sharing = '' + for save_epoch in epoch_list: + valid_metric = 0.0 + self.logger = get_logger(f'{save_root}/', f'{save_epoch}_arch_search_ng{exp_note}') + valid_loss_searched_arch_res = dict() + valid_f1_searched_arch_res = dict() + valid_auroc_searched_arch_res = dict() + search_time = 0.0 + t_start = time.time() + for epoch in range(1, self.p.spos_arch_sample_num + 1): + Ms = [] + ma_structs = [] + scores = [] + for i in range(self.p.asng_sample_num): + M = asng.sampling() + struct = np.argmax(M, axis=1) + Ms.append(M) + ma_structs.append(list(struct)) + # print(list(struct)) + self.generate_single_path_ng(list(struct)) + arch = "||".join(self.model.ops) + if self.p.exp_note == 'spfs': + few_shot_op = self.model.ops[0] + else: + few_shot_op = '' + self.model.load_state_dict( + torch.load(f'{save_root}{exp_note}_{few_shot_op}{weight_sharing}/{save_epoch}.pt')) + self.subgraph_selector.load_state_dict( + torch.load(f'{save_root}{exp_note}_{few_shot_op}{weight_sharing}/{save_epoch}_ss.pt')) + val_results, valid_loss = self.evaluate_epoch('valid', 'spos_arch_search_ps2') + test_results, test_loss = self.evaluate_epoch('test', 'spos_arch_search_ps2') + valid_loss_searched_arch_res.setdefault(arch, valid_loss) + self.logger.info(f'[Epoch {epoch}, {i}-th arch]: Path:{arch}') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"[Epoch {epoch}, {i}-th arch]: Valid Loss: {valid_loss:.5}, Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"[Epoch {epoch}, {i}-th arch]: Test Loss: {test_loss:.5}, Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + scores.append(val_results['auroc']) + sorted_list.append(val_results['auroc']) + valid_auroc_searched_arch_res.setdefault(arch, val_results['auroc']) + else: + self.logger.info( + f"[Epoch {epoch}, {i}-th arch]: Valid Loss: {valid_loss:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid ACC: {val_results['acc']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"[Epoch {epoch}, {i}-th arch]: Test Loss: {test_loss:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test ACC: {test_results['acc']:.5}, Test Cohen: {test_results['kappa']:.5}") + valid_f1_searched_arch_res.setdefault(arch, val_results['macro_f1']) + scores.append(val_results['macro_f1']) + sorted_list.append(val_results['macro_f1']) + res_list.append(sorted(sorted_list, reverse=True)[:self.p.asng_sample_num]) + asng.update(np.array(Ms), -np.array(scores), True) + best_struct = list(asng.theta.argmax(axis=1)) + self.generate_single_path_ng(best_struct) + arch = "||".join(self.model.ops) + val_results, valid_loss = self.evaluate_epoch('valid', 'spos_arch_search_ps2') + test_results, test_loss = self.evaluate_epoch('test', 'spos_arch_search_ps2') + self.logger.info(f'Path:{arch}') + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + self.logger.info( + f"Valid Loss: {valid_loss:.5}, Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}") + self.logger.info( + f"Test Loss: {test_loss:.5}, Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}") + else: + self.logger.info( + f"Valid Loss: {valid_loss:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid ACC: {val_results['acc']:.5}, Valid Cohen: {val_results['kappa']:.5}") + self.logger.info( + f"Test Loss: {test_loss:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test ACC: {test_results['acc']:.5}, Test Cohen: {test_results['kappa']:.5}") + with open(f"{save_root}/topK_{save_epoch}{exp_note}_ng.pkl", "wb") as f: + pickle.dump(res_list, f) + t_end = time.time() + search_time = (t_end - t_start) + + search_time = search_time / 3600 + self.logger.info(f'The search process costs {search_time:.2f}h.') + import csv + with open(f'{save_root}/valid_loss_{save_epoch}{exp_note}_ng.csv', 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['arch', 'valid loss']) + valid_loss_searched_arch_res_sorted = sorted(valid_loss_searched_arch_res.items(), key=lambda x :x[1]) + res = valid_loss_searched_arch_res_sorted + for i in range(len(res)): + writer.writerow([res[i][0], res[i][1]]) + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + with open(f'{save_root}/valid_auroc_{save_epoch}{exp_note}_ng.csv', 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['arch', 'valid auroc']) + valid_auroc_searched_arch_res_sorted = sorted(valid_auroc_searched_arch_res.items(), key=lambda x: x[1], + reverse=True) + res = valid_auroc_searched_arch_res_sorted + for i in range(len(res)): + writer.writerow([res[i][0], res[i][1]]) + else: + with open(f'{save_root}/valid_f1_{save_epoch}{exp_note}_ng.csv', 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['arch', 'valid f1']) + valid_f1_searched_arch_res_sorted = sorted(valid_f1_searched_arch_res.items(), key=lambda x: x[1], reverse=True) + res = valid_f1_searched_arch_res_sorted + for i in range(len(res)): + writer.writerow([res[i][0], res[i][1]]) + def generate_single_path_ng(self, struct): + single_path = [] + for ops_index, index in enumerate(struct): + if ops_index % 4 == 0: + single_path.append(COMP_PRIMITIVES[index]) + elif ops_index % 4 == 1: + single_path.append(AGG_PRIMITIVES[index]) + elif ops_index % 4 == 2: + single_path.append(COMB_PRIMITIVES[index]) + elif ops_index % 4 == 3: + single_path.append(ACT_PRIMITIVES[index]) + self.model.ops = single_path + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Parser For Arguments', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('--name', default='test_run', + help='Set run name for saving/restoring models') + parser.add_argument('--dataset', default='drugbank', + help='Dataset to use, default: FB15k-237') + parser.add_argument('--input_type', type=str, default='allgraph', choices=['subgraph', 'allgraph']) + parser.add_argument('--score_func', dest='score_func', default='none', + help='Score Function for Link prediction') + parser.add_argument('--opn', dest='opn', default='corr', + help='Composition Operation to be used in CompGCN') + + parser.add_argument('--batch', dest='batch_size', + default=256, type=int, help='Batch size') + parser.add_argument('--gpu', type=int, default=0, + help='Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0') + parser.add_argument('--epoch', dest='max_epochs', + type=int, default=500, help='Number of epochs') + parser.add_argument('--l2', type=float, default=5e-4, + help='L2 Regularization for Optimizer') + parser.add_argument('--lr', type=float, default=0.001, + help='Starting Learning Rate') + parser.add_argument('--lbl_smooth', dest='lbl_smooth', + type=float, default=0.1, help='Label Smoothing') + parser.add_argument('--num_workers', type=int, default=0, + help='Number of processes to construct batches') + parser.add_argument('--seed', dest='seed', default=12345, + type=int, help='Seed for randomization') + + parser.add_argument('--restore', dest='restore', action='store_true', + help='Restore from the previously saved model') + parser.add_argument('--bias', dest='bias', action='store_true', + help='Whether to use bias in the model') + + parser.add_argument('--num_bases', dest='num_bases', default=-1, type=int, + help='Number of basis relation vectors to use') + parser.add_argument('--init_dim', dest='init_dim', default=100, type=int, + help='Initial dimension size for entities and relations') + parser.add_argument('--gcn_dim', dest='gcn_dim', default=200, + type=int, help='Number of hidden units in GCN') + parser.add_argument('--embed_dim', dest='embed_dim', default=None, type=int, + help='Embedding dimension to give as input to score function') + parser.add_argument('--n_layer', dest='n_layer', default=1, + type=int, help='Number of GCN Layers to use') + parser.add_argument('--gcn_drop', dest='gcn_drop', default=0.1, + type=float, help='Dropout to use in GCN Layer') + parser.add_argument('--hid_drop', dest='hid_drop', + default=0.3, type=float, help='Dropout after GCN') + + # ConvE specific hyperparameters + parser.add_argument('--conve_hid_drop', dest='conve_hid_drop', default=0.3, type=float, + help='ConvE: Hidden dropout') + parser.add_argument('--feat_drop', dest='feat_drop', + default=0.2, type=float, help='ConvE: Feature Dropout') + parser.add_argument('--input_drop', dest='input_drop', default=0.2, + type=float, help='ConvE: Stacked Input Dropout') + parser.add_argument('--k_w', dest='k_w', default=20, + type=int, help='ConvE: k_w') + parser.add_argument('--k_h', dest='k_h', default=10, + type=int, help='ConvE: k_h') + parser.add_argument('--num_filt', dest='num_filt', default=200, type=int, + help='ConvE: Number of filters in convolution') + parser.add_argument('--ker_sz', dest='ker_sz', default=7, + type=int, help='ConvE: Kernel size to use') + + parser.add_argument('--gamma', dest='gamma', default=9.0, + type=float, help='TransE: Gamma to use') + + parser.add_argument('--rat', action='store_true', + default=False, help='random adacency tensors') + parser.add_argument('--wni', action='store_true', + default=False, help='without neighbor information') + parser.add_argument('--wsi', action='store_true', + default=False, help='without self-loop information') + parser.add_argument('--ss', dest='ss', default=-1, + type=int, help='sample size (sample neighbors)') + parser.add_argument('--nobn', action='store_true', + default=False, help='no use of batch normalization in aggregation') + parser.add_argument('--noltr', action='store_true', + default=False, help='no use of linear transformations for relation embeddings') + + parser.add_argument('--encoder', dest='encoder', + default='compgcn', type=str, help='which encoder to use') + + # for lte models + parser.add_argument('--x_ops', dest='x_ops', default="") + parser.add_argument('--r_ops', dest='r_ops', default="") + + parser.add_argument("--ss_num_layer", default=2, type=int) + parser.add_argument("--ss_input_dim", default=200, type=int) + parser.add_argument("--ss_hidden_dim", default=200, type=int) + parser.add_argument("--ss_lr", default=0.001, type=float) + parser.add_argument('--train_mode', default='', type=str, + choices=["train", "tune", "DEBUG", "vis_hop", "inference", "vis_class","joint_tune","spos_tune","vis_hop_pred","vis_rank_ccorelation","vis_rank_ccorelation_spfs"]) + parser.add_argument("--ss_model_path", type=str) + parser.add_argument("--ss_search_algorithm", default='darts', type=str) + parser.add_argument("--search_algorithm", default='darts', type=str) + parser.add_argument("--temperature", default=0.07, type=float) + parser.add_argument("--temperature_min", default=0.005, type=float) + parser.add_argument("--lr_min", type=float, default=0.001) + parser.add_argument("--arch_lr", default=0.001, type=float) + parser.add_argument("--arch_lr_min", default=0.001, type=float) + parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') + parser.add_argument("--w_update_epoch", type=int, default=1) + parser.add_argument("--alpha_mode", type=str, default='valid_loss') + parser.add_argument('--loc_mean', type=float, default=10.0, help='initial mean value to generate the location') + parser.add_argument('--loc_std', type=float, default=0.01, help='initial std to generate the location') + parser.add_argument("--genotype", type=str, default=None) + parser.add_argument("--baseline_sample_num", type=int, default=200) + parser.add_argument("--cos_temp", action='store_true', default=False, help='temp decay') + parser.add_argument("--spos_arch_sample_num", default=1000, type=int) + parser.add_argument("--tune_sample_num", default=10, type=int) + + # subgraph config + parser.add_argument("--subgraph_type", type=str, default='seal') + parser.add_argument("--subgraph_hop", type=int, default=2) + parser.add_argument("--subgraph_edge_sample_ratio", type=float, default=1) + parser.add_argument("--subgraph_is_saved", type=bool, default=True) + parser.add_argument("--subgraph_max_num_nodes", type=int, default=100) + parser.add_argument("--subgraph_sample_type", type=str, default='enclosing_subgraph') + parser.add_argument("--save_mode", type=str, default='graph') + parser.add_argument("--num_neg_samples_per_link", type=int, default=0) + + # transformer config + parser.add_argument("--d_model", type=int, default=100) + parser.add_argument("--num_transformer_layers", type=int, default=2) + parser.add_argument("--nhead", type=int, default=4) + parser.add_argument("--dim_feedforward", type=int, default=100) + parser.add_argument("--transformer_dropout", type=float, default=0.1) + parser.add_argument("--transformer_activation", type=str, default='relu') + parser.add_argument("--concat_type", type=str, default='so') + parser.add_argument("--graph_pooling_type", type=str, default='mean') + + parser.add_argument("--loss_type", type=str, default='ce') + parser.add_argument("--eval_mode", type=str, default='rel') + parser.add_argument("--wandb_project", type=str, default='') + parser.add_argument("--search_mode", type=str, default='', choices=["ps2", "ps2_random", "", "arch_search", "arch_random","joint_search","arch_spos"]) + parser.add_argument("--add_reverse", action='store_true', default=False) + parser.add_argument("--clip_grad", action='store_true', default=False) + parser.add_argument("--fine_tune_with_implicit_subgraph", action='store_true', default=False) + parser.add_argument("--combine_type", type=str, default='concat') + parser.add_argument("--exp_note", type=str, default=None) + parser.add_argument("--few_shot_op", type=str, default=None) + parser.add_argument("--weight_sharing", action='store_true', default=False) + + parser.add_argument("--asng_sample_num", type=int, default=16) + parser.add_argument("--arch_search_mode", type=str, default='random') + + args = parser.parse_args() + opn = '_' + args.opn if args.encoder == 'compgcn' else '' + reverse = '_add_reverse' if args.add_reverse else '' + genotype = '_'+args.genotype if args.genotype is not None else '' + input_type = '_'+args.input_type if args.input_type == 'subgraph' else '' + exp_note = '_'+args.exp_note if args.exp_note is not None else '' + num_bases = '_b'+str(args.num_bases) if args.num_bases!=-1 else '' + alpha_mode = '_'+str(args.alpha_mode) + few_shot_op = '_'+args.few_shot_op if args.few_shot_op is not None else '' + weight_sharing = '_ws' if args.weight_sharing else '' + ss_search_algorithm = '_snas' if args.ss_search_algorithm == 'snas' else '' + + if args.input_type == 'subgraph': + args.name = 'seal' + else: + if args.train_mode in ['train', 'vis_hop_pred']: + args.name = f'{args.encoder}{opn}_{args.score_func}_{args.combine_type}_train_layer{args.n_layer}_seed{args.seed}{num_bases}{reverse}{genotype}{exp_note}' + elif args.search_mode in ['ps2']: + args.name = f'{args.encoder}{opn}_{args.score_func}_{args.combine_type}_{args.search_mode}{ss_search_algorithm}_layer{args.n_layer}_seed{args.seed}{num_bases}{reverse}{genotype}{exp_note}' + elif args.search_mode in ['arch_random']: + args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_{args.search_mode}_layer{args.n_layer}_seed{args.seed}{reverse}{exp_note}' + elif args.search_algorithm == 'spos_arch_search': + args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_spos_train_supernet_layer{args.n_layer}_seed{args.seed}{reverse}{genotype}{exp_note}{few_shot_op}{weight_sharing}' + elif args.search_mode == 'joint_search' and args.train_mode == 'vis_hop': + args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_spos_train_supernet_ps2{ss_search_algorithm}_layer{args.n_layer}_seed{args.seed}{reverse}{exp_note}{few_shot_op}{weight_sharing}' + elif args.search_algorithm == 'spos_arch_search_ps2': + args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_spos_train_supernet_ps2{ss_search_algorithm}_layer{args.n_layer}_seed{args.seed}{reverse}{genotype}' + elif args.search_mode == 'joint_search' and args.train_mode == 'spos_tune': + args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_spos_train_supernet_ps2{ss_search_algorithm}_layer{args.n_layer}_seed{args.seed}{reverse}{genotype}' + elif args.search_mode == 'joint_search' and args.search_algorithm == 'random_ps2': + args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_{args.search_algorithm}{ss_search_algorithm}_layer{args.n_layer}_seed{args.seed}' + elif args.search_algorithm == 'spos_train_supernet_ps2': + args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_{args.search_algorithm}{ss_search_algorithm}_layer{args.n_layer}_seed{args.seed}{reverse}{genotype}{exp_note}{few_shot_op}{weight_sharing}' + elif (args.search_algorithm == 'spos_train_supernet' and args.train_mode!='spos_tune') or args.train_mode == 'vis_rank_correlation' or args.train_mode == 'vis_rank_ccorelation_spfs': + args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_{args.search_algorithm}_layer{args.n_layer}_seed{args.seed}{reverse}{genotype}{exp_note}{few_shot_op}{weight_sharing}' + else: + args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_spos_train_supernet_layer{args.n_layer}_seed{args.seed}{reverse}{genotype}{exp_note}{few_shot_op}{weight_sharing}' + + args.embed_dim = args.k_w * \ + args.k_h if args.embed_dim is None else args.embed_dim + + # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ.setdefault("HYPEROPT_FMIN_SEED", str(args.seed)) + + runner = Runner(args) + + if args.search_mode == 'joint_search' and args.search_algorithm!='spos_arch_search_ps2' and args.train_mode!='spos_tune': + wandb.init( + project=args.wandb_project, + config={ + "dataset": args.dataset, + "encoder": args.encoder, + "score_function": args.score_func, + "batch_size": args.batch_size, + "learning_rate": args.lr, + "encoder_layer_num": args.n_layer, + "epochs": args.max_epochs, + "seed": args.seed, + "init_dim": args.init_dim, + "embed_dim": args.embed_dim, + "loss_type": args.loss_type, + "search_mode": args.search_mode, + "combine_type": args.combine_type, + "note": args.exp_note, + "search_algorithm": args.search_algorithm, + "weight_sharing": args.weight_sharing, + "few_shot_op": args.few_shot_op, + "ss_search_algorithm": args.ss_search_algorithm + }) + + if args.train_mode == 'train': + if args.exp_note is not None: + runner.train_mix_hop() + else: + runner.train() + elif args.train_mode == 'tune' and args.search_mode in ['ps2', 'joint_search']: + runner.fine_tune() + elif args.train_mode == 'vis_hop' and args.search_mode in ['ps2', "ps2_random", "joint_search"]: + runner.vis_hop_distrubution() + elif args.train_mode == 'vis_hop_pred': + runner.vis_hop_pred() + elif args.train_mode == 'inference' and args.search_mode in ['ps2']: + runner.inference() + elif args.train_mode == 'vis_class': + runner.vis_class_distribution() + elif args.train_mode == 'joint_tune': + runner.joint_tune() + else: + if args.search_mode == 'ps2_random': + runner.random_search() + elif args.search_mode == 'ps2': + runner.ps2() + elif args.search_mode == 'arch_search': + if args.train_mode == 'spos_tune': + runner.spos_fine_tune() + elif args.train_mode in ["vis_rank_ccorelation"]: + runner.vis_rank_ccorelation() + elif args.train_mode in ["vis_rank_ccorelation_spfs"]: + runner.vis_rank_ccorelation_spfs() + elif args.search_algorithm in ["darts", "snas"]: + runner.architecture_search() + elif args.search_algorithm in ["spos_train_supernet"]: + runner.spos_train_supernet() + elif args.search_algorithm in ["spos_arch_search"]: + runner.spos_arch_search() + elif args.search_mode == 'arch_random': + runner.arch_random_search() + elif args.search_mode == 'joint_search': + if args.search_algorithm in ["spos_train_supernet_ps2"]: + runner.spos_train_supernet_ps2() + elif args.search_algorithm in ["spos_arch_search_ps2"] and args.arch_search_mode == 'random': + runner.spos_arch_search_ps2() + elif args.search_algorithm in ["spos_arch_search_ps2"] and args.arch_search_mode == 'ng': + runner.spos_arch_search_ps2_ng() + elif args.train_mode in ["spos_tune"]: + runner.joint_spos_ps2_fine_tune() + elif args.search_algorithm == 'random_ps2': + runner.joint_random_ps2() + else: + runner.joint_search() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..7b83ce9 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,7 @@ +from .process_data import process, process_multi_label +from .data_set import TestDataset, TrainDataset, GraphTrainDataset, GraphTestDataset, NCNDataset +from .logger import get_logger +from .sampler import TripleSampler, SEALSampler, SEALSampler_NG, GrailSampler +from .utils import get_current_memory_gb, get_f1_score_list, get_acc_list, deserialize, Temp_Scheduler +from .dgl_utils import get_neighbor_nodes +from .asng import CategoricalASNG \ No newline at end of file diff --git a/utils/asng.py b/utils/asng.py new file mode 100644 index 0000000..139b33c --- /dev/null +++ b/utils/asng.py @@ -0,0 +1,169 @@ +import numpy as np + + +class CategoricalASNG: + """ + code refers to https://github.com/shirakawas/ASNG-NAS + """ + + def __init__(self, categories, alpha=1.5, delta_init=1., lam=2., Delta_max=10, init_theta=None): + if init_theta is not None: + self.theta = init_theta + + self.N = np.sum(categories - 1) + self.d = len(categories) + self.C = categories + self.Cmax = np.max(categories) + self.theta = np.zeros((self.d, self.Cmax)) + + for i in range(self.d): + self.theta[i, :self.C[i]] = 1. / self.C[i] + + for i in range(self.d): + self.theta[i, self.C[i]:] = 0. + + self.valid_param_num = int(np.sum(self.C - 1)) + self.valid_d = len(self.C[self.C > 1]) + + self.alpha = alpha + self.delta_init = delta_init + self.lam = lam + self.Delta_max = Delta_max + + self.Delta = 1. + self.gamma = 0.0 + self.delta = self.delta_init / self.Delta + self.eps = self.delta + self.s = np.zeros(self.N) + + def get_lam(self): + return self.lam + + def get_delta(self): + return self.delta / self.Delta + + def sampling(self): + rand = np.random.rand(self.d, 1) + cum_theta = self.theta.cumsum(axis=1) + + c = (cum_theta - self.theta <= rand) & (rand < cum_theta) + return c + + def sampling_lam(self, lam): + rand = np.random.rand(lam, self.d, 1) + cum_theta = self.theta.cumsum(axis=1) + + c = (cum_theta - self.theta <= rand) & (rand < cum_theta) + return c + + def mle(self): + m = self.theta.argmax(axis=1) + x = np.zeros((self.d, self.Cmax)) + for i, c in enumerate(m): + x[i, c] = 1 + return x + + def update(self, c_one, fxc, range_restriction=True): + delta = self.get_delta() + # print('delta:', delta) + beta = self.delta * self.N ** (-0.5) + + aru, idx = self.utility(fxc) + mu_W, var_W = aru.mean(), aru.var() + # print(fxc, idx, aru) + if var_W == 0: + return + + ngrad = np.mean((aru - mu_W)[:, np.newaxis, np.newaxis] * (c_one[idx] - self.theta), axis=0) + + if (np.abs(ngrad) < 1e-18).all(): + # print('skip update') + return + + sl = [] + for i, K in enumerate(self.C): + theta_i = self.theta[i, :K - 1] + theta_K = self.theta[i, K - 1] + s_i = 1. / np.sqrt(theta_i) * ngrad[i, :K - 1] + s_i += np.sqrt(theta_i) * ngrad[i, :K - 1].sum() / (theta_K + np.sqrt(theta_K)) + sl += list(s_i) + sl = np.array(sl) + + ngnorm = np.sqrt(np.dot(sl, sl)) + 1e-8 + dp = ngrad / ngnorm + assert not np.isnan(dp).any(), (ngrad, ngnorm) + + self.theta += delta * dp + + self.s = (1 - beta) * self.s + np.sqrt(beta * (2 - beta)) * sl / ngnorm + self.gamma = (1 - beta) ** 2 * self.gamma + beta * (2 - beta) + self.Delta *= np.exp(beta * (self.gamma - np.dot(self.s, self.s) / self.alpha)) + self.Delta = min(self.Delta, self.Delta_max) + + for i in range(self.d): + ci = self.C[i] + theta_min = 1. / (self.valid_d * (ci - 1)) if range_restriction and ci > 1 else 0. + self.theta[i, :ci] = np.maximum(self.theta[i, :ci], theta_min) + theta_sum = self.theta[i, :ci].sum() + tmp = theta_sum - theta_min * ci + self.theta[i, :ci] -= (theta_sum - 1.) * (self.theta[i, :ci] - theta_min) / tmp + self.theta[i, :ci] /= self.theta[i, :ci].sum() + + def get_arch(self, ): + return np.argmax(self.theta, axis=1) + + def get_max(self, ): + return np.max(self.theta, axis=1) + + def get_entropy(self, ): + ent = 0 + for i, K in enumerate(self.C): + the = self.theta[i, :K] + ent += np.sum(the * np.log(the)) + return -ent + + @staticmethod + def utility(f, rho=0.25, negative=True): + eps = 1e-3 + idx = np.argsort(f) + lam = len(f) + mu = int(np.ceil(lam * rho)) + _w = np.zeros(lam) + _w[:mu] = 1 / mu + _w[lam - mu:] = -1 / mu if negative else 0 + w = np.zeros(lam) + istart = 0 + for i in range(len(f) - 1): + if f[idx[i + 1]] - f[idx[i]] < eps * f[idx[i]]: + pass + elif istart < i: + w[istart:i + 1] = np.mean(_w[istart:i + 1]) + istart = i + 1 + else: + w[i] = _w[i] + istart = i + 1 + w[istart:] = np.mean(_w[istart:]) + return w, idx + + def log_header(self, theta_log=False): + header_list = ['delta', 'eps', 'snorm_alha', 'theta_converge'] + if theta_log: + for i in range(self.d): + header_list += ['theta%d_%d' % (i, j) for j in range(self.C[i])] + return header_list + + def log(self, theta_log=False): + log_list = [self.delta, self.eps, np.dot(self.s, self.s) / self.alpha, self.theta.max(axis=1).mean()] + + if theta_log: + for i in range(self.d): + log_list += ['%f' % self.theta[i, j] for j in range(self.C[i])] + return log_list + + def load_theta_from_log(self, theta_log): + self.theta = np.zeros((self.d, self.Cmax)) + k = 0 + for i in range(self.d): + for j in range(self.C[i]): + self.theta[i, j] = theta_log[k] + k += 1 \ No newline at end of file diff --git a/utils/data_set.py b/utils/data_set.py new file mode 100644 index 0000000..005dcbe --- /dev/null +++ b/utils/data_set.py @@ -0,0 +1,367 @@ +from torch.utils.data import Dataset +import numpy as np +import torch +import dgl +from dgl import NID +from scipy.sparse.csgraph import shortest_path +import lmdb +import pickle +from utils.dgl_utils import get_neighbor_nodes + + +class TrainDataset(Dataset): + def __init__(self, triplets, num_ent, num_rel, params): + super(TrainDataset, self).__init__() + self.p = params + self.triplets = triplets + self.label_smooth = params.lbl_smooth + self.num_ent = num_ent + self.num_rel = num_rel + + def __len__(self): + return len(self.triplets) + + def __getitem__(self, item): + ele = self.triplets[item] + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + triple, label, pos_neg = torch.tensor(ele['triple'], dtype=torch.long), torch.tensor(ele['label'], dtype=torch.long), torch.tensor(ele['pos_neg'], dtype=torch.float) + return triple, label, pos_neg + else: + triple, label, random_hop = torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label']), torch.tensor(ele['random_hop']) + label = triple[1] + # label = torch.tensor(label, dtype=torch.long) + # label = self.get_label_rel(label) + # if self.label_smooth != 0.0: + # label = (1.0 - self.label_smooth) * label + (1.0 / self.num_rel) + if self.p.search_mode == 'random': + return triple, label, random_hop + else: + return triple, label + + def get_label_rel(self, label): + """ + get label corresponding to a (sub, rel) pair + :param label: a list containing indices of objects corresponding to a (sub, rel) pair + :return: a tensor of shape [nun_ent] + """ + y = np.zeros([self.num_rel*2], dtype=np.float32) + y[label] = 1 + return torch.tensor(y, dtype=torch.float32) + + +class TestDataset(Dataset): + def __init__(self, triplets, num_ent, num_rel, params): + super(TestDataset, self).__init__() + self.p = params + self.triplets = triplets + self.num_ent = num_ent + self.num_rel = num_rel + + def __len__(self): + return len(self.triplets) + + def __getitem__(self, item): + ele = self.triplets[item] + if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset: + triple, label, pos_neg = torch.tensor(ele['triple'], dtype=torch.long), torch.tensor(ele['label'], dtype=torch.long), torch.tensor(ele['pos_neg'], dtype=torch.float) + return triple, label, pos_neg + else: + triple, label, random_hop = torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label']), torch.tensor(ele['random_hop']) + label = triple[1] + # label = torch.tensor(label, dtype=torch.long) + # label = self.get_label_rel(label) + if self.p.search_mode == 'random': + return triple, label, random_hop + else: + return triple, label + + def get_label_rel(self, label): + """ + get label corresponding to a (sub, rel) pair + :param label: a list containing indices of objects corresponding to a (sub, rel) pair + :return: a tensor of shape [nun_ent] + """ + y = np.zeros([self.num_rel*2], dtype=np.float32) + y[label] = 1 + return torch.tensor(y, dtype=torch.float32) + + +class GraphTrainDataset(Dataset): + def __init__(self, triplets, num_ent, num_rel, params, all_graph, db_name_pos=None): + super(GraphTrainDataset, self).__init__() + self.p = params + self.triplets = triplets + self.label_smooth = params.lbl_smooth + self.num_ent = num_ent + self.num_rel = num_rel + self.g = all_graph + db_path = f'subgraph/{self.p.dataset}/{self.p.subgraph_type}_neg_{self.p.num_neg_samples_per_link}_hop_{self.p.subgraph_hop}_seed_{self.p.seed}' + if self.p.save_mode == 'mdb': + self.main_env = lmdb.open(db_path, readonly=True, max_dbs=3, lock=False) + self.db_pos = self.main_env.open_db(db_name_pos.encode()) + # with self.main_env.begin(db=self.db_pos) as txn: + # num_graphs_pos = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little') + # print(num_graphs_pos) + + def __len__(self): + return len(self.triplets) + + def __getitem__(self, item): + ele = self.triplets[item] + if self.p.save_mode == 'pickle': + sample_nodes, triple, label = ele['sample_nodes'], torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label']) + subgraph = self.subgraph_sample(triple[0], triple[2], sample_nodes) + elif self.p.save_mode == 'graph': + subgraph, triple, label = ele['subgraph'], torch.tensor(ele['triple'], dtype=torch.long), torch.tensor(ele['triple'][1], dtype=torch.long) + elif self.p.save_mode == 'mdb': + # triple, label = torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label']) + with self.main_env.begin(db=self.db_pos) as txn: + str_id = '{:08}'.format(item).encode('ascii') + nodes_pos, r_label_pos, g_label_pos, n_labels_pos = deserialize(txn.get(str_id)).values() + # print(nodes_pos) + # head_id = np.argwhere([label[0] == 0 and label[1] == 1 for label in n_labels_pos]) + # tail_id = np.argwhere([label[0] == 1 and label[1] == 0 for label in n_labels_pos]) + # print(head_id) + # print(nodes_pos[head_id[0]], nodes_pos[tail_id[0]]) + # exit(0) + # n_ids = np.zeros(len(nodes_pos)) + # n_ids[head_id] = 1 # head + # n_ids[tail_id] = 2 # tail + # subgraph.ndata['id'] = torch.FloatTensor(n_ids) + if self.p.train_mode == 'tune': + subgraph = torch.zeros(self.num_ent, dtype=torch.bool) + subgraph[nodes_pos] = 1 + else: + subgraph = self.subgraph_sample(nodes_pos[0], nodes_pos[1], nodes_pos) + triple = torch.tensor([nodes_pos[0], r_label_pos,nodes_pos[1]], dtype=torch.long) + return subgraph, triple, torch.tensor(r_label_pos, dtype=torch.long) + # input_ids = self.convert_subgraph_to_tokens(subgraph, self.max_seq_length) + # label = self.get_label_rel(label) + # if self.label_smooth != 0.0: + # label = (1.0 - self.label_smooth) * label + (1.0 / self.num_rel) + return subgraph, triple, label + + def get_label_rel(self, label): + """ + get label corresponding to a (sub, rel) pair + :param label: a list containing indices of objects corresponding to a (sub, rel) pair + :return: a tensor of shape [nun_ent] + """ + y = np.zeros([self.num_rel*2], dtype=np.float32) + y[label] = 1 + return torch.tensor(y, dtype=torch.float32) + + def subgraph_sample(self, u, v, sample_nodes): + subgraph = dgl.node_subgraph(self.g, sample_nodes) + n_ids = np.zeros(len(sample_nodes)) + n_ids[0] = 1 # head + n_ids[1] = 2 # tail + subgraph.ndata['id'] = torch.tensor(n_ids, dtype=torch.long) + # print(sample_nodes) + # print(u,v) + # Each node should have unique node id in the new subgraph + u_id = int(torch.nonzero(subgraph.ndata[NID] == int(u), as_tuple=False)) + # print(torch.nonzero(subgraph.ndata[NID] == int(u), as_tuple=False)) + # print(torch.nonzero(subgraph.ndata[NID] == int(v), as_tuple=False)) + v_id = int(torch.nonzero(subgraph.ndata[NID] == int(v), as_tuple=False)) + + if subgraph.has_edges_between(u_id, v_id): + link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2] + subgraph.remove_edges(link_id) + if subgraph.has_edges_between(v_id, u_id): + link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2] + subgraph.remove_edges(link_id) + + n_ids = np.zeros(len(sample_nodes)) + n_ids[0] = 1 # head + n_ids[1] = 2 # tail + subgraph.ndata['id'] = torch.tensor(n_ids, dtype=torch.long) + + # z = drnl_node_labeling(subgraph, u_id, v_id) + # subgraph.ndata['z'] = z + return subgraph + + +class GraphTestDataset(Dataset): + def __init__(self, triplets, num_ent, num_rel, params, all_graph, db_name_pos=None): + super(GraphTestDataset, self).__init__() + self.p = params + self.triplets = triplets + self.num_ent = num_ent + self.num_rel = num_rel + self.g = all_graph + db_path = f'subgraph/{self.p.dataset}/{self.p.subgraph_type}_neg_{self.p.num_neg_samples_per_link}_hop_{self.p.subgraph_hop}_seed_{self.p.seed}' + if self.p.save_mode == 'mdb': + self.main_env = lmdb.open(db_path, readonly=True, max_dbs=3, lock=False) + self.db_pos = self.main_env.open_db(db_name_pos.encode()) + + def __len__(self): + return len(self.triplets) + + def __getitem__(self, item): + ele = self.triplets[item] + if self.p.save_mode == 'pickle': + sample_nodes, triple, label = ele['sample_nodes'], torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label']) + subgraph = self.subgraph_sample(triple[0], triple[2], sample_nodes) + elif self.p.save_mode == 'graph': + subgraph, triple, label = ele['subgraph'], torch.tensor(ele['triple'], dtype=torch.long), torch.tensor(ele['triple'][1], dtype=torch.long) + elif self.p.save_mode == 'mdb': + # triple, label = torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label']) + with self.main_env.begin(db=self.db_pos) as txn: + str_id = '{:08}'.format(item).encode('ascii') + nodes_pos, r_label_pos, g_label_pos, n_labels_pos = deserialize(txn.get(str_id)).values() + if self.p.train_mode == 'tune': + subgraph = torch.zeros(self.num_ent, dtype=torch.bool) + subgraph[nodes_pos] = 1 + else: + subgraph = self.subgraph_sample(nodes_pos[0], nodes_pos[1], nodes_pos) + triple = torch.tensor([nodes_pos[0], r_label_pos, nodes_pos[1]], dtype=torch.long) + return subgraph, triple, torch.tensor(r_label_pos, dtype=torch.long) + # label = self.get_label_rel(label) + return subgraph, triple, label + + def get_label_rel(self, label): + """ + get label corresponding to a (sub, rel) pair + :param label: a list containing indices of objects corresponding to a (sub, rel) pair + :return: a tensor of shape [nun_ent] + """ + y = np.zeros([self.num_rel*2], dtype=np.float32) + y[label] = 1 + return torch.tensor(y, dtype=torch.float32) + + def subgraph_sample(self, u, v, sample_nodes): + subgraph = dgl.node_subgraph(self.g, sample_nodes) + n_ids = np.zeros(len(sample_nodes)) + n_ids[0] = 1 # head + n_ids[1] = 2 # tail + subgraph.ndata['id'] = torch.tensor(n_ids, dtype=torch.long) + # Each node should have unique node id in the new subgraph + u_id = int(torch.nonzero(subgraph.ndata[NID] == int(u), as_tuple=False)) + v_id = int(torch.nonzero(subgraph.ndata[NID] == int(v), as_tuple=False)) + + if subgraph.has_edges_between(u_id, v_id): + link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2] + subgraph.remove_edges(link_id) + if subgraph.has_edges_between(v_id, u_id): + link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2] + subgraph.remove_edges(link_id) + + # z = drnl_node_labeling(subgraph, u_id, v_id) + # subgraph.ndata['z'] = z + return subgraph + + +def drnl_node_labeling(subgraph, src, dst): + """ + Double Radius Node labeling + d = r(i,u)+r(i,v) + label = 1+ min(r(i,u),r(i,v))+ (d//2)*(d//2+d%2-1) + Isolated nodes in subgraph will be set as zero. + Extreme large graph may cause memory error. + + Args: + subgraph(DGLGraph): The graph + src(int): node id of one of src node in new subgraph + dst(int): node id of one of dst node in new subgraph + Returns: + z(Tensor): node labeling tensor + """ + adj = subgraph.adj().to_dense().numpy() + src, dst = (dst, src) if src > dst else (src, dst) + if src != dst: + idx = list(range(src)) + list(range(src + 1, adj.shape[0])) + adj_wo_src = adj[idx, :][:, idx] + + idx = list(range(dst)) + list(range(dst + 1, adj.shape[0])) + adj_wo_dst = adj[idx, :][:, idx] + + dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src) + dist2src = np.insert(dist2src, dst, 0, axis=0) + dist2src = torch.from_numpy(dist2src) + + dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst - 1) + dist2dst = np.insert(dist2dst, src, 0, axis=0) + dist2dst = torch.from_numpy(dist2dst) + else: + dist2src = shortest_path(adj, directed=False, unweighted=True, indices=src) + # dist2src = np.insert(dist2src, dst, 0, axis=0) + dist2src = torch.from_numpy(dist2src) + + dist2dst = shortest_path(adj, directed=False, unweighted=True, indices=dst) + # dist2dst = np.insert(dist2dst, src, 0, axis=0) + dist2dst = torch.from_numpy(dist2dst) + + dist = dist2src + dist2dst + dist_over_2, dist_mod_2 = dist // 2, dist % 2 + + z = 1 + torch.min(dist2src, dist2dst) + z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1) + z[src] = 1. + z[dst] = 1. + z[torch.isnan(z)] = 0. + + return z.to(torch.long) + +def deserialize(data): + data_tuple = pickle.loads(data) + keys = ('nodes', 'r_label', 'g_label', 'n_label', 'common_neighbor') + return dict(zip(keys, data_tuple)) + +class NCNDataset(Dataset): + def __init__(self, triplets, num_ent, num_rel, params, adj, db_name_pos=None): + super(NCNDataset, self).__init__() + self.p = params + self.triplets = triplets + self.label_smooth = params.lbl_smooth + self.num_ent = num_ent + self.num_rel = num_rel + self.adj = adj + db_path = f'subgraph/{self.p.dataset}/{self.p.subgraph_type}_neg_{self.p.num_neg_samples_per_link}_hop_{self.p.subgraph_hop}_seed_{self.p.seed}' + if self.p.save_mode == 'mdb': + self.main_env = lmdb.open(db_path, readonly=True, max_dbs=3, lock=False) + self.db_pos = self.main_env.open_db(db_name_pos.encode()) + # with self.main_env.begin(db=self.db_pos) as txn: + # num_graphs_pos = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little') + # print(num_graphs_pos) + + def __len__(self): + return len(self.triplets) + + def __getitem__(self, item): + with self.main_env.begin(db=self.db_pos) as txn: + str_id = '{:08}'.format(item).encode('ascii') + nodes_pos, r_label_pos, g, n, cn = deserialize(txn.get(str_id)).values() + triple = torch.tensor([nodes_pos[0], r_label_pos,nodes_pos[1]], dtype=torch.long) + label = torch.tensor(r_label_pos, dtype=torch.long) + cn_index = torch.zeros([self.num_ent+1], dtype=torch.bool) + if len(cn) == 0: + cn_index[self.num_ent] = 1 + else: + cn_index[list(cn)] = 1 + return triple, label, cn_index + + def get_common_neighbors(self, u, v): + cns_list = [] + # for i_u in range(1, self.p.n_layer+1): + # for i_v in range(1, self.p.n_layer+1): + # root_u_nei = get_neighbor_nodes({u}, self.adj, i_u) + # root_v_nei = get_neighbor_nodes({v}, self.adj, i_v) + # subgraph_nei_nodes_int = root_u_nei.intersection(root_v_nei) + # ng = list(subgraph_nei_nodes_int) + # subgraph = torch.zeros([1, self.num_ent], dtype=torch.bool) + # if len(ng) == 0: + # cns_list.append(subgraph) + # continue + # subgraph[:,ng] = 1 + # cns_list.append(subgraph) + # root_u_nei = get_neighbor_nodes({u}, self.adj, 1) + # root_v_nei = get_neighbor_nodes({v}, self.adj, 1) + # subgraph_nei_nodes_int = root_u_nei.intersection(root_v_nei) + # ng = list(subgraph_nei_nodes_int) + # subgraph = torch.zeros([1, self.num_ent], dtype=torch.bool) + # if len(ng) == 0: + # return subgraph + # else: + # subgraph[:,ng] = 1 + return 1 \ No newline at end of file diff --git a/utils/dgl_utils.py b/utils/dgl_utils.py new file mode 100644 index 0000000..da99c42 --- /dev/null +++ b/utils/dgl_utils.py @@ -0,0 +1,131 @@ +import numpy as np +import scipy.sparse as ssp +import random +from scipy.sparse import csc_matrix + +"""All functions in this file are from dgl.contrib.data.knowledge_graph""" + + +def _bfs_relational(adj, roots, max_nodes_per_hop=None): + """ + BFS for graphs. + Modified from dgl.contrib.data.knowledge_graph to accomodate node sampling + """ + visited = set() + current_lvl = set(roots) + + next_lvl = set() + + while current_lvl: + + for v in current_lvl: + visited.add(v) + + next_lvl = _get_neighbors(adj, current_lvl) + next_lvl -= visited # set difference + + if max_nodes_per_hop and max_nodes_per_hop < len(next_lvl): + next_lvl = set(random.sample(next_lvl, max_nodes_per_hop)) + + yield next_lvl + + current_lvl = set.union(next_lvl) + + +def _get_neighbors(adj, nodes): + """Takes a set of nodes and a graph adjacency matrix and returns a set of neighbors. + Directly copied from dgl.contrib.data.knowledge_graph""" + sp_nodes = _sp_row_vec_from_idx_list(list(nodes), adj.shape[1]) + sp_neighbors = sp_nodes.dot(adj) + neighbors = set(ssp.find(sp_neighbors)[1]) # convert to set of indices + return neighbors + + +def _sp_row_vec_from_idx_list(idx_list, dim): + """Create sparse vector of dimensionality dim from a list of indices.""" + shape = (1, dim) + data = np.ones(len(idx_list)) + row_ind = np.zeros(len(idx_list)) + col_ind = list(idx_list) + return ssp.csr_matrix((data, (row_ind, col_ind)), shape=shape) + + +def process_files_ddi(files, triple_file, saved_relation2id=None, keeptrainone = False): + entity2id = {} + relation2id = {} if saved_relation2id is None else saved_relation2id + + triplets = {} + kg_triple = [] + ent = 0 + rel = 0 + + for file_type, file_path in files.items(): + data = [] + # with open(file_path) as f: + # file_data = [line.split() for line in f.read().split('\n')[:-1]] + file_data = np.loadtxt(file_path) + for triplet in file_data: + #print(triplet) + triplet[0], triplet[1], triplet[2] = int(triplet[0]), int(triplet[1]), int(triplet[2]) + if triplet[0] not in entity2id: + entity2id[triplet[0]] = triplet[0] + #ent += 1 + if triplet[1] not in entity2id: + entity2id[triplet[1]] = triplet[1] + #ent += 1 + if not saved_relation2id and triplet[2] not in relation2id: + if keeptrainone: + triplet[2] = 0 + relation2id[triplet[2]] = 0 + rel = 1 + else: + relation2id[triplet[2]] = triplet[2] + rel += 1 + + # Save the triplets corresponding to only the known relations + if triplet[2] in relation2id: + data.append([entity2id[triplet[0]], entity2id[triplet[1]], relation2id[triplet[2]]]) + + triplets[file_type] = np.array(data) + #print(rel) + triplet_kg = np.loadtxt(triple_file) + # print(np.max(triplet_kg[:, -1])) + for (h, t, r) in triplet_kg: + h, t, r = int(h), int(t), int(r) + if h not in entity2id: + entity2id[h] = h + if t not in entity2id: + entity2id[t] = t + if not saved_relation2id and rel+r not in relation2id: + relation2id[rel+r] = rel + r + kg_triple.append([h, t, r]) + kg_triple = np.array(kg_triple) + id2entity = {v: k for k, v in entity2id.items()} + id2relation = {v: k for k, v in relation2id.items()} + #print(relation2id, rel) + + # Construct the list of adjacency matrix each corresponding to each relation. Note that this is constructed only from the train data. + adj_list = [] + #print(kg_triple) + #for i in range(len(relation2id)): + for i in range(rel): + idx = np.argwhere(triplets['train'][:, 2] == i) + adj_list.append(csc_matrix((np.ones(len(idx), dtype=np.uint8), (triplets['train'][:, 0][idx].squeeze(1), triplets['train'][:, 1][idx].squeeze(1))), shape=(34124, 34124))) + for i in range(rel, len(relation2id)): + idx = np.argwhere(kg_triple[:, 2] == i-rel) + #print(len(idx), i) + adj_list.append(csc_matrix((np.ones(len(idx), dtype=np.uint8), (kg_triple[:, 0][idx].squeeze(1), kg_triple[:, 1][idx].squeeze(1))), shape=(34124, 34124))) + #print(adj_list) + #assert 0 + return adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel + + +def get_neighbor_nodes(roots, adj, h=1, max_nodes_per_hop=None): + bfs_generator = _bfs_relational(adj, roots, max_nodes_per_hop) + lvls = list() + for _ in range(h): + try: + lvls.append(next(bfs_generator)) + except StopIteration: + pass + return set().union(*lvls) \ No newline at end of file diff --git a/utils/graph_utils.py b/utils/graph_utils.py new file mode 100644 index 0000000..ced0178 --- /dev/null +++ b/utils/graph_utils.py @@ -0,0 +1,185 @@ +import statistics +import numpy as np +import scipy.sparse as ssp +import torch +import networkx as nx +import dgl +import pickle + + +def serialize(data): + data_tuple = tuple(data.values()) + return pickle.dumps(data_tuple) + + +def deserialize(data): + data_tuple = pickle.loads(data) + keys = ('nodes', 'r_label', 'g_label', 'n_label') + return dict(zip(keys, data_tuple)) + + +def get_edge_count(adj_list): + count = [] + for adj in adj_list: + count.append(len(adj.tocoo().row.tolist())) + return np.array(count) + + +def incidence_matrix(adj_list): + ''' + adj_list: List of sparse adjacency matrices + ''' + + rows, cols, dats = [], [], [] + dim = adj_list[0].shape + for adj in adj_list: + adjcoo = adj.tocoo() + rows += adjcoo.row.tolist() + cols += adjcoo.col.tolist() + dats += adjcoo.data.tolist() + row = np.array(rows) + col = np.array(cols) + data = np.array(dats) + return ssp.csc_matrix((data, (row, col)), shape=dim) + + +def remove_nodes(A_incidence, nodes): + idxs_wo_nodes = list(set(range(A_incidence.shape[1])) - set(nodes)) + return A_incidence[idxs_wo_nodes, :][:, idxs_wo_nodes] + + +def ssp_to_torch(A, device, dense=False): + ''' + A : Sparse adjacency matrix + ''' + idx = torch.LongTensor([A.tocoo().row, A.tocoo().col]) + dat = torch.FloatTensor(A.tocoo().data) + A = torch.sparse.FloatTensor(idx, dat, torch.Size([A.shape[0], A.shape[1]])).to(device=device) + return A + + +# def ssp_multigraph_to_dgl(graph, n_feats=None): +# """ +# Converting ssp multigraph (i.e. list of adjs) to dgl multigraph. +# """ + +# g_nx = nx.MultiDiGraph() +# g_nx.add_nodes_from(list(range(graph[0].shape[0]))) +# # Add edges +# for rel, adj in enumerate(graph): +# # Convert adjacency matrix to tuples for nx0 +# nx_triplets = [] +# for src, dst in list(zip(adj.tocoo().row, adj.tocoo().col)): +# nx_triplets.append((src, dst, {'type': rel})) +# g_nx.add_edges_from(nx_triplets) + +# # make dgl graph +# g_dgl = dgl.DGLGraph(multigraph=True) +# g_dgl.from_networkx(g_nx, edge_attrs=['type']) +# # add node features +# if n_feats is not None: +# g_dgl.ndata['feat'] = torch.tensor(n_feats) + +# return g_dgl + +def ssp_multigraph_to_dgl(graph): + """ + Converting ssp multigraph (i.e. list of adjs) to dgl multigraph. + """ + + g_nx = nx.MultiDiGraph() + g_nx.add_nodes_from(list(range(graph[0].shape[0]))) + # Add edges + for rel, adj in enumerate(graph): + # Convert adjacency matrix to tuples for nx0 + nx_triplets = [] + for src, dst in list(zip(adj.tocoo().row, adj.tocoo().col)): + nx_triplets.append((src, dst, {'type': rel})) + g_nx.add_edges_from(nx_triplets) + + # make dgl graph + g_dgl = dgl.DGLGraph(multigraph=True) + g_dgl = dgl.from_networkx(g_nx, edge_attrs=['type']) + + return g_dgl + + +def collate_dgl(samples): + # The input `samples` is a list of pairs + graphs_pos, g_labels_pos, r_labels_pos = map(list, zip(*samples)) + # print(graphs_pos, g_labels_pos, r_labels_pos, samples) + batched_graph_pos = dgl.batch(graphs_pos) + # print(batched_graph_pos) + + # graphs_neg = [item for sublist in graphs_negs for item in sublist] + # g_labels_neg = [item for sublist in g_labels_negs for item in sublist] + # r_labels_neg = [item for sublist in r_labels_negs for item in sublist] + + # batched_graph_neg = dgl.batch(graphs_neg) + return (batched_graph_pos, r_labels_pos), g_labels_pos # , drug_idx + + +def move_batch_to_device_dgl(batch, device): + (g_dgl_pos, r_labels_pos), targets_pos = batch + + targets_pos = torch.LongTensor(targets_pos).to(device=device) + r_labels_pos = torch.LongTensor(r_labels_pos).to(device=device) + # drug_idx = torch.LongTensor(drug_idx).to(device=device) + # targets_neg = torch.LongTensor(targets_neg).to(device=device) + # r_labels_neg = torch.LongTensor(r_labels_neg).to(device=device) + + g_dgl_pos = g_dgl_pos.to(device) + # g_dgl_neg = send_graph_to_device(g_dgl_neg, device) + + return g_dgl_pos, r_labels_pos, targets_pos + + +def move_batch_to_device_dgl_ddi2(batch, device): + (g_dgl_pos, r_labels_pos), targets_pos = batch + + targets_pos = torch.LongTensor(targets_pos).to(device=device) + r_labels_pos = torch.FloatTensor(r_labels_pos).to(device=device) + # drug_idx = torch.LongTensor(drug_idx).to(device=device) + # targets_neg = torch.LongTensor(targets_neg).to(device=device) + # r_labels_neg = torch.LongTensor(r_labels_neg).to(device=device) + + g_dgl_pos = send_graph_to_device(g_dgl_pos, device) + # g_dgl_neg = send_graph_to_device(g_dgl_neg, device) + + return g_dgl_pos, r_labels_pos, targets_pos + + +def send_graph_to_device(g, device): + # nodes + labels = g.node_attr_schemes() + for l in labels.keys(): + g.ndata[l] = g.ndata.pop(l).to(device) + + # edges + labels = g.edge_attr_schemes() + for l in labels.keys(): + g.edata[l] = g.edata.pop(l).to(device) + return g + + +# The following three functions are modified from networks source codes to +# accomodate diameter and radius for dirercted graphs + + +def eccentricity(G): + e = {} + for n in G.nbunch_iter(): + length = nx.single_source_shortest_path_length(G, n) + e[n] = max(length.values()) + return e + + +def radius(G): + e = eccentricity(G) + e = np.where(np.array(list(e.values())) > 0, list(e.values()), np.inf) + return min(e) + + +def diameter(G): + e = eccentricity(G) + return max(e.values()) diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..50fa9cc --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,12 @@ +import json +import logging +import logging.config +from os import makedirs + +def get_logger(log_dir, name): + config_dict = json.load(open('./config/' + 'logger_config.json')) + config_dict['handlers']['file_handler']['filename'] = log_dir + name + '.log' + makedirs(log_dir, exist_ok=True) + logging.config.dictConfig(config_dict) + logger = logging.getLogger(name) + return logger \ No newline at end of file diff --git a/utils/process_data.py b/utils/process_data.py new file mode 100644 index 0000000..cfe07f0 --- /dev/null +++ b/utils/process_data.py @@ -0,0 +1,78 @@ +from collections import defaultdict as ddict +import random + + +def process(dataset, num_rel, n_layer, add_reverse): + """ + pre-process dataset + :param dataset: a dictionary containing 'train', 'valid' and 'test' data. + :param num_rel: relation number + :return: + """ + + so2r = ddict(set) + so2randomhop = ddict() + class2num = ddict() + # print(len(dataset['train'])) + # index = 0 + # cnt = 0 + for subj, rel, obj in dataset['train']: + class2num[rel] = class2num.setdefault(rel, 0) + 1 + so2r[(subj,obj)].add(rel) + subj_hop, obj_hop = random.randint(1, n_layer), random.randint(1, n_layer) + so2randomhop.setdefault((subj, obj), (subj_hop, obj_hop)) + if add_reverse: + so2r[(obj, subj)].add(rel + num_rel) + class2num[rel + num_rel] = class2num.setdefault(rel + num_rel, 0) + 1 + so2randomhop.setdefault((obj, subj), (obj_hop, subj_hop)) + # index+=1 + # print("______________________") + # print(subj, rel, obj) + # print(so2r[(subj, obj)]) + # print(so2r[(obj, subj)]) + # print(len(so2r)) + # print(2*index) + # assert len(so2r) == 2*index + # print(len(so2r)) + # print(cnt) + so2r_train = {k: list(v) for k, v in so2r.items()} + for split in ['valid', 'test']: + for subj, rel, obj in dataset[split]: + so2r[(subj, obj)].add(rel) + so2r[(obj, subj)].add(rel + num_rel) + subj_hop, obj_hop = random.randint(1, n_layer), random.randint(1, n_layer) + so2randomhop.setdefault((subj, obj), (subj_hop, obj_hop)) + so2randomhop.setdefault((obj, subj), (obj_hop, subj_hop)) + so2r_all = {k: list(v) for k, v in so2r.items()} + triplets = ddict(list) + + # for (subj, obj), rel in so2r_train.items(): + # triplets['train_rel'].append({'triple': (subj, rel, obj), 'label': so2r_train[(subj, obj)], 'random_hop':so2randomhop[(subj, obj)]}) + # FOR DDI + for subj, rel, obj in dataset['train']: + triplets['train_rel'].append( + {'triple': (subj, rel, obj), 'label': so2r_train[(subj, obj)], 'random_hop': so2randomhop[(subj, obj)]}) + if add_reverse: + triplets['train_rel'].append( + {'triple': (obj, rel+num_rel, subj), 'label': so2r_train[(obj, subj)], 'random_hop': so2randomhop[(obj, subj)]}) + for split in ['valid', 'test']: + for subj, rel, obj in dataset[split]: + triplets[f"{split}_rel"].append({'triple': (subj, rel, obj), 'label': so2r_all[(subj, obj)], 'random_hop':so2randomhop[(subj, obj)]}) + triplets[f"{split}_rel_inv"].append( + {'triple': (obj, rel + num_rel, subj), 'label': so2r_all[(obj, subj)], 'random_hop':so2randomhop[(obj, subj)]}) + triplets = dict(triplets) + return triplets, class2num + + +def process_multi_label(input, multi_label, pos_neg): + triplets = ddict(list) + for index, data in enumerate(input['train']): + subj, _, obj = data + triplets['train_rel'].append( + {'triple': (subj, -1, obj), 'label': multi_label['train'][index], 'pos_neg': pos_neg['train'][index]}) + for split in ['valid', 'test']: + for index, data in enumerate(input[split]): + subj, _, obj = data + triplets[f"{split}_rel"].append({'triple': (subj, -1, obj), 'label': multi_label[split][index],'pos_neg': pos_neg[split][index]}) + triplets = dict(triplets) + return triplets \ No newline at end of file diff --git a/utils/sampler.py b/utils/sampler.py new file mode 100644 index 0000000..36fdf23 --- /dev/null +++ b/utils/sampler.py @@ -0,0 +1,1148 @@ +import os.path as osp +from tqdm import tqdm +from copy import deepcopy +import torch +import dgl +from torch.utils.data import DataLoader, Dataset +from dgl import DGLGraph, NID +from dgl.dataloading.negative_sampler import Uniform +from dgl import add_self_loop +import numpy as np +from scipy.sparse.csgraph import shortest_path +#Grail +from scipy.sparse import csc_matrix +import os +import json +import matplotlib.pyplot as plt +import logging +from scipy.special import softmax +from tqdm import tqdm +import lmdb +import multiprocessing as mp +import scipy.sparse as ssp +from utils.graph_utils import serialize, incidence_matrix, remove_nodes +from utils.dgl_utils import _bfs_relational +import struct + + +class GraphDataSet(Dataset): + """ + GraphDataset for torch DataLoader + """ + + def __init__(self, graph_list, tensor, n_ent, n_rel, max_seq_length): + self.graph_list = graph_list + self.tensor = tensor + self.n_ent = n_ent + self.n_rel = n_rel + self.max_seq_length = max_seq_length + + def __len__(self): + return len(self.graph_list) + + def __getitem__(self, index): + input_ids, input_mask, mask_position, mask_label = self.convert_subgraph_to_feature(self.graph_list[index], + self.max_seq_length, + self.tensor[index]) + rela_mat = torch.zeros(self.max_seq_length, self.max_seq_length, self.n_rel) + rela_mat[self.graph_list[index].edges()[0], self.graph_list[index].edges()[1], self.graph_list[index].edata['rel']] = 1 + return (self.graph_list[index], self.tensor[index], input_ids, input_mask, mask_position, mask_label, rela_mat) + + def convert_subgraph_to_feature(self, subgraph, max_seq_length, triple): + input_ids = [subgraph.ndata[NID]] + input_mask = [1 for _ in range(subgraph.num_nodes())] + if max_seq_length - subgraph.num_nodes() > 0: + input_ids.append(torch.tensor([self.n_ent for _ in range(max_seq_length - subgraph.num_nodes())])) + input_mask += [0 for _ in range(max_seq_length - subgraph.num_nodes())] + input_ids = torch.cat(input_ids) + input_mask = torch.tensor(input_mask) + # input_ids = torch.cat( + # [subgraph.ndata[NID], torch.tensor([self.n_ent for _ in range(max_seq_length - subgraph.num_nodes())])]) + # input_mask = [1 for _ in range(subgraph.num_nodes())] + # input_mask += [0 for _ in range(max_seq_length - subgraph.num_nodes())] + # while len(input_mask) < max_seq_length: + # # input_ids.append(self.n_ent) + # input_mask.append(0) + # TODO: predict head entity, now predict tail entity + mask_position = ((subgraph.ndata[NID] == triple[2]).nonzero().flatten()) + input_ids[mask_position.item()] = self.n_ent + 1 + mask_label = triple[2] + # for position in list(torch.where(subgraph.ndata['z']==1)[0]): + # mask_position = position + # input_ids[mask_position]=self.n_ent+1 + # mask_label = subgraph.ndata[NID][mask_position] + # break + assert input_ids.size(0) == max_seq_length + assert input_mask.size(0) == max_seq_length + + return input_ids, input_mask, mask_position, mask_label + + +class GraphDataSetRP(Dataset): + """ + GraphDataset for torch DataLoader + """ + + def __init__(self, graph_list, tensor, n_ent, n_rel, max_seq_length): + self.graph_list = graph_list + self.tensor = tensor + self.n_ent = n_ent + self.n_rel = n_rel + self.max_seq_length = max_seq_length + + def __len__(self): + return len(self.graph_list) + + def __getitem__(self, index): + input_ids, num_nodes, head_position, tail_position = self.convert_subgraph_to_feature(self.graph_list[index],self.tensor[index],self.max_seq_length) + rela_mat = torch.zeros(self.max_seq_length, self.max_seq_length, self.n_rel) + rela_mat[self.graph_list[index].edges()[0], self.graph_list[index].edges()[1], self.graph_list[index].edata['rel']] = 1 + return (self.graph_list[index], self.tensor[index], input_ids, num_nodes, head_position, tail_position, rela_mat) + + def convert_subgraph_to_feature(self, subgraph, triple, max_seq_length): + input_ids = [subgraph.ndata[NID]] + if max_seq_length - subgraph.num_nodes() > 0: + input_ids.append(torch.tensor([self.n_ent for _ in range(max_seq_length - subgraph.num_nodes())])) + input_ids = torch.cat(input_ids) + # print(subgraph) + # print(input_ids) + # print(triple) + head_position = torch.where(subgraph.ndata[NID] == triple[0])[0] + tail_position = torch.where(subgraph.ndata[NID] == triple[2])[0] + # print(head_position) + # print(tail_position) + # exit(0) + # input_ids = torch.cat( + # [subgraph.ndata[NID], torch.tensor([self.n_ent for _ in range(max_seq_length - subgraph.num_nodes())])]) + # input_mask = [1 for _ in range(subgraph.num_nodes())] + # input_mask += [0 for _ in range(max_seq_length - subgraph.num_nodes())] + # while len(input_mask) < max_seq_length: + # # input_ids.append(self.n_ent) + # input_mask.append(0) + # for position in list(torch.where(subgraph.ndata['z']==1)[0]): + # mask_position = position + # input_ids[mask_position]=self.n_ent+1 + # mask_label = subgraph.ndata[NID][mask_position] + # break + assert input_ids.size(0) == max_seq_length + + return input_ids, subgraph.num_nodes(), head_position, tail_position + +class GraphDataSetGCN(Dataset): + """ + GraphDataset for torch DataLoader + """ + + def __init__(self, graph_list, tensor, n_ent, n_rel): + self.graph_list = graph_list + self.tensor = tensor + + def __len__(self): + return len(self.graph_list) + + def __getitem__(self, index): + head_idx = torch.where(self.graph_list[index].ndata[NID] == self.tensor[index][0])[0] + tail_idx = torch.where(self.graph_list[index].ndata[NID] == self.tensor[index][2])[0] + return (self.graph_list[index], self.tensor[index], head_idx, tail_idx) + +class PosNegEdgesGenerator(object): + """ + Generate positive and negative samples + Attributes: + g(dgl.DGLGraph): graph + split_edge(dict): split edge + neg_samples(int): num of negative samples per positive sample + subsample_ratio(float): ratio of subsample + shuffle(bool): if shuffle generated graph list + """ + + def __init__(self, g, split_edge, neg_samples=1, subsample_ratio=0.1, shuffle=True): + self.neg_sampler = Uniform(neg_samples) + self.subsample_ratio = subsample_ratio + self.split_edge = split_edge + self.g = g + self.shuffle = shuffle + + def __call__(self, split_type): + + if split_type == 'train': + subsample_ratio = self.subsample_ratio + else: + subsample_ratio = 1 + + pos_edges = self.g.edges() + pos_edges = torch.stack((pos_edges[0], pos_edges[1]), 1) + + if split_type == 'train': + # Adding self loop in train avoids sampling the source node itself. + g = add_self_loop(self.g) + eids = g.edge_ids(pos_edges[:, 0], pos_edges[:, 1]) + neg_edges = torch.stack(self.neg_sampler(g, eids), dim=1) + else: + neg_edges = self.split_edge[split_type]['edge_neg'] + pos_edges = self.subsample(pos_edges, subsample_ratio).long() + neg_edges = self.subsample(neg_edges, subsample_ratio).long() + + edges = torch.cat([pos_edges, neg_edges]) + labels = torch.cat([torch.ones(pos_edges.size(0), 1), torch.zeros(neg_edges.size(0), 1)]) + if self.shuffle: + perm = torch.randperm(edges.size(0)) + edges = edges[perm] + labels = labels[perm] + return edges, labels + + def subsample(self, edges, subsample_ratio): + """ + Subsample generated edges. + Args: + edges(Tensor): edges to subsample + subsample_ratio(float): ratio of subsample + + Returns: + edges(Tensor): edges + + """ + + num_edges = edges.size(0) + perm = torch.randperm(num_edges) + perm = perm[:int(subsample_ratio * num_edges)] + edges = edges[perm] + return edges + + +class EdgeDataSet(Dataset): + """ + Assistant Dataset for speeding up the SEALSampler + """ + + def __init__(self, triples, transform): + self.transform = transform + self.triples = triples + + def __len__(self): + return len(self.triples) + + def __getitem__(self, index): + edge = torch.tensor([self.triples[index]['triple'][0], self.triples[index]['triple'][2]]) + subgraph = self.transform(edge) + return (subgraph) + + +class SEALSampler(object): + """ + Sampler for SEAL in paper(no-block version) + The strategy is to sample all the k-hop neighbors around the two target nodes. + Attributes: + graph(DGLGraph): The graph + hop(int): num of hop + num_workers(int): num of workers + + """ + + def __init__(self, graph, hop=1, max_num_nodes=100, num_workers=32, type='greedy', print_fn=print): + self.graph = graph + self.hop = hop + self.print_fn = print_fn + self.num_workers = num_workers + self.threshold = None + self.max_num_nodes = max_num_nodes + self.sample_type = type + + def sample_subgraph(self, target_nodes, mode='valid'): + """ + Args: + target_nodes(Tensor): Tensor of two target nodes + Returns: + subgraph(DGLGraph): subgraph + """ + # TODO: Add sample constrain on each hop + sample_nodes = [target_nodes] + frontiers = target_nodes + if self.sample_type == 'greedy': + for i in range(self.hop): + # get sampled node number + tmp_sample_nodes = torch.cat(sample_nodes) + tmp_sample_nodes = torch.unique(tmp_sample_nodes) + if tmp_sample_nodes.size(0) < self.max_num_nodes: # whether sample or not + frontiers = self.graph.out_edges(frontiers)[1] + frontiers = torch.unique(frontiers) + if frontiers.size(0) > self.max_num_nodes - tmp_sample_nodes.size(0): + frontiers = np.random.choice(frontiers.numpy(), self.max_num_nodes - tmp_sample_nodes.size(0), replace=False) + frontiers = torch.unique(torch.tensor(frontiers)) + sample_nodes.append(frontiers) + elif self.sample_type == 'greedy_set': + for i in range(self.hop): + # get sampled node number + tmp_sample_nodes = torch.cat(sample_nodes) + tmp_sample_nodes_set = set(tmp_sample_nodes.tolist()) + # tmp_sample_nodes = torch.unique(tmp_sample_nodes) + if len(tmp_sample_nodes_set) < self.max_num_nodes: # whether sample or not + tmp_frontiers = self.graph.out_edges(frontiers)[1] + tmp_frontiers_set = set(tmp_frontiers.tolist()) + tmp_frontiers_set = tmp_frontiers_set - tmp_sample_nodes_set + if not tmp_frontiers_set: + break + else: + frontiers_set = tmp_frontiers_set + frontiers = torch.tensor(list(frontiers_set)) + if frontiers.size(0) > self.max_num_nodes - len(tmp_sample_nodes_set): + frontiers = np.random.choice(frontiers.numpy(), self.max_num_nodes - len(tmp_sample_nodes_set), replace=False) + frontiers = torch.unique(torch.tensor(frontiers)) + sample_nodes.append(frontiers) + elif self.sample_type == 'average_set': + for i in range(self.hop): + # get sampled node number + tmp_sample_nodes = torch.cat(sample_nodes) + tmp_sample_nodes_set = set(tmp_sample_nodes.tolist()) + num_nodes = int(self.max_num_nodes / self.hop * (i + 1)) + if len(tmp_sample_nodes_set) < num_nodes: # whether sample or not + tmp_frontiers = self.graph.out_edges(frontiers)[1] + tmp_frontiers_set = set(tmp_frontiers.tolist()) + tmp_frontiers_set = tmp_frontiers_set - tmp_sample_nodes_set + if not tmp_frontiers_set: + break + else: + frontiers_set = tmp_frontiers_set + frontiers = torch.tensor(list(frontiers_set)) + if frontiers.size(0) > num_nodes - len(tmp_sample_nodes_set): + frontiers = np.random.choice(frontiers.numpy(), num_nodes - len(tmp_sample_nodes_set), replace=False) + frontiers = torch.unique(torch.tensor(frontiers)) + sample_nodes.append(frontiers) + elif self.sample_type == 'average': + for i in range(self.hop): + tmp_sample_nodes = torch.cat(sample_nodes) + tmp_sample_nodes = torch.unique(tmp_sample_nodes) + num_nodes = int(self.max_num_nodes/self.hop*(i+1)) + if tmp_sample_nodes.size(0) < num_nodes: # whether sample or not + frontiers = self.graph.out_edges(frontiers)[1] + frontiers = torch.unique(frontiers) + if frontiers.size(0) > num_nodes - tmp_sample_nodes.size(0): + frontiers = np.random.choice(frontiers.numpy(), num_nodes - tmp_sample_nodes.size(0), + replace=False) + frontiers = torch.unique(torch.tensor(frontiers)) + sample_nodes.append(frontiers) + elif self.sample_type == 'enclosing_subgraph': + u, v = target_nodes[0], target_nodes[1] + u_neighbour, v_neighbour = u, v + # print(target_nodes) + u_sample_nodes = [u_neighbour.reshape(1)] + v_sample_nodes = [v_neighbour.reshape(1)] + graph = self.graph + if graph.has_edges_between(u, v): + link_id = graph.edge_ids(u, v, return_uv=True)[2] + graph.remove_edges(link_id) + if graph.has_edges_between(v, u): + link_id = graph.edge_ids(v, u, return_uv=True)[2] + graph.remove_edges(link_id) + for i in range(self.hop): + u_frontiers = graph.out_edges(u_neighbour)[1] + # v_frontiers = self.graph.out_edges(v_neighbour)[1] + u_neighbour = u_frontiers + # set(u_frontiers.tolist()) + if u_frontiers.size(0) > self.max_num_nodes: + u_frontiers = np.random.choice(u_frontiers.numpy(), self.max_num_nodes, replace=False) + u_frontiers = torch.tensor(u_frontiers) + u_sample_nodes.append(u_frontiers) + for i in range(self.hop): + v_frontiers = graph.out_edges(v_neighbour)[1] + # v_frontiers = self.graph.out_edges(v_neighbour)[1] + v_neighbour = v_frontiers + if v_frontiers.size(0) > self.max_num_nodes: + v_frontiers = np.random.choice(v_frontiers.numpy(), self.max_num_nodes, replace=False) + v_frontiers = torch.tensor(v_frontiers) + v_sample_nodes.append(v_frontiers) + # print('U', u_sample_nodes) + # print('V', v_sample_nodes) + u_sample_nodes = torch.cat(u_sample_nodes) + u_sample_nodes = torch.unique(u_sample_nodes) + v_sample_nodes = torch.cat(v_sample_nodes) + v_sample_nodes = torch.unique(v_sample_nodes) + # print('U', u_sample_nodes) + # print('V', v_sample_nodes) + u_sample_nodes_set = set(u_sample_nodes.tolist()) + v_sample_nodes_set = set(v_sample_nodes.tolist()) + uv_inter_neighbour = u_sample_nodes_set.intersection(v_sample_nodes_set) + frontiers = torch.tensor(list(uv_inter_neighbour), dtype=torch.int64) + # print(frontiers) + sample_nodes.append(frontiers) + else: + raise NotImplementedError + sample_nodes = torch.cat(sample_nodes) + sample_nodes = torch.unique(sample_nodes) + subgraph = dgl.node_subgraph(self.graph, sample_nodes) + + # Each node should have unique node id in the new subgraph + u_id = int(torch.nonzero(subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False)) + v_id = int(torch.nonzero(subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False)) + # remove link between target nodes in positive subgraphs. + # Edge removing will rearange NID and EID, which lose the original NID and EID. + + # if dgl.__version__[:5] < '0.6.0': + # nids = subgraph.ndata[NID] + # eids = subgraph.edata[EID] + # if subgraph.has_edges_between(u_id, v_id): + # link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2] + # subgraph.remove_edges(link_id) + # eids = eids[subgraph.edata[EID]] + # if subgraph.has_edges_between(v_id, u_id): + # link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2] + # subgraph.remove_edges(link_id) + # eids = eids[subgraph.edata[EID]] + # subgraph.ndata[NID] = nids + # subgraph.edata[EID] = eids + + if subgraph.has_edges_between(u_id, v_id): + link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2] + subgraph.remove_edges(link_id) + if subgraph.has_edges_between(v_id, u_id): + link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2] + subgraph.remove_edges(link_id) + + z = drnl_node_labeling(subgraph, u_id, v_id) + subgraph.ndata['z'] = z + return subgraph + + def _collate(self, batch_graphs): + + # batch_graphs = map(list, zip(*batch)) + # print(batch_graphs) + # print(batch_triples) + # print(batch_labels) + + batch_graphs = dgl.batch(batch_graphs) + return batch_graphs + + def __call__(self, triples): + subgraph_list = [] + triples_list = [] + labels_list = [] + edge_dataset = EdgeDataSet(triples, transform=self.sample_subgraph) + self.print_fn('Using {} workers in sampling job.'.format(self.num_workers)) + sampler = DataLoader(edge_dataset, batch_size=128, num_workers=self.num_workers, + shuffle=False, collate_fn=self._collate) + for subgraph in tqdm(sampler, ncols=100): + subgraph = dgl.unbatch(subgraph) + + subgraph_list += subgraph + + return subgraph_list + + +class SEALSampler_NG(object): + """ + Sampler for SEAL in paper(no-block version) + The strategy is to sample all the k-hop neighbors around the two target nodes. + Attributes: + graph(DGLGraph): The graph + hop(int): num of hop + num_workers(int): num of workers + + """ + + def __init__(self, graph, hop=1, max_num_nodes=100, num_workers=32, type='greedy', print_fn=print): + self.graph = graph + self.hop = hop + self.print_fn = print_fn + self.num_workers = num_workers + self.threshold = None + self.max_num_nodes = max_num_nodes + self.sample_type = type + + def sample_subgraph(self, target_nodes, mode='valid'): + """ + Args: + target_nodes(Tensor): Tensor of two target nodes + Returns: + subgraph(DGLGraph): subgraph + """ + # TODO: Add sample constrain on each hop + sample_nodes = [target_nodes] + frontiers = target_nodes + if self.sample_type == 'greedy': + for i in range(self.hop): + # get sampled node number + tmp_sample_nodes = torch.cat(sample_nodes) + tmp_sample_nodes = torch.unique(tmp_sample_nodes) + if tmp_sample_nodes.size(0) < self.max_num_nodes: # whether sample or not + frontiers = self.graph.out_edges(frontiers)[1] + frontiers = torch.unique(frontiers) + if frontiers.size(0) > self.max_num_nodes - tmp_sample_nodes.size(0): + frontiers = np.random.choice(frontiers.numpy(), self.max_num_nodes - tmp_sample_nodes.size(0), replace=False) + frontiers = torch.unique(torch.tensor(frontiers)) + sample_nodes.append(frontiers) + elif self.sample_type == 'greedy_set': + for i in range(self.hop): + # get sampled node number + tmp_sample_nodes = torch.cat(sample_nodes) + tmp_sample_nodes_set = set(tmp_sample_nodes.tolist()) + # tmp_sample_nodes = torch.unique(tmp_sample_nodes) + if len(tmp_sample_nodes_set) < self.max_num_nodes: # whether sample or not + tmp_frontiers = self.graph.out_edges(frontiers)[1] + tmp_frontiers_set = set(tmp_frontiers.tolist()) + tmp_frontiers_set = tmp_frontiers_set - tmp_sample_nodes_set + if not tmp_frontiers_set: + break + else: + frontiers_set = tmp_frontiers_set + frontiers = torch.tensor(list(frontiers_set)) + if frontiers.size(0) > self.max_num_nodes - len(tmp_sample_nodes_set): + frontiers = np.random.choice(frontiers.numpy(), self.max_num_nodes - len(tmp_sample_nodes_set), replace=False) + frontiers = torch.unique(torch.tensor(frontiers)) + sample_nodes.append(frontiers) + elif self.sample_type == 'average_set': + for i in range(self.hop): + # get sampled node number + tmp_sample_nodes = torch.cat(sample_nodes) + tmp_sample_nodes_set = set(tmp_sample_nodes.tolist()) + num_nodes = int(self.max_num_nodes / self.hop * (i + 1)) + if len(tmp_sample_nodes_set) < num_nodes: # whether sample or not + tmp_frontiers = self.graph.out_edges(frontiers)[1] + tmp_frontiers_set = set(tmp_frontiers.tolist()) + tmp_frontiers_set = tmp_frontiers_set - tmp_sample_nodes_set + if not tmp_frontiers_set: + break + else: + frontiers_set = tmp_frontiers_set + frontiers = torch.tensor(list(frontiers_set)) + if frontiers.size(0) > num_nodes - len(tmp_sample_nodes_set): + frontiers = np.random.choice(frontiers.numpy(), num_nodes - len(tmp_sample_nodes_set), replace=False) + frontiers = torch.unique(torch.tensor(frontiers)) + sample_nodes.append(frontiers) + elif self.sample_type == 'average': + for i in range(self.hop): + tmp_sample_nodes = torch.cat(sample_nodes) + tmp_sample_nodes = torch.unique(tmp_sample_nodes) + num_nodes = int(self.max_num_nodes/self.hop*(i+1)) + if tmp_sample_nodes.size(0) < num_nodes: # whether sample or not + frontiers = self.graph.out_edges(frontiers)[1] + frontiers = torch.unique(frontiers) + if frontiers.size(0) > num_nodes - tmp_sample_nodes.size(0): + frontiers = np.random.choice(frontiers.numpy(), num_nodes - tmp_sample_nodes.size(0), + replace=False) + frontiers = torch.unique(torch.tensor(frontiers)) + sample_nodes.append(frontiers) + elif self.sample_type == 'enclosing_subgraph': + u, v = target_nodes[0], target_nodes[1] + u_neighbour, v_neighbour = u, v + # print(target_nodes) + u_sample_nodes = [u_neighbour.reshape(1)] + v_sample_nodes = [v_neighbour.reshape(1)] + graph = self.graph + if graph.has_edges_between(u, v): + link_id = graph.edge_ids(u, v, return_uv=True)[2] + graph.remove_edges(link_id) + if graph.has_edges_between(v, u): + link_id = graph.edge_ids(v, u, return_uv=True)[2] + graph.remove_edges(link_id) + for i in range(self.hop): + u_frontiers = self.graph.out_edges(u_neighbour)[1] + # v_frontiers = self.graph.out_edges(v_neighbour)[1] + u_neighbour = u_frontiers + # set(u_frontiers.tolist()) + if u_frontiers.size(0) > self.max_num_nodes: + u_frontiers = np.random.choice(u_frontiers.numpy(), self.max_num_nodes, replace=False) + u_frontiers = torch.tensor(u_frontiers) + u_sample_nodes.append(u_frontiers) + for i in range(self.hop): + v_frontiers = self.graph.out_edges(v_neighbour)[1] + # v_frontiers = self.graph.out_edges(v_neighbour)[1] + v_neighbour = v_frontiers + if v_frontiers.size(0) > self.max_num_nodes: + v_frontiers = np.random.choice(v_frontiers.numpy(), self.max_num_nodes, replace=False) + v_frontiers = torch.tensor(v_frontiers) + v_sample_nodes.append(v_frontiers) + # print('U', u_sample_nodes) + # print('V', v_sample_nodes) + u_sample_nodes = torch.cat(u_sample_nodes) + u_sample_nodes = torch.unique(u_sample_nodes) + v_sample_nodes = torch.cat(v_sample_nodes) + v_sample_nodes = torch.unique(v_sample_nodes) + # print('U', u_sample_nodes) + # print('V', v_sample_nodes) + u_sample_nodes_set = set(u_sample_nodes.tolist()) + v_sample_nodes_set = set(v_sample_nodes.tolist()) + uv_inter_neighbour = u_sample_nodes_set.intersection(v_sample_nodes_set) + frontiers = torch.tensor(list(uv_inter_neighbour), dtype=torch.int64) + # print(frontiers) + sample_nodes.append(frontiers) + # print(sample_nodes) + # print("____________________________________") + # exit(0) + # v_neighbour = v_frontiers + else: + raise NotImplementedError + + sample_nodes = torch.cat(sample_nodes) + sample_nodes = torch.unique(sample_nodes) + # print(sample_nodes) + # print("____________________________________") + + return sample_nodes + + def _collate(self, batch_nodes): + + # batch_graphs = map(list, zip(*batch)) + # print(batch_graphs) + # print(batch_triples) + # print(batch_labels) + + return batch_nodes + + def __call__(self, triples): + sample_nodes_list = [] + edge_dataset = EdgeDataSet(triples, transform=self.sample_subgraph) + self.print_fn('Using {} workers in sampling job.'.format(self.num_workers)) + sampler = DataLoader(edge_dataset, batch_size=1, num_workers=self.num_workers, + shuffle=False, collate_fn=self._collate) + for sample_nodes in tqdm(sampler): + sample_nodes_list.append(sample_nodes) + # for sample_nodes in tqdm(sampler, ncols=100): + # print(sample_nodes) + # exit(0) + # sample_nodes_list.append(sample_nodes) + + return sample_nodes_list + +class SEALData(object): + """ + 1. Generate positive and negative samples + 2. Subgraph sampling + + Attributes: + g(dgl.DGLGraph): graph + split_edge(dict): split edge + hop(int): num of hop + neg_samples(int): num of negative samples per positive sample + subsample_ratio(float): ratio of subsample + use_coalesce(bool): True for coalesce graph. Graph with multi-edge need to coalesce + """ + + def __init__(self, g, split_edge, hop=1, neg_samples=1, subsample_ratio=1, prefix=None, save_dir=None, + num_workers=32, shuffle=True, use_coalesce=True, print_fn=print): + self.g = g + self.hop = hop + self.subsample_ratio = subsample_ratio + self.prefix = prefix + self.save_dir = save_dir + self.print_fn = print_fn + + self.generator = PosNegEdgesGenerator(g=self.g, + split_edge=split_edge, + neg_samples=neg_samples, + subsample_ratio=subsample_ratio, + shuffle=shuffle) + # if use_coalesce: + # for k, v in g.edata.items(): + # g.edata[k] = v.float() # dgl.to_simple() requires data is float + # self.g = dgl.to_simple(g, copy_ndata=True, copy_edata=True, aggregator='sum') + # + # self.ndata = {k: v for k, v in self.g.ndata.items()} + # self.edata = {k: v for k, v in self.g.edata.items()} + # self.g.ndata.clear() + # self.g.edata.clear() + # self.print_fn("Save ndata and edata in class.") + # self.print_fn("Clear ndata and edata in graph.") + # + self.sampler = SEALSampler(graph=self.g, + hop=hop, + num_workers=num_workers, + print_fn=print_fn) + + def __call__(self, split_type): + + if split_type == 'train': + subsample_ratio = self.subsample_ratio + else: + subsample_ratio = 1 + + path = osp.join(self.save_dir or '', '{}_{}_{}-hop_{}-subsample.bin'.format(self.prefix, split_type, + self.hop, subsample_ratio)) + + if osp.exists(path): + self.print_fn("Load existing processed {} files".format(split_type)) + graph_list, data = dgl.load_graphs(path) + dataset = GraphDataSet(graph_list, data['labels']) + + else: + self.print_fn("Processed {} files not exist.".format(split_type)) + + edges, labels = self.generator(split_type) + self.print_fn("Generate {} edges totally.".format(edges.size(0))) + + graph_list, labels = self.sampler(edges, labels) + dataset = GraphDataSet(graph_list, labels) + dgl.save_graphs(path, graph_list, {'labels': labels}) + self.print_fn("Save preprocessed subgraph to {}".format(path)) + return dataset + + +class TripleSampler(object): + def __init__(self, g, data, sample_ratio=0.1): + self.g = g + self.sample_ratio = sample_ratio + self.data = data + + def __call__(self, split_type): + + if split_type == 'train': + sample_ratio = self.sample_ratio + self.shuffle = True + pos_edges = self.g.edges() + pos_edges = torch.stack((pos_edges[0], pos_edges[1]), 1) + + g = add_self_loop(self.g) + pos_edges = self.sample(pos_edges, sample_ratio).long() + eids = g.edge_ids(pos_edges[:, 0], pos_edges[:, 1]) + edges = pos_edges + # labels = torch.cat([torch.ones(pos_edges.size(0), 1), torch.zeros(neg_edges.size(0), 1)]) + triples = torch.stack(([edges[:, 0], g.edata['rel'][eids], edges[:, 1]]), 1) + else: + self.shuffle = False + triples = torch.tensor(self.data[split_type]) + edges = torch.stack((triples[:, 0], triples[:, 2]), 1) + + if self.shuffle: + perm = torch.randperm(edges.size(0)) + edges = edges[perm] + triples = triples[perm] + return edges, triples + + def sample(self, edges, sample_ratio): + """ + Subsample generated edges. + Args: + edges(Tensor): edges to subsample + subsample_ratio(float): ratio of subsample + + Returns: + edges(Tensor): edges + + """ + + num_edges = edges.size(0) + perm = torch.randperm(num_edges) + perm = perm[:int(sample_ratio * num_edges)] + edges = edges[perm] + return edges + + +def drnl_node_labeling(subgraph, src, dst): + """ + Double Radius Node labeling + d = r(i,u)+r(i,v) + label = 1+ min(r(i,u),r(i,v))+ (d//2)*(d//2+d%2-1) + Isolated nodes in subgraph will be set as zero. + Extreme large graph may cause memory error. + + Args: + subgraph(DGLGraph): The graph + src(int): node id of one of src node in new subgraph + dst(int): node id of one of dst node in new subgraph + Returns: + z(Tensor): node labeling tensor + """ + adj = subgraph.adj().to_dense().numpy() + src, dst = (dst, src) if src > dst else (src, dst) + if src != dst: + idx = list(range(src)) + list(range(src + 1, adj.shape[0])) + adj_wo_src = adj[idx, :][:, idx] + + idx = list(range(dst)) + list(range(dst + 1, adj.shape[0])) + adj_wo_dst = adj[idx, :][:, idx] + + dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src) + dist2src = np.insert(dist2src, dst, 0, axis=0) + dist2src = torch.from_numpy(dist2src) + + dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst - 1) + dist2dst = np.insert(dist2dst, src, 0, axis=0) + dist2dst = torch.from_numpy(dist2dst) + else: + dist2src = shortest_path(adj, directed=False, unweighted=True, indices=src) + # dist2src = np.insert(dist2src, dst, 0, axis=0) + dist2src = torch.from_numpy(dist2src) + + dist2dst = shortest_path(adj, directed=False, unweighted=True, indices=dst) + # dist2dst = np.insert(dist2dst, src, 0, axis=0) + dist2dst = torch.from_numpy(dist2dst) + + dist = dist2src + dist2dst + dist_over_2, dist_mod_2 = dist // 2, dist % 2 + + z = 1 + torch.min(dist2src, dist2dst) + z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1) + z[src] = 1. + z[dst] = 1. + z[torch.isnan(z)] = 0. + + return z.to(torch.long) + +class GrailSampler(object): + def __init__(self, dataset, file_paths, external_kg_file, db_path, hop, enclosing_sub_graph, max_nodes_per_hop): + self.dataset = dataset + self.file_paths = file_paths + self.external_kg_file = external_kg_file + self.db_path = db_path + self.max_links = 2500000 + self.params = dict() + self.params['hop'] = hop + self.params['enclosing_sub_graph'] = enclosing_sub_graph + self.params['max_nodes_per_hop'] = max_nodes_per_hop + + def generate_subgraph_datasets(self, num_neg_samples_per_link, constrained_neg_prob, splits=['train', 'valid', 'test'], saved_relation2id=None, max_label_value=None): + + + testing = 'test' in splits + #adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel = process_files(params.file_paths, saved_relation2id) + + # triple_file = f'data/{}/{}.txt'.format(params.dataset,params.BKG_file_name) + if self.dataset == 'drugbank': + adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel = self.process_files_ddi(self.file_paths, self.external_kg_file, saved_relation2id) + # else: + # adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel, triplets_mr, polarity_mr = self.process_files_decagon(params.file_paths, triple_file, saved_relation2id) + # self.plot_rel_dist(adj_list, f'rel_dist.png') + #print(triplets.keys(), triplets_mr.keys()) + data_path = f'datasets/{self.dataset}/relation2id.json' + if not os.path.isdir(data_path) and testing: + with open(data_path, 'w') as f: + json.dump(relation2id, f) + + graphs = {} + + for split_name in splits: + if self.dataset == 'drugbank': + graphs[split_name] = {'triplets': triplets[split_name], 'max_size': self.max_links} + # elif self.dataset == 'BioSNAP': + # graphs[split_name] = {'triplets': triplets_mr[split_name], 'max_size': params.max_links, "polarity_mr": polarity_mr[split_name]} + # Sample train and valid/test links + for split_name, split in graphs.items(): + print(f"Sampling negative links for {split_name}") + split['pos'], split['neg'] = self.sample_neg(adj_list, split['triplets'], num_neg_samples_per_link, max_size=split['max_size'], constrained_neg_prob=constrained_neg_prob) + #print(graphs.keys()) + # if testing: + # directory = os.path.join(params.main_dir, 'data/{}/'.format(params.dataset)) + # save_to_file(directory, f'neg_{params.test_file}_{params.constrained_neg_prob}.txt', graphs['test']['neg'], id2entity, id2relation) + + self.links2subgraphs(adj_list, graphs, self.params, max_label_value) + + def process_files_ddi(self, files, triple_file, saved_relation2id=None, keeptrainone = False): + entity2id = {} + relation2id = {} if saved_relation2id is None else saved_relation2id + + triplets = {} + kg_triple = [] + ent = 0 + rel = 0 + + for file_type, file_path in files.items(): + data = [] + # with open(file_path) as f: + # file_data = [line.split() for line in f.read().split('\n')[:-1]] + file_data = np.loadtxt(file_path) + for triplet in file_data: + #print(triplet) + triplet[0], triplet[1], triplet[2] = int(triplet[0]), int(triplet[1]), int(triplet[2]) + if triplet[0] not in entity2id: + entity2id[triplet[0]] = triplet[0] + #ent += 1 + if triplet[1] not in entity2id: + entity2id[triplet[1]] = triplet[1] + #ent += 1 + if not saved_relation2id and triplet[2] not in relation2id: + if keeptrainone: + triplet[2] = 0 + relation2id[triplet[2]] = 0 + rel = 1 + else: + relation2id[triplet[2]] = triplet[2] + rel += 1 + + # Save the triplets corresponding to only the known relations + if triplet[2] in relation2id: + data.append([entity2id[triplet[0]], entity2id[triplet[1]], relation2id[triplet[2]]]) + + triplets[file_type] = np.array(data) + #print(rel) + triplet_kg = np.loadtxt(triple_file) + # print(np.max(triplet_kg[:, -1])) + for (h, t, r) in triplet_kg: + h, t, r = int(h), int(t), int(r) + if h not in entity2id: + entity2id[h] = h + if t not in entity2id: + entity2id[t] = t + if not saved_relation2id and rel+r not in relation2id: + relation2id[rel+r] = rel + r + kg_triple.append([h, t, r]) + kg_triple = np.array(kg_triple) + id2entity = {v: k for k, v in entity2id.items()} + id2relation = {v: k for k, v in relation2id.items()} + #print(relation2id, rel) + + # Construct the list of adjacency matrix each corresponding to each relation. Note that this is constructed only from the train data. + adj_list = [] + #print(kg_triple) + #for i in range(len(relation2id)): + for i in range(rel): + idx = np.argwhere(triplets['train'][:, 2] == i) + adj_list.append(csc_matrix((np.ones(len(idx), dtype=np.uint8), (triplets['train'][:, 0][idx].squeeze(1), triplets['train'][:, 1][idx].squeeze(1))), shape=(34124, 34124))) + for i in range(rel, len(relation2id)): + idx = np.argwhere(kg_triple[:, 2] == i-rel) + #print(len(idx), i) + adj_list.append(csc_matrix((np.ones(len(idx), dtype=np.uint8), (kg_triple[:, 0][idx].squeeze(1), kg_triple[:, 1][idx].squeeze(1))), shape=(34124, 34124))) + #print(adj_list) + #assert 0 + return adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel + + def plot_rel_dist(self, adj_list, filename): + rel_count = [] + for adj in adj_list: + rel_count.append(adj.count_nonzero()) + + fig = plt.figure(figsize=(12, 8)) + plt.plot(rel_count) + fig.savefig(filename, dpi=fig.dpi) + + def get_edge_count(self, adj_list): + count = [] + for adj in adj_list: + count.append(len(adj.tocoo().row.tolist())) + return np.array(count) + + def sample_neg(self, adj_list, edges, num_neg_samples_per_link=1, max_size=1000000, constrained_neg_prob=0): + pos_edges = edges + neg_edges = [] + + # if max_size is set, randomly sample train links + if max_size < len(pos_edges): + perm = np.random.permutation(len(pos_edges))[:max_size] + pos_edges = pos_edges[perm] + + # sample negative links for train/test + n, r = adj_list[0].shape[0], len(adj_list) + + # distribution of edges across reelations + theta = 0.001 + edge_count = self.get_edge_count(adj_list) + rel_dist = np.zeros(edge_count.shape) + idx = np.nonzero(edge_count) + rel_dist[idx] = softmax(theta * edge_count[idx]) + + # possible head and tails for each relation + valid_heads = [adj.tocoo().row.tolist() for adj in adj_list] + valid_tails = [adj.tocoo().col.tolist() for adj in adj_list] + + pbar = tqdm(total=len(pos_edges)) + while len(neg_edges) < num_neg_samples_per_link * len(pos_edges): + neg_head, neg_tail, rel = pos_edges[pbar.n % len(pos_edges)][0], pos_edges[pbar.n % len(pos_edges)][1], \ + pos_edges[pbar.n % len(pos_edges)][2] + if np.random.uniform() < constrained_neg_prob: + if np.random.uniform() < 0.5: + neg_head = np.random.choice(valid_heads[rel]) + else: + neg_tail = np.random.choice(valid_tails[rel]) + else: + if np.random.uniform() < 0.5: + neg_head = np.random.choice(n) + else: + neg_tail = np.random.choice(n) + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_edges.append([neg_head, neg_tail, rel]) + pbar.update(1) + + pbar.close() + + neg_edges = np.array(neg_edges) + return pos_edges, neg_edges + + def links2subgraphs(self, A, graphs, params, max_label_value=None): + ''' + extract enclosing subgraphs, write map mode + named dbs + ''' + max_n_label = {'value': np.array([0, 0])} + subgraph_sizes = [] + enc_ratios = [] + num_pruned_nodes = [] + + BYTES_PER_DATUM = self.get_average_subgraph_size(100, list(graphs.values())[0]['pos'], A, params) * 1.5 + links_length = 0 + for split_name, split in graphs.items(): + links_length += (len(split['pos']) + len(split['neg'])) * 2 + map_size = links_length * BYTES_PER_DATUM + + env = lmdb.open(self.db_path, map_size=map_size, max_dbs=6) + + def extraction_helper(A, links, g_labels, split_env): + + with env.begin(write=True, db=split_env) as txn: + txn.put('num_graphs'.encode(), (len(links)).to_bytes(int.bit_length(len(links)), byteorder='little')) + + with mp.Pool(processes=None, initializer=self.intialize_worker, initargs=(A, params, max_label_value)) as p: + args_ = zip(range(len(links)), links, g_labels) + for (str_id, datum) in tqdm(p.imap(self.extract_save_subgraph, args_), total=len(links)): + max_n_label['value'] = np.maximum(np.max(datum['n_labels'], axis=0), max_n_label['value']) + subgraph_sizes.append(datum['subgraph_size']) + enc_ratios.append(datum['enc_ratio']) + num_pruned_nodes.append(datum['num_pruned_nodes']) + + with env.begin(write=True, db=split_env) as txn: + txn.put(str_id, serialize(datum)) + + for split_name, split in graphs.items(): + logging.info(f"Extracting enclosing subgraphs for positive links in {split_name} set") + if self.dataset == 'BioSNAP': + labels = np.array(split["polarity_mr"]) + else: + labels = np.ones(len(split['pos'])) + db_name_pos = split_name + '_pos' + split_env = env.open_db(db_name_pos.encode()) + extraction_helper(A, split['pos'], labels, split_env) + + logging.info(f"Extracting enclosing subgraphs for negative links in {split_name} set") + if self.dataset == 'BioSNAP': + labels = np.array(split["polarity_mr"]) + else: + labels = np.ones(len(split['pos'])) + db_name_neg = split_name + '_neg' + split_env = env.open_db(db_name_neg.encode()) + extraction_helper(A, split['neg'], labels, split_env) + + max_n_label['value'] = max_label_value if max_label_value is not None else max_n_label['value'] + + with env.begin(write=True) as txn: + bit_len_label_sub = int.bit_length(int(max_n_label['value'][0])) + bit_len_label_obj = int.bit_length(int(max_n_label['value'][1])) + txn.put('max_n_label_sub'.encode(), + (int(max_n_label['value'][0])).to_bytes(bit_len_label_sub, byteorder='little')) + txn.put('max_n_label_obj'.encode(), + (int(max_n_label['value'][1])).to_bytes(bit_len_label_obj, byteorder='little')) + + txn.put('avg_subgraph_size'.encode(), struct.pack('f', float(np.mean(subgraph_sizes)))) + txn.put('min_subgraph_size'.encode(), struct.pack('f', float(np.min(subgraph_sizes)))) + txn.put('max_subgraph_size'.encode(), struct.pack('f', float(np.max(subgraph_sizes)))) + txn.put('std_subgraph_size'.encode(), struct.pack('f', float(np.std(subgraph_sizes)))) + + txn.put('avg_enc_ratio'.encode(), struct.pack('f', float(np.mean(enc_ratios)))) + txn.put('min_enc_ratio'.encode(), struct.pack('f', float(np.min(enc_ratios)))) + txn.put('max_enc_ratio'.encode(), struct.pack('f', float(np.max(enc_ratios)))) + txn.put('std_enc_ratio'.encode(), struct.pack('f', float(np.std(enc_ratios)))) + + txn.put('avg_num_pruned_nodes'.encode(), struct.pack('f', float(np.mean(num_pruned_nodes)))) + txn.put('min_num_pruned_nodes'.encode(), struct.pack('f', float(np.min(num_pruned_nodes)))) + txn.put('max_num_pruned_nodes'.encode(), struct.pack('f', float(np.max(num_pruned_nodes)))) + txn.put('std_num_pruned_nodes'.encode(), struct.pack('f', float(np.std(num_pruned_nodes)))) + + def get_average_subgraph_size(self, sample_size, links, A, params): + total_size = 0 + # print(links, len(links)) + lst = np.random.choice(len(links), sample_size) + for idx in lst: + (n1, n2, r_label) = links[idx] + # for (n1, n2, r_label) in links[np.random.choice(len(links), sample_size)]: + nodes, n_labels, subgraph_size, enc_ratio, num_pruned_nodes, common_neighbor = self.subgraph_extraction_labeling((n1, n2), + r_label, A, + params['hop'], + params['enclosing_sub_graph'], + params['max_nodes_per_hop']) + datum = {'nodes': nodes, 'r_label': r_label, 'g_label': 0, 'n_labels': n_labels, 'common_neighbor': common_neighbor, + 'subgraph_size': subgraph_size, 'enc_ratio': enc_ratio, 'num_pruned_nodes': num_pruned_nodes} + total_size += len(serialize(datum)) + return total_size / sample_size + + def intialize_worker(self, A, params, max_label_value): + global A_, params_, max_label_value_ + A_, params_, max_label_value_ = A, params, max_label_value + + def extract_save_subgraph(self, args_): + idx, (n1, n2, r_label), g_label = args_ + nodes, n_labels, subgraph_size, enc_ratio, num_pruned_nodes, common_neighbor = self.subgraph_extraction_labeling((n1, n2), r_label, + A_, params_['hop'], + params_['enclosing_sub_graph'], + params_['max_nodes_per_hop']) + + # max_label_value_ is to set the maximum possible value of node label while doing double-radius labelling. + if max_label_value_ is not None: + n_labels = np.array([np.minimum(label, max_label_value_).tolist() for label in n_labels]) + + datum = {'nodes': nodes, 'r_label': r_label, 'g_label': g_label, 'n_labels': n_labels, 'common_neighbor': common_neighbor, + 'subgraph_size': subgraph_size, 'enc_ratio': enc_ratio, 'num_pruned_nodes': num_pruned_nodes} + str_id = '{:08}'.format(idx).encode('ascii') + + return (str_id, datum) + + def get_neighbor_nodes(self, roots, adj, h=1, max_nodes_per_hop=None): + bfs_generator = _bfs_relational(adj, roots, max_nodes_per_hop) + lvls = list() + for _ in range(h): + try: + lvls.append(next(bfs_generator)) + except StopIteration: + pass + return set().union(*lvls) + + def subgraph_extraction_labeling(self, ind, rel, A_list, h=1, enclosing_sub_graph=False, max_nodes_per_hop=None, + max_node_label_value=None): + # extract the h-hop enclosing subgraphs around link 'ind' + A_incidence = incidence_matrix(A_list) + A_incidence += A_incidence.T + ind = list(ind) + ind[0], ind[1] = int(ind[0]), int(ind[1]) + ind = (ind[0], ind[1]) + root1_nei = self.get_neighbor_nodes(set([ind[0]]), A_incidence, h, max_nodes_per_hop) + root2_nei = self.get_neighbor_nodes(set([ind[1]]), A_incidence, h, max_nodes_per_hop) + subgraph_nei_nodes_int = root1_nei.intersection(root2_nei) + subgraph_nei_nodes_un = root1_nei.union(root2_nei) + + root1_nei_1 = self.get_neighbor_nodes(set([ind[0]]), A_incidence, 1, max_nodes_per_hop) + root2_nei_1 = self.get_neighbor_nodes(set([ind[1]]), A_incidence, 1, max_nodes_per_hop) + common_neighbor = root1_nei_1.intersection(root2_nei_1) + + # Extract subgraph | Roots being in the front is essential for labelling and the model to work properly. + if enclosing_sub_graph: + if ind[0] in subgraph_nei_nodes_int: + subgraph_nei_nodes_int.remove(ind[0]) + if ind[1] in subgraph_nei_nodes_int: + subgraph_nei_nodes_int.remove(ind[1]) + subgraph_nodes = list(ind) + list(subgraph_nei_nodes_int) + else: + if ind[0] in subgraph_nei_nodes_un: + subgraph_nei_nodes_un.remove(ind[0]) + if ind[1] in subgraph_nei_nodes_un: + subgraph_nei_nodes_un.remove(ind[1]) + subgraph_nodes = list(ind) + list(subgraph_nei_nodes_un) # list(set(ind).union(subgraph_nei_nodes_un)) + + subgraph = [adj[subgraph_nodes, :][:, subgraph_nodes] for adj in A_list] + + labels, enclosing_subgraph_nodes = self.node_label(incidence_matrix(subgraph), max_distance=h) + # print(ind, subgraph_nodes[:32],enclosing_subgraph_nodes[:32], labels) + pruned_subgraph_nodes = np.array(subgraph_nodes)[enclosing_subgraph_nodes].tolist() + pruned_labels = labels[enclosing_subgraph_nodes] + # pruned_subgraph_nodes = subgraph_nodes + # pruned_labels = labels + + if max_node_label_value is not None: + pruned_labels = np.array([np.minimum(label, max_node_label_value).tolist() for label in pruned_labels]) + + subgraph_size = len(pruned_subgraph_nodes) + enc_ratio = len(subgraph_nei_nodes_int) / (len(subgraph_nei_nodes_un) + 1e-3) + num_pruned_nodes = len(subgraph_nodes) - len(pruned_subgraph_nodes) + # print(pruned_subgraph_nodes) + # import time + # time.sleep(10) + return pruned_subgraph_nodes, pruned_labels, subgraph_size, enc_ratio, num_pruned_nodes, common_neighbor + + def node_label(self, subgraph, max_distance=1): + # implementation of the node labeling scheme described in the paper + roots = [0, 1] + sgs_single_root = [remove_nodes(subgraph, [root]) for root in roots] + dist_to_roots = [ + np.clip(ssp.csgraph.dijkstra(sg, indices=[0], directed=False, unweighted=True, limit=1e6)[:, 1:], 0, 1e7) + for r, sg in enumerate(sgs_single_root)] + dist_to_roots = np.array(list(zip(dist_to_roots[0][0], dist_to_roots[1][0])), dtype=int) + + target_node_labels = np.array([[0, 1], [1, 0]]) + labels = np.concatenate((target_node_labels, dist_to_roots)) if dist_to_roots.size else target_node_labels + + enclosing_subgraph_nodes = np.where(np.max(labels, axis=1) <= max_distance)[0] + return labels, enclosing_subgraph_nodes \ No newline at end of file diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..7b96b2e --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,233 @@ +import argparse +from scipy.sparse.csgraph import shortest_path +import numpy as np +import pandas as pd +import torch +import dgl +import json +import logging +import logging.config +import os +import psutil +from statistics import mean +import pickle + + +def parse_arguments(): + """ + Parse arguments + """ + parser = argparse.ArgumentParser(description='SEAL') + parser.add_argument('--dataset', type=str, default='ogbl-collab') + parser.add_argument('--gpu_id', type=int, default=0) + parser.add_argument('--hop', type=int, default=1) + parser.add_argument('--model', type=str, default='dgcnn') + parser.add_argument('--gcn_type', type=str, default='gcn') + parser.add_argument('--num_layers', type=int, default=3) + parser.add_argument('--hidden_units', type=int, default=32) + parser.add_argument('--sort_k', type=int, default=30) + parser.add_argument('--pooling', type=str, default='sum') + parser.add_argument('--dropout', type=str, default=0.5) + parser.add_argument('--hits_k', type=int, default=50) + parser.add_argument('--lr', type=float, default=0.0001) + parser.add_argument('--neg_samples', type=int, default=1) + parser.add_argument('--subsample_ratio', type=float, default=0.1) + parser.add_argument('--epochs', type=int, default=60) + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--eval_steps', type=int, default=5) + parser.add_argument('--num_workers', type=int, default=32) + parser.add_argument('--random_seed', type=int, default=2023) + parser.add_argument('--save_dir', type=str, default='./processed') + args = parser.parse_args() + + return args + + +def coalesce_graph(graph, aggr_type='sum', copy_data=False): + """ + Coalesce multi-edge graph + Args: + graph(DGLGraph): graph + aggr_type(str): type of aggregator for multi edge weights + copy_data(bool): if copy ndata and edata in new graph + + Returns: + graph(DGLGraph): graph + + + """ + src, dst = graph.edges() + graph_df = pd.DataFrame({'src': src, 'dst': dst}) + graph_df['edge_weight'] = graph.edata['edge_weight'].numpy() + + if aggr_type == 'sum': + tmp = graph_df.groupby(['src', 'dst'])['edge_weight'].sum().reset_index() + elif aggr_type == 'mean': + tmp = graph_df.groupby(['src', 'dst'])['edge_weight'].mean().reset_index() + else: + raise ValueError("aggr type error") + + if copy_data: + graph = dgl.to_simple(graph, copy_ndata=True, copy_edata=True) + else: + graph = dgl.to_simple(graph) + + src, dst = graph.edges() + graph_df = pd.DataFrame({'src': src, 'dst': dst}) + graph_df = pd.merge(graph_df, tmp, how='left', on=['src', 'dst']) + graph.edata['edge_weight'] = torch.from_numpy(graph_df['edge_weight'].values).unsqueeze(1) + + graph.edata.pop('count') + return graph + + +def drnl_node_labeling(subgraph, src, dst): + """ + Double Radius Node labeling + d = r(i,u)+r(i,v) + label = 1+ min(r(i,u),r(i,v))+ (d//2)*(d//2+d%2-1) + Isolated nodes in subgraph will be set as zero. + Extreme large graph may cause memory error. + + Args: + subgraph(DGLGraph): The graph + src(int): node id of one of src node in new subgraph + dst(int): node id of one of dst node in new subgraph + Returns: + z(Tensor): node labeling tensor + """ + adj = subgraph.adj().to_dense().numpy() + src, dst = (dst, src) if src > dst else (src, dst) + if src != dst: + idx = list(range(src)) + list(range(src + 1, adj.shape[0])) + adj_wo_src = adj[idx, :][:, idx] + + idx = list(range(dst)) + list(range(dst + 1, adj.shape[0])) + adj_wo_dst = adj[idx, :][:, idx] + + dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src) + dist2src = np.insert(dist2src, dst, 0, axis=0) + dist2src = torch.from_numpy(dist2src) + + dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst - 1) + dist2dst = np.insert(dist2dst, src, 0, axis=0) + dist2dst = torch.from_numpy(dist2dst) + else: + dist2src = shortest_path(adj, directed=False, unweighted=True, indices=src) + # dist2src = np.insert(dist2src, dst, 0, axis=0) + dist2src = torch.from_numpy(dist2src) + + dist2dst = shortest_path(adj, directed=False, unweighted=True, indices=dst) + # dist2dst = np.insert(dist2dst, src, 0, axis=0) + dist2dst = torch.from_numpy(dist2dst) + + dist = dist2src + dist2dst + dist_over_2, dist_mod_2 = dist // 2, dist % 2 + + z = 1 + torch.min(dist2src, dist2dst) + z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1) + z[src] = 1. + z[dst] = 1. + z[torch.isnan(z)] = 0. + + return z.to(torch.long) + + +def cal_ranks(probs, label): + sorted_idx = np.argsort(probs, axis=1)[:,::-1] + find_target = sorted_idx == np.expand_dims(label, 1) + ranks = np.nonzero(find_target)[1] + 1 + return ranks + +def cal_performance(ranks): + mrr = (1. / ranks).sum() / len(ranks) + m_r = sum(ranks) * 1.0 / len(ranks) + h_1 = sum(ranks<=1) * 1.0 / len(ranks) + h_10 = sum(ranks<=10) * 1.0 / len(ranks) + return mrr, m_r, h_1, h_10 + +def get_logger(name, log_dir): + config_dict = json.load(open('./config/' + 'log_config.json')) + config_dict['handlers']['file_handler']['filename'] = log_dir + name + '.log' + logging.config.dictConfig(config_dict) + logger = logging.getLogger(name) + return logger + +def get_current_memory_gb() -> int: +# 获取当前进程内存占用。 + pid = os.getpid() + p = psutil.Process(pid) + info = p.memory_full_info() + return info.uss / 1024. / 1024. / 1024. + +def get_f1_score_list(class_dict): + f1_score_list = [[],[],[],[],[]] + for key in class_dict: + if key.isdigit(): + if class_dict[key]['support'] < 10: + f1_score_list[0].append(class_dict[key]['f1-score']) + elif 10 <= class_dict[key]['support'] < 50: + f1_score_list[1].append(class_dict[key]['f1-score']) + elif 50 <= class_dict[key]['support'] < 100: + f1_score_list[2].append(class_dict[key]['f1-score']) + elif 100 <= class_dict[key]['support'] < 1000: + f1_score_list[3].append(class_dict[key]['f1-score']) + elif 1000 <= class_dict[key]['support'] < 100000: + f1_score_list[4].append(class_dict[key]['f1-score']) + for index, _ in enumerate(f1_score_list): + f1_score_list[index] = mean(_) + return f1_score_list + +def get_acc_list(class_dict): + acc_list = [0.0,0.0,0.0,0.0,0.0] + support_list = [0,0,0,0,0] + proportion_list = [0.0,0.0,0.0,0.0,0.0] + for key in class_dict: + if key.isdigit(): + if class_dict[key]['support'] < 10: + acc_list[0] += (class_dict[key]['recall']*class_dict[key]['support']) + support_list[0]+=class_dict[key]['support'] + elif 10 <= class_dict[key]['support'] < 50: + acc_list[1] += (class_dict[key]['recall']*class_dict[key]['support']) + support_list[1] += class_dict[key]['support'] + elif 50 <= class_dict[key]['support'] < 100: + acc_list[2] += (class_dict[key]['recall']*class_dict[key]['support']) + support_list[2] += class_dict[key]['support'] + elif 100 <= class_dict[key]['support'] < 1000: + acc_list[3] += (class_dict[key]['recall']*class_dict[key]['support']) + support_list[3] += class_dict[key]['support'] + elif 1000 <= class_dict[key]['support'] < 100000: + acc_list[4] += (class_dict[key]['recall'] * class_dict[key]['support']) + support_list[4] += class_dict[key]['support'] + for index, _ in enumerate(acc_list): + acc_list[index] = acc_list[index] / support_list[index] + # proportion_list[index] = support_list[index] / class_dict['macro avg']['support'] + return acc_list + +def deserialize(data): + data_tuple = pickle.loads(data) + keys = ('nodes', 'r_label', 'g_label', 'n_label') + return dict(zip(keys, data_tuple)) + + +class Temp_Scheduler(object): + def __init__(self, total_epochs, curr_temp, base_temp, temp_min=0.33, last_epoch=-1): + self.curr_temp = curr_temp + self.base_temp = base_temp + self.temp_min = temp_min + self.last_epoch = last_epoch + self.total_epochs = total_epochs + self.step(last_epoch + 1) + + def step(self, epoch=None): + return self.decay_whole_process() + + def decay_whole_process(self, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + self.last_epoch = epoch + # self.total_epochs = 150 + self.curr_temp = (1 - self.last_epoch / self.total_epochs) * (self.base_temp - self.temp_min) + self.temp_min + if self.curr_temp < self.temp_min: + self.curr_temp = self.temp_min + return self.curr_temp \ No newline at end of file