diff --git a/README.md b/README.md
index b7164c5..ec3ae50 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,72 @@
-# CSSE-DDI
\ No newline at end of file
+# Customized Subgraph Selection and Encoding for Drug-drug Interaction Prediction
+
+
+
+
+
+---
+
+## Requirements
+
+```sheel
+torch==1.13.0
+dgl-cu111==0.6.1
+optuna==3.2.0
+```
+
+## Run
+
+### Unpack Dataset
+```shell
+unzip datasets.zip
+```
+
+### Supernet Training
+```shell
+python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \
+--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \
+--loss_type ce --dataset drugbank --ss_search_algorithm snas
+```
+### Sub-Supernet Training
+```shell
+python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \
+--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \
+--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op rotate --weight_sharing --ss_search_algorithm snas
+
+python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \
+--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \
+--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op ccorr --weight_sharing --ss_search_algorithm snas
+
+python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \
+--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \
+--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op mult --weight_sharing --ss_search_algorithm snas
+
+python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \
+--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \
+--loss_type ce --dataset drugbank --exp_note spfs --few_shot_op sub --weight_sharing --ss_search_algorithm snas
+```
+### Subgraph Selection and Encoding Function Searching
+```shell
+python run.py --encoder searchgcn --score_func mlp --combine_type concat --n_layer 3 --epoch 400 \
+--batch 512 --seed 0 --search_mode joint_search --search_algorithm spos_train_supernet_ps2 --input_type allgraph \
+--loss_type ce --dataset drugbank --exp_note spfs --weight_sharing --ss_search_algorithm snas --arch_search_mode ng
+```
+## Citation
+
+Readers are welcomed to follow our work. Please kindly cite our paper:
+
+```bibtex
+@inproceedings{du2024customized,
+ title={Customized Subgraph Selection and Encoding for Drug-drug Interaction Prediction},
+ author={Du, Haotong and Yao, Quanming and Zhang, Juzheng and Liu, Yang and Wang, Zhen},
+ booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
+ year={2024}
+}
+```
+
+## Contact
+If you have any questions, feel free to contact me at [duhaotong@mail.nwpu.edu.cn](mailto:duhaotong@mail.nwpu.edu.cn).
+
+## Acknowledgement
+
+The codes of this paper are partially based on the codes of [SEAL_dgl](https://github.com/Smilexuhc/SEAL_dgl), [PS2](https://github.com/qiaoyu-tan/PS2), and [Interstellar](https://github.com/LARS-research/Interstellar). We thank the authors of above work.
diff --git a/config/logger_config.json b/config/logger_config.json
new file mode 100644
index 0000000..f4d5547
--- /dev/null
+++ b/config/logger_config.json
@@ -0,0 +1,35 @@
+{
+ "version": 1,
+ "root": {
+ "handlers": [
+ "console_handler",
+ "file_handler"
+ ],
+ "level": "DEBUG"
+ },
+ "handlers": {
+ "console_handler": {
+ "class": "logging.StreamHandler",
+ "level": "DEBUG",
+ "formatter": "console_formatter"
+ },
+ "file_handler": {
+ "class": "logging.FileHandler",
+ "level": "DEBUG",
+ "formatter": "file_formatter",
+ "filename": "python_logging.log",
+ "encoding": "utf8",
+ "mode": "w"
+ }
+ },
+ "formatters": {
+ "console_formatter": {
+ "format": "%(asctime)s - %(message)s",
+ "datefmt": "%Y-%m-%d %H:%M:%S"
+ },
+ "file_formatter": {
+ "format": "%(asctime)s - %(message)s",
+ "datefmt": "%Y-%m-%d %H:%M:%S"
+ }
+ }
+}
\ No newline at end of file
diff --git a/data/knowledge_graph.py b/data/knowledge_graph.py
new file mode 100644
index 0000000..29aec3b
--- /dev/null
+++ b/data/knowledge_graph.py
@@ -0,0 +1,176 @@
+"""
+based on the implementation in DGL
+(https://github.com/dmlc/dgl/blob/master/python/dgl/contrib/data/knowledge_graph.py)
+Knowledge graph dataset for Relational-GCN
+Code adapted from authors' implementation of Relational-GCN
+https://github.com/tkipf/relational-gcn
+https://github.com/MichSchli/RelationPrediction
+"""
+
+from __future__ import print_function
+from __future__ import absolute_import
+import numpy as np
+import scipy.sparse as sp
+import os
+
+from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
+from utils.dgl_utils import process_files_ddi
+from utils.graph_utils import incidence_matrix
+
+# np.random.seed(123)
+
+_downlaod_prefix = _get_dgl_url('dataset/')
+
+
+def load_data(dataset):
+ if dataset in ['drugbank', 'twosides', 'twosides_200', 'drugbank_s1', 'twosides_s1']:
+ return load_link(dataset)
+ else:
+ raise ValueError('Unknown dataset: {}'.format(dataset))
+
+
+class RGCNLinkDataset(object):
+
+ def __init__(self, name):
+ self.name = name
+ self.dir = 'datasets'
+
+ # zip_path = os.path.join(self.dir, '{}.zip'.format(self.name))
+ self.dir = os.path.join(self.dir, self.name)
+ # extract_archive(zip_path, self.dir)
+
+ def load(self):
+ entity_path = os.path.join(self.dir, 'entities.dict')
+ relation_path = os.path.join(self.dir, 'relations.dict')
+ train_path = os.path.join(self.dir, 'train.txt')
+ valid_path = os.path.join(self.dir, 'valid.txt')
+ test_path = os.path.join(self.dir, 'test.txt')
+ entity_dict = _read_dictionary(entity_path)
+ relation_dict = _read_dictionary(relation_path)
+ self.train = np.asarray(_read_triplets_as_list(
+ train_path, entity_dict, relation_dict))
+ self.valid = np.asarray(_read_triplets_as_list(
+ valid_path, entity_dict, relation_dict))
+ self.test = np.asarray(_read_triplets_as_list(
+ test_path, entity_dict, relation_dict))
+ self.num_nodes = len(entity_dict)
+ print("# entities: {}".format(self.num_nodes))
+ self.num_rels = len(relation_dict)
+ print("# relations: {}".format(self.num_rels))
+ print("# training sample: {}".format(len(self.train)))
+ print("# valid sample: {}".format(len(self.valid)))
+ print("# testing sample: {}".format(len(self.test)))
+ file_paths = {
+ 'train': f'{self.dir}/train_raw.txt',
+ 'valid': f'{self.dir}/dev_raw.txt',
+ 'test': f'{self.dir}/test_raw.txt'
+ }
+ external_kg_file = f'{self.dir}/external_kg.txt'
+ adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel = process_files_ddi(file_paths, external_kg_file)
+ A_incidence = incidence_matrix(adj_list)
+ A_incidence += A_incidence.T
+ self.adj = A_incidence
+
+
+
+def load_link(dataset):
+ if 'twosides' in dataset or 'ogbl_biokg' in dataset:
+ data = MultiLabelDataset(dataset)
+ else:
+ data = RGCNLinkDataset(dataset)
+ data.load()
+ return data
+
+
+def _read_dictionary(filename):
+ d = {}
+ with open(filename, 'r+') as f:
+ for line in f:
+ line = line.strip().split('\t')
+ d[line[1]] = int(line[0])
+ return d
+
+
+def _read_triplets(filename):
+ with open(filename, 'r+') as f:
+ for line in f:
+ processed_line = line.strip().split('\t')
+ yield processed_line
+
+
+def _read_triplets_as_list(filename, entity_dict, relation_dict):
+ l = []
+ for triplet in _read_triplets(filename):
+ s = entity_dict[triplet[0]]
+ r = relation_dict[triplet[1]]
+ o = entity_dict[triplet[2]]
+ l.append([s, r, o])
+ return l
+
+
+def _read_multi_rel_triplets(filename):
+ with open(filename, 'r+') as f:
+ for line in f:
+ processed_line = line.strip().split('\t')
+ yield processed_line
+
+def _read_multi_rel_triplets_as_array(filename, entity_dict):
+ graph_list = []
+ input_list = []
+ multi_label_list = []
+ pos_neg_list = []
+ for triplet in _read_triplets(filename):
+ s = entity_dict[triplet[0]]
+ o = entity_dict[triplet[1]]
+ r_list = list(map(int, triplet[2].split(',')))
+ multi_label_list.append(r_list)
+ r_label = [i for i, _ in enumerate(r_list) if _ == 1]
+ for r in r_label:
+ graph_list.append([s, r, o])
+ input_list.append([s, -1, o])
+ pos_neg = int(triplet[3])
+ pos_neg_list.append(pos_neg)
+ return np.asarray(graph_list), np.asarray(input_list), np.asarray(multi_label_list), np.asarray(pos_neg_list)
+
+class MultiLabelDataset(object):
+ def __init__(self, name):
+ self.name = name
+ self.dir = 'datasets'
+
+ # zip_path = os.path.join(self.dir, '{}.zip'.format(self.name))
+ self.dir = os.path.join(self.dir, self.name)
+ # extract_archive(zip_path, self.dir)
+
+ def load(self):
+ entity_path = os.path.join(self.dir, 'entities.dict')
+ train_path = os.path.join(self.dir, 'train.txt')
+ valid_path = os.path.join(self.dir, 'valid.txt')
+ test_path = os.path.join(self.dir, 'test.txt')
+ entity_dict = _read_dictionary(entity_path)
+ self.train_graph, self.train_input, self.train_multi_label, self.train_pos_neg = _read_multi_rel_triplets_as_array(
+ train_path, entity_dict)
+ _, self.valid_input, self.valid_multi_label, self.valid_pos_neg = _read_multi_rel_triplets_as_array(
+ valid_path, entity_dict)
+ _, self.test_input, self.test_multi_label, self.test_pos_neg = _read_multi_rel_triplets_as_array(
+ test_path, entity_dict)
+ self.num_nodes = len(entity_dict)
+ print("# entities: {}".format(self.num_nodes))
+ self.num_rels = self.train_multi_label.shape[1]
+ print("# relations: {}".format(self.num_rels))
+ print("# training sample: {}".format(self.train_input.shape[0]))
+ print("# valid sample: {}".format(self.valid_input.shape[0]))
+ print("# testing sample: {}".format(self.test_input.shape[0]))
+ # print("# training sample: {}".format(len(self.train)))
+ # print("# valid sample: {}".format(len(self.valid)))
+ # print("# testing sample: {}".format(len(self.test)))
+ # file_paths = {
+ # 'train': f'{self.dir}/train_raw.txt',
+ # 'valid': f'{self.dir}/dev_raw.txt',
+ # 'test': f'{self.dir}/test_raw.txt'
+ # }
+ # external_kg_file = f'{self.dir}/external_kg.txt'
+ # adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel = process_files_ddi(file_paths,
+ # external_kg_file)
+ # A_incidence = incidence_matrix(adj_list)
+ # A_incidence += A_incidence.T
+ # self.adj = A_incidence
\ No newline at end of file
diff --git a/datasets.zip b/datasets.zip
new file mode 100644
index 0000000..44d839b
Binary files /dev/null and b/datasets.zip differ
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000..07b73fb
--- /dev/null
+++ b/model/__init__.py
@@ -0,0 +1,7 @@
+from .gcns import GCN_TransE, GCN_DistMult, GCN_ConvE, GCN_ConvE_Rel, GCN_Transformer, GCN_None, GCN_MLP, GCN_MLP_NCN
+from .subgraph_selector import SubgraphSelector
+from .model_search import SearchGCN_MLP
+from .model import SearchedGCN_MLP
+from .model_fast import NetworkGNN_MLP
+from .model_spos import SearchGCN_MLP_SPOS
+from .seal_model import SEAL_GCN
diff --git a/model/compgcn_layer.py b/model/compgcn_layer.py
new file mode 100644
index 0000000..777da9c
--- /dev/null
+++ b/model/compgcn_layer.py
@@ -0,0 +1,158 @@
+import torch
+from torch import nn
+import dgl
+import dgl.function as fn
+
+
+class CompGCNCov(nn.Module):
+ def __init__(self, in_channels, out_channels, act=lambda x: x, bias=True, drop_rate=0., opn='corr', num_base=-1,
+ num_rel=None, wni=False, wsi=False, use_bn=True, ltr=True, add_reverse=True):
+ super(CompGCNCov, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.act = act # activation function
+ self.device = None
+ if add_reverse:
+ self.rel = nn.Parameter(torch.empty([num_rel * 2, in_channels], dtype=torch.float))
+ else:
+ self.rel = nn.Parameter(torch.empty([num_rel, in_channels], dtype=torch.float))
+ self.opn = opn
+
+ self.use_bn = use_bn
+ self.ltr = ltr
+
+ # relation-type specific parameter
+ self.in_w = self.get_param([in_channels, out_channels])
+ self.out_w = self.get_param([in_channels, out_channels])
+ self.loop_w = self.get_param([in_channels, out_channels])
+ # transform embedding of relations to next layer
+ self.w_rel = self.get_param([in_channels, out_channels])
+ self.loop_rel = self.get_param([1, in_channels]) # self-loop embedding
+
+ self.drop = nn.Dropout(drop_rate)
+ self.bn = torch.nn.BatchNorm1d(out_channels)
+ self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None
+ if num_base > 0:
+ if add_reverse:
+ self.rel_wt = self.get_param([num_rel * 2, num_base])
+ else:
+ self.rel_wt = self.get_param([num_rel, num_base])
+ else:
+ self.rel_wt = None
+
+ self.wni = wni
+ self.wsi = wsi
+
+ def get_param(self, shape):
+ param = nn.Parameter(torch.Tensor(*shape))
+ nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu'))
+ return param
+
+ def message_func(self, edges):
+ edge_type = edges.data['type'] # [E, 1]
+ edge_num = edge_type.shape[0]
+ edge_data = self.comp(
+ edges.src['h'], self.rel[edge_type]) # [E, in_channel]
+ # NOTE: first half edges are all in-directions, last half edges are out-directions.
+ msg = torch.cat([torch.matmul(edge_data[:edge_num // 2, :], self.in_w),
+ torch.matmul(edge_data[edge_num // 2:, :], self.out_w)])
+ msg = msg * edges.data['norm'].reshape(-1, 1) # [E, D] * [E, 1]
+ return {'msg': msg}
+
+ def reduce_func(self, nodes):
+ return {'h': self.drop(nodes.data['h'])}
+
+ def comp(self, h, edge_data):
+ # def com_mult(a, b):
+ # r1, i1 = a[..., 0], a[..., 1]
+ # r2, i2 = b[..., 0], b[..., 1]
+ # return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1)
+ #
+ # def conj(a):
+ # a[..., 1] = -a[..., 1]
+ # return a
+ #
+ # def ccorr(a, b):
+ # # return torch.irfft(com_mult(conj(torch.rfft(a, 1)), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],))
+ # return torch.fft.irfftn(torch.conj(torch.fft.rfftn(a, (-1))) * torch.fft.rfftn(b, (-1)), (-1))
+
+ def com_mult(a, b):
+ r1, i1 = a.real, a.imag
+ r2, i2 = b.real, b.imag
+ real = r1 * r2 - i1 * i2
+ imag = r1 * i2 + i1 * r2
+ return torch.complex(real, imag)
+
+ def conj(a):
+ a.imag = -a.imag
+ return a
+
+ def ccorr(a, b):
+ return torch.fft.irfft(com_mult(conj(torch.fft.rfft(a)), torch.fft.rfft(b)), a.shape[-1])
+
+ def rotate(h, r):
+ # re: first half, im: second half
+ # assume embedding dim is the last dimension
+ d = h.shape[-1]
+ h_re, h_im = torch.split(h, d // 2, -1)
+ r_re, r_im = torch.split(r, d // 2, -1)
+ return torch.cat([h_re * r_re - h_im * r_im,
+ h_re * r_im + h_im * r_re], dim=-1)
+
+ if self.opn == 'mult':
+ return h * edge_data
+ elif self.opn == 'sub':
+ return h - edge_data
+ elif self.opn == 'add':
+ return h + edge_data
+ elif self.opn == 'corr':
+ return ccorr(h, edge_data.expand_as(h))
+ elif self.opn == 'rotate':
+ return rotate(h, edge_data)
+ else:
+ raise KeyError(f'composition operator {self.opn} not recognized.')
+
+ def forward(self, g: dgl.DGLGraph, x, rel_repr, edge_type, edge_norm):
+ """
+ :param g: dgl Graph, a graph without self-loop
+ :param x: input node features, [V, in_channel]
+ :param rel_repr: input relation features: 1. not using bases: [num_rel*2, in_channel]
+ 2. using bases: [num_base, in_channel]
+ :param edge_type: edge type, [E]
+ :param edge_norm: edge normalization, [E]
+ :return: x: output node features: [V, out_channel]
+ rel: output relation features: [num_rel*2, out_channel]
+ """
+ self.device = x.device
+ g = g.local_var()
+ g.ndata['h'] = x
+ g.edata['type'] = edge_type
+ g.edata['norm'] = edge_norm
+ # print(self.rel.data)
+ if self.rel_wt is None:
+ self.rel.data = rel_repr
+ else:
+ # [num_rel*2, num_base] @ [num_base, in_c]
+ self.rel.data = torch.mm(self.rel_wt, rel_repr)
+ g.update_all(self.message_func, fn.sum(
+ msg='msg', out='h'), self.reduce_func)
+
+ if (not self.wni) and (not self.wsi):
+ x = (g.ndata.pop('h') +
+ torch.mm(self.comp(x, self.loop_rel), self.loop_w)) / 3
+ else:
+ if self.wsi:
+ x = g.ndata.pop('h') / 2
+ if self.wni:
+ x = torch.mm(self.comp(x, self.loop_rel), self.loop_w)
+
+ if self.bias is not None:
+ x = x + self.bias
+
+ if self.use_bn:
+ x = self.bn(x)
+
+ if self.ltr:
+ return self.act(x), torch.matmul(self.rel.data, self.w_rel)
+ else:
+ return self.act(x), self.rel.data
diff --git a/model/fune_layer.py b/model/fune_layer.py
new file mode 100644
index 0000000..2f702c3
--- /dev/null
+++ b/model/fune_layer.py
@@ -0,0 +1,237 @@
+import torch
+from torch import nn
+import dgl
+import dgl.function as fn
+from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES
+from model.operations import *
+
+
+class CompOp(nn.Module):
+ def __init__(self, primitive):
+ super(CompOp, self).__init__()
+ self._op = COMP_OPS[primitive]()
+
+ def reset_parameters(self):
+ self._op.reset_parameters()
+
+ def forward(self, src_emb, rel_emb):
+ return self._op(src_emb, rel_emb)
+
+
+class AggOp(nn.Module):
+ def __init__(self, primitive):
+ super(AggOp, self).__init__()
+ self._op = AGG_OPS[primitive]()
+
+ def reset_parameters(self):
+ self._op.reset_parameters()
+
+ def forward(self, msg):
+ return self._op(msg)
+
+
+class CombOp(nn.Module):
+ def __init__(self, primitive, out_channels):
+ super(CombOp, self).__init__()
+ self._op = COMB_OPS[primitive](out_channels)
+
+ def reset_parameters(self):
+ self._op.reset_parameters()
+
+ def forward(self, self_emb, msg):
+ return self._op(self_emb, msg)
+
+
+class ActOp(nn.Module):
+ def __init__(self, primitive):
+ super(ActOp, self).__init__()
+ self._op = ACT_OPS[primitive]()
+
+ def reset_parameters(self):
+ self._op.reset_parameters()
+
+ def forward(self, emb):
+ return self._op(emb)
+
+def act_map(act):
+ if act == "identity":
+ return lambda x: x
+ elif act == "elu":
+ return torch.nn.functional.elu
+ elif act == "sigmoid":
+ return torch.sigmoid
+ elif act == "tanh":
+ return torch.tanh
+ elif act == "relu":
+ return torch.nn.functional.relu
+ elif act == "relu6":
+ return torch.nn.functional.relu6
+ elif act == "softplus":
+ return torch.nn.functional.softplus
+ elif act == "leaky_relu":
+ return torch.nn.functional.leaky_relu
+ else:
+ raise Exception("wrong activate function")
+
+
+class SearchedGCNConv(nn.Module):
+ def __init__(self, in_channels, out_channels, bias=True, drop_rate=0., num_base=-1,
+ num_rel=None, wni=False, wsi=False, use_bn=True, ltr=True, comp=None, agg=None, comb=None, act=None):
+ super(SearchedGCNConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.comp_op = comp
+ self.act = act_map(act) # activation function
+ self.agg_op = agg
+ self.device = None
+
+ self.rel = nn.Parameter(torch.empty([num_rel, in_channels], dtype=torch.float))
+
+ self.use_bn = use_bn
+ self.ltr = ltr
+
+ # relation-type specific parameter
+ self.in_w = self.get_param([in_channels, out_channels])
+ self.out_w = self.get_param([in_channels, out_channels])
+ self.loop_w = self.get_param([in_channels, out_channels])
+ # transform embedding of relations to next layer
+ self.w_rel = self.get_param([in_channels, out_channels])
+ self.loop_rel = self.get_param([1, in_channels]) # self-loop embedding
+
+ self.drop = nn.Dropout(drop_rate)
+ self.bn = torch.nn.BatchNorm1d(out_channels)
+ self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None
+ if num_base > 0:
+ self.rel_wt = self.get_param([num_rel, num_base])
+ else:
+ self.rel_wt = None
+
+ self.wni = wni
+ self.wsi = wsi
+ # self.comp = CompOp(comp)
+ # self.agg = AggOp(agg)
+ self.comb = CombOp(comb, out_channels)
+ # self.act = ActOp(act)
+
+ def get_param(self, shape):
+ param = nn.Parameter(torch.Tensor(*shape))
+ nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu'))
+ return param
+
+ def message_func(self, edges):
+ edge_type = edges.data['type'] # [E, 1]
+ edge_num = edge_type.shape[0]
+ edge_data = self.comp(
+ edges.src['h'], self.rel[edge_type]) # [E, in_channel]
+ # NOTE: first half edges are all in-directions, last half edges are out-directions.
+ msg = torch.cat([torch.matmul(edge_data[:edge_num // 2, :], self.in_w),
+ torch.matmul(edge_data[edge_num // 2:, :], self.out_w)])
+ msg = msg * edges.data['norm'].reshape(-1, 1) # [E, D] * [E, 1]
+ return {'msg': msg}
+
+ def reduce_func(self, nodes):
+ # return {'h': torch.sum(nodes.mailbox['msg'], dim=1)}
+ return {'h': self.agg(nodes.mailbox['msg'])}
+
+ def apply_node_func(self, nodes):
+ return {'h': self.drop(nodes.data['h'])}
+
+ def comp(self, h, edge_data):
+ # def com_mult(a, b):
+ # r1, i1 = a[..., 0], a[..., 1]
+ # r2, i2 = b[..., 0], b[..., 1]
+ # return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1)
+ #
+ # def conj(a):
+ # a[..., 1] = -a[..., 1]
+ # return a
+ #
+ # def ccorr(a, b):
+ # # return torch.irfft(com_mult(conj(torch.rfft(a, 1)), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],))
+ # return torch.fft.irfftn(torch.conj(torch.fft.rfftn(a, (-1))) * torch.fft.rfftn(b, (-1)), (-1))
+
+ def com_mult(a, b):
+ r1, i1 = a.real, a.imag
+ r2, i2 = b.real, b.imag
+ real = r1 * r2 - i1 * i2
+ imag = r1 * i2 + i1 * r2
+ return torch.complex(real, imag)
+
+ def conj(a):
+ a.imag = -a.imag
+ return a
+
+ def ccorr(a, b):
+ return torch.fft.irfft(com_mult(conj(torch.fft.rfft(a)), torch.fft.rfft(b)), a.shape[-1])
+
+ def rotate(h, r):
+ # re: first half, im: second half
+ # assume embedding dim is the last dimension
+ d = h.shape[-1]
+ h_re, h_im = torch.split(h, d // 2, -1)
+ r_re, r_im = torch.split(r, d // 2, -1)
+ return torch.cat([h_re * r_re - h_im * r_im,
+ h_re * r_im + h_im * r_re], dim=-1)
+
+ if self.comp_op == 'mult':
+ return h * edge_data
+ elif self.comp_op == 'add':
+ return h + edge_data
+ elif self.comp_op == 'sub':
+ return h - edge_data
+ elif self.comp_op == 'ccorr':
+ return ccorr(h, edge_data.expand_as(h))
+ elif self.comp_op == 'rotate':
+ return rotate(h, edge_data)
+ else:
+ raise KeyError(f'composition operator {self.opn} not recognized.')
+
+ def forward(self, g: dgl.DGLGraph, x, rel_repr, edge_type, edge_norm):
+ """
+ :param g: dgl Graph, a graph without self-loop
+ :param x: input node features, [V, in_channel]
+ :param rel_repr: input relation features: 1. not using bases: [num_rel*2, in_channel]
+ 2. using bases: [num_base, in_channel]
+ :param edge_type: edge type, [E]
+ :param edge_norm: edge normalization, [E]
+ :return: x: output node features: [V, out_channel]
+ rel: output relation features: [num_rel*2, out_channel]
+ """
+ self.device = x.device
+ g = g.local_var()
+ g.ndata['h'] = x
+ g.edata['type'] = edge_type
+ g.edata['norm'] = edge_norm
+ if self.rel_wt is None:
+ self.rel.data = rel_repr
+ else:
+ # [num_rel*2, num_base] @ [num_base, in_c]
+ self.rel.data = torch.mm(self.rel_wt, rel_repr)
+ if self.agg_op == 'max':
+ g.update_all(self.message_func, fn.max(msg='msg', out='h'), self.apply_node_func)
+ elif self.agg_op == 'mean':
+ g.update_all(self.message_func, fn.mean(msg='msg', out='h'), self.apply_node_func)
+ elif self.agg_op == 'sum':
+ g.update_all(self.message_func, fn.sum(msg='msg', out='h'), self.apply_node_func)
+ # g.update_all(self.message_func, self.reduce_func, self.apply_node_func)
+
+ if (not self.wni) and (not self.wsi):
+ x = self.comb(g.ndata.pop('h'), torch.mm(self.comp(x, self.loop_rel), self.loop_w))*(1/3)
+ # x = (g.ndata.pop('h') +
+ # torch.mm(self.comp(x, self.loop_rel, self.comp_weights), self.loop_w)) / 3
+ # else:
+ # if self.wsi:
+ # x = g.ndata.pop('h') / 2
+ # if self.wni:
+ # x = torch.mm(self.comp(x, self.loop_rel), self.loop_w)
+
+ if self.bias is not None:
+ x = x + self.bias
+
+ if self.use_bn:
+ x = self.bn(x)
+
+ if self.ltr:
+ return self.act(x), torch.matmul(self.rel.data, self.w_rel)
+ else:
+ return self.act(x), self.rel.data
\ No newline at end of file
diff --git a/model/gcns.py b/model/gcns.py
new file mode 100644
index 0000000..4919469
--- /dev/null
+++ b/model/gcns.py
@@ -0,0 +1,1607 @@
+import torch
+from torch import nn
+import dgl
+from model.rgcn_layer import RelGraphConv
+from model.compgcn_layer import CompGCNCov
+import torch.nn.functional as F
+from dgl import NID, EID, readout_nodes
+from dgl.nn.pytorch import GraphConv
+import time
+
+
+class GCNs(nn.Module):
+ def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ conv_bias=True, gcn_drop=0., opn='mult', wni=False, wsi=False, encoder='compgcn', use_bn=True,
+ ltr=True, input_type='subgraph', loss_type='ce', add_reverse=True):
+ super(GCNs, self).__init__()
+ self.act = torch.tanh
+ if loss_type == 'ce':
+ self.loss = nn.CrossEntropyLoss()
+ elif loss_type == 'bce':
+ self.loss = nn.BCELoss(reduce=False)
+ elif loss_type == 'bce_logits':
+ self.loss = nn.BCEWithLogitsLoss()
+ else:
+ raise NotImplementedError
+ self.num_ent, self.num_rel, self.num_base = num_ent, num_rel, num_base
+ self.init_dim, self.gcn_dim, self.embed_dim = init_dim, gcn_dim, embed_dim
+ self.conv_bias = conv_bias
+ self.gcn_drop = gcn_drop
+ self.opn = opn
+ self.edge_type = edge_type # [E]
+ self.edge_norm = edge_norm # [E]
+ self.n_layer = n_layer
+ self.input_type = input_type
+
+ self.wni = wni
+
+ self.encoder = encoder
+ if input_type == 'subgraph':
+ self.init_embed = self.get_param([self.num_ent, self.init_dim])
+ # self.init_embed = nn.Embedding(self.num_ent+2, self.init_dim)
+ else:
+ self.init_embed = self.get_param([self.num_ent + 1, self.init_dim])
+ if add_reverse:
+ self.init_rel = self.get_param([self.num_rel * 2, self.init_dim])
+ self.bias_rel = nn.Parameter(torch.zeros(self.num_rel * 2))
+ else:
+ self.init_rel = self.get_param([self.num_rel, self.init_dim])
+ self.bias_rel = nn.Parameter(torch.zeros(self.num_rel))
+
+ if encoder == 'compgcn':
+ if n_layer < 3:
+ self.conv1 = CompGCNCov(self.init_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, opn, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ add_reverse=add_reverse)
+ self.conv2 = CompGCNCov(self.gcn_dim, self.embed_dim, self.act, conv_bias, gcn_drop,
+ opn, num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ add_reverse=add_reverse) if n_layer == 2 else None
+ else:
+ self.conv1 = CompGCNCov(self.init_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, opn, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ add_reverse=add_reverse)
+ self.conv2 = CompGCNCov(self.gcn_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, opn, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ add_reverse=add_reverse)
+ self.conv3 = CompGCNCov(self.gcn_dim, self.embed_dim, self.act, conv_bias, gcn_drop,
+ opn, num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ add_reverse=add_reverse)
+ elif encoder == 'rgcn':
+ if n_layer < 3:
+ self.conv1 = RelGraphConv(self.init_dim, self.gcn_dim, self.num_rel * 2, "bdd",
+ num_bases=self.num_base, activation=self.act, self_loop=(not wsi),
+ dropout=gcn_drop, wni=wni)
+ self.conv2 = RelGraphConv(self.gcn_dim, self.embed_dim, self.num_rel * 2, "bdd", num_bases=self.num_base,
+ activation=self.act, self_loop=(not wsi), dropout=gcn_drop,
+ wni=wni) if n_layer == 2 else None
+ else:
+ self.conv1 = RelGraphConv(self.init_dim, self.gcn_dim, self.num_rel * 2, "bdd",
+ num_bases=self.num_base, activation=self.act, self_loop=(not wsi),
+ dropout=gcn_drop, wni=wni)
+ self.conv2 = RelGraphConv(self.gcn_dim, self.gcn_dim, self.num_rel * 2, "bdd",
+ num_bases=self.num_base, activation=self.act, self_loop=(not wsi),
+ dropout=gcn_drop, wni=wni)
+ self.conv3 = RelGraphConv(self.gcn_dim, self.embed_dim, self.num_rel * 2, "bdd", num_bases=self.num_base,
+ activation=self.act, self_loop=(not wsi), dropout=gcn_drop,
+ wni=wni) if n_layer == 2 else None
+ elif encoder == 'gcn':
+ self.conv1 = GraphConv(self.init_dim, self.gcn_dim, allow_zero_in_degree=True)
+ self.conv2 = GraphConv(self.gcn_dim, self.gcn_dim, allow_zero_in_degree=True)
+
+ self.bias = nn.Parameter(torch.zeros(self.num_ent))
+ # self.bias_rel = nn.Parameter(torch.zeros(self.num_rel*2))
+
+ def get_param(self, shape):
+ param = nn.Parameter(torch.Tensor(*shape))
+ nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu'))
+ return param
+
+ def calc_loss(self, pred, label, pos_neg=None):
+ if pos_neg is not None:
+ m = nn.Sigmoid()
+ score_pos = m(pred)
+ targets_pos = pos_neg.unsqueeze(1)
+ loss = self.loss(score_pos, label * targets_pos)
+ return torch.sum(loss * label)
+ return self.loss(pred, label)
+
+ def forward_base(self, g, subj, rel, drop1, drop2):
+ """
+ :param g: graph
+ :param sub: subjects in a batch [batch]
+ :param rel: relations in a batch [batch]
+ :param drop1: dropout rate in first layer
+ :param drop2: dropout rate in second layer
+ :return: sub_emb: [batch, D]
+ rel_emb: [num_rel*2, D]
+ x: [num_ent, D]
+ """
+ x, r = self.init_embed, self.init_rel # embedding of relations
+
+ if self.n_layer > 0:
+ if self.encoder == 'compgcn':
+ if self.n_layer < 3:
+ x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(
+ g, x, r, self.edge_type, self.edge_norm) if self.n_layer == 2 else (x, r)
+ x = drop2(x) if self.n_layer == 2 else x
+ else:
+ x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(g, x, r, self.edge_type, self.edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv3(g, x, r, self.edge_type, self.edge_norm)
+ x = drop2(x)
+ elif self.encoder == 'rgcn':
+ if self.n_layer < 3:
+ x = self.conv1(g, x, self.edge_type, self.edge_norm.unsqueeze(-1))
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x = self.conv2(
+ g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x
+ x = drop2(x) if self.n_layer == 2 else x
+ else:
+ x = self.conv1(g, x, self.edge_type, self.edge_norm.unsqueeze(-1))
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x = self.conv2(g, x, self.edge_type, self.edge_norm.unsqueeze(-1))
+ x = drop1(x)
+ x = self.conv3(g, x, self.edge_type, self.edge_norm.unsqueeze(-1))
+ x = drop2(x)
+
+ # filter out embeddings of subjects in this batch
+ sub_emb = torch.index_select(x, 0, subj)
+ # filter out embeddings of relations in this batch
+ rel_emb = torch.index_select(r, 0, rel)
+
+ return sub_emb, rel_emb, x
+
+ def forward_base_search(self, g, subj, rel, drop1, drop2):
+ """
+ :param g: graph
+ :param sub: subjects in a batch [batch]
+ :param rel: relations in a batch [batch]
+ :param drop1: dropout rate in first layer
+ :param drop2: dropout rate in second layer
+ :return: sub_emb: [batch, D]
+ rel_emb: [num_rel*2, D]
+ x: [num_ent, D]
+ """
+ x, r = self.init_embed, self.init_rel # embedding of relations
+ x_hidden = []
+ if self.n_layer > 0:
+ if self.encoder == 'compgcn':
+ if self.n_layer < 3:
+ x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm)
+ x_hidden.append(x)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(
+ g, x, r, self.edge_type, self.edge_norm) if self.n_layer == 2 else (x, r)
+ x_hidden.append(x)
+ x = drop2(x) if self.n_layer == 2 else x
+ else:
+ x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm)
+ x_hidden.append(x)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(g, x, r, self.edge_type, self.edge_norm)
+ x_hidden.append(x)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv3(g, x, r, self.edge_type, self.edge_norm)
+ x_hidden.append(x)
+ x = drop2(x)
+ elif self.encoder == 'rgcn':
+ x = self.conv1(g, x, self.edge_type,
+ self.edge_norm.unsqueeze(-1))
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x = self.conv2(
+ g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x
+ x = drop2(x) if self.n_layer == 2 else x
+
+ # filter out embeddings of subjects in this batch
+ sub_emb = torch.index_select(x, 0, subj)
+ # filter out embeddings of relations in this batch
+ rel_emb = torch.index_select(r, 0, rel)
+
+ x_hidden = torch.stack(x_hidden, dim=1)
+
+ return x_hidden
+
+ def forward_base_rel(self, g, subj, obj, drop1, drop2):
+ """
+ :param g: graph
+ :param sub: subjects in a batch [batch]
+ :param rel: relations in a batch [batch]
+ :param drop1: dropout rate in first layer
+ :param drop2: dropout rate in second layer
+ :return: sub_emb: [batch, D]
+ rel_emb: [num_rel*2, D]
+ x: [num_ent, D]
+ """
+ x, r = self.init_embed, self.init_rel # embedding of relations
+
+ if self.n_layer > 0:
+ if self.encoder == 'compgcn':
+ if self.n_layer < 3:
+ x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(
+ g, x, r, self.edge_type, self.edge_norm) if self.n_layer == 2 else (x, r)
+ x = drop2(x) if self.n_layer == 2 else x
+ else:
+ x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(g, x, r, self.edge_type, self.edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv3(g, x, r, self.edge_type, self.edge_norm)
+ x = drop2(x)
+ elif self.encoder == 'rgcn':
+ x = self.conv1(g, x, self.edge_type,
+ self.edge_norm.unsqueeze(-1))
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x = self.conv2(
+ g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x
+ x = drop2(x) if self.n_layer == 2 else x
+
+ # filter out embeddings of subjects in this batch
+ sub_emb = torch.index_select(x, 0, subj)
+ # filter out embeddings of objects in this batch
+ obj_emb = torch.index_select(x, 0, obj)
+
+ return sub_emb, obj_emb, x, r
+
+ def forward_base_rel_vis_hop(self, g, subj, obj, drop1, drop2):
+ """
+ :param g: graph
+ :param sub: subjects in a batch [batch]
+ :param rel: relations in a batch [batch]
+ :param drop1: dropout rate in first layer
+ :param drop2: dropout rate in second layer
+ :return: sub_emb: [batch, D]
+ rel_emb: [num_rel*2, D]
+ x: [num_ent, D]
+ """
+ x, r = self.init_embed, self.init_rel # embedding of relations
+ x_hidden = []
+
+ if self.n_layer > 0:
+ if self.encoder == 'compgcn':
+ if self.n_layer < 3:
+ x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm)
+ x_hidden.append(x)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(
+ g, x, r, self.edge_type, self.edge_norm) if self.n_layer == 2 else (x, r)
+ x_hidden.append(x)
+ x = drop2(x) if self.n_layer == 2 else x
+ else:
+ x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm)
+ x_hidden.append(x)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(g, x, r, self.edge_type, self.edge_norm)
+ x_hidden.append(x)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv3(g, x, r, self.edge_type, self.edge_norm)
+ x_hidden.append(x)
+ x = drop2(x)
+ elif self.encoder == 'rgcn':
+ x = self.conv1(g, x, self.edge_type,
+ self.edge_norm.unsqueeze(-1))
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x = self.conv2(
+ g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x
+ x = drop2(x) if self.n_layer == 2 else x
+
+ # filter out embeddings of subjects in this batch
+ sub_emb = torch.index_select(x, 0, subj)
+ # filter out embeddings of objects in this batch
+ obj_emb = torch.index_select(x, 0, obj)
+ x_hidden = torch.stack(x_hidden, dim=1)
+
+ return sub_emb, obj_emb, x, r, x_hidden
+
+ def forward_base_rel_search(self, g, drop1, drop2):
+ """
+ :param g: graph
+ :param sub: subjects in a batch [batch]
+ :param rel: relations in a batch [batch]
+ :param drop1: dropout rate in first layer
+ :param drop2: dropout rate in second layer
+ :return: sub_emb: [batch, D]
+ rel_emb: [num_rel*2, D]
+ x: [num_ent, D]
+ """
+ x, r = self.init_embed, self.init_rel # embedding of relations
+ x_hidden = []
+ if self.n_layer > 0:
+ if self.encoder == 'compgcn':
+ if self.n_layer < 3:
+ x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x_hidden.append(x)
+ x, r = self.conv2(
+ g, x, r, self.edge_type, self.edge_norm) if self.n_layer == 2 else (x, r)
+ x = drop2(x) if self.n_layer == 2 else x
+ x_hidden.append(x)
+ else:
+ x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x_hidden.append(x)
+ x, r = self.conv2(g, x, r, self.edge_type, self.edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x_hidden.append(x)
+ x, r = self.conv3(g, x, r, self.edge_type, self.edge_norm)
+ x = drop2(x)
+ x_hidden.append(x)
+ elif self.encoder == 'rgcn':
+ x = self.conv1(g, x, self.edge_type,
+ self.edge_norm.unsqueeze(-1))
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x = self.conv2(
+ g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x
+ x = drop2(x) if self.n_layer == 2 else x
+
+ x_hidden = torch.stack(x_hidden, dim=1)
+
+ return x_hidden, x
+
+ def forward_base_rel_subgraph(self, bg, subj, obj, drop1, drop2):
+ """
+ :param g: graph
+ :param sub: subjects in a batch [batch]
+ :param rel: relations in a batch [batch]
+ :param drop1: dropout rate in first layer
+ :param drop2: dropout rate in second layer
+ :return: sub_emb: [batch, D]
+ rel_emb: [num_rel*2, D]
+ x: [num_ent, D]
+ """
+ x, r = self.init_embed[bg.ndata[NID]], self.init_rel # embedding of relations
+ edge_type = self.edge_type[bg.edata[EID]]
+ edge_norm = self.edge_norm[bg.edata[EID]]
+ if self.n_layer > 0:
+ if self.encoder == 'compgcn':
+ if self.n_layer < 3:
+ x, r = self.conv1(bg, x, r, edge_type, edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(
+ bg, x, r, edge_type, edge_norm) if self.n_layer == 2 else (x, r)
+ x = drop2(x) if self.n_layer == 2 else x
+ else:
+ x, r = self.conv1(bg, x, r, edge_type, edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(bg, x, r, edge_type, edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv3(bg, x, r, edge_type, edge_norm)
+ x = drop2(x)
+ elif self.encoder == 'rgcn':
+ x = self.conv1(bg, x, self.edge_type,
+ self.edge_norm.unsqueeze(-1))
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x = self.conv2(
+ bg, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x
+ x = drop2(x) if self.n_layer == 2 else x
+
+ bg.ndata['h'] = x
+ sub_list = []
+ obj_list = []
+ for idx, g in enumerate(dgl.unbatch(bg)):
+ head_idx = torch.where(g.ndata[NID] == subj[idx])[0]
+ tail_idx = torch.where(g.ndata[NID] == obj[idx])[0]
+ head_emb = g.ndata['h'][head_idx]
+ tail_emb = g.ndata['h'][tail_idx]
+ sub_list.append(head_emb)
+ obj_list.append(tail_emb)
+ sub_emb = torch.cat(sub_list, dim=0)
+ obj_emb = torch.cat(obj_list, dim=0)
+ # # filter out embeddings of subjects in this batch
+ # sub_emb = torch.index_select(x, 0, subj)
+ # # filter out embeddings of objects in this batch
+ # obj_emb = torch.index_select(x, 0, obj)
+
+ return sub_emb, obj_emb, x, r
+
+ def forward_base_rel_subgraph_trans(self, bg, input_ids, drop1, drop2):
+ """
+ :param g: graph
+ :param sub: subjects in a batch [batch]
+ :param rel: relations in a batch [batch]
+ :param drop1: dropout rate in first layer
+ :param drop2: dropout rate in second layer
+ :return: sub_emb: [batch, D]
+ rel_emb: [num_rel*2, D]
+ x: [num_ent, D]
+ """
+ x, r = self.init_embed[bg.ndata[NID]], self.init_rel # embedding of relations
+ edge_type = self.edge_type[bg.edata[EID]]
+ edge_norm = self.edge_norm[bg.edata[EID]]
+ if self.n_layer > 0:
+ if self.encoder == 'compgcn':
+ if self.n_layer < 3:
+ x, r = self.conv1(bg, x, r, edge_type, edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(
+ bg, x, r, edge_type, edge_norm) if self.n_layer == 2 else (x, r)
+ x = drop2(x) if self.n_layer == 2 else x
+ else:
+ x, r = self.conv1(bg, x, r, edge_type, edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(bg, x, r, edge_type, edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv3(bg, x, r, edge_type, edge_norm)
+ x = drop2(x)
+ elif self.encoder == 'rgcn':
+ x = self.conv1(bg, x, self.edge_type,
+ self.edge_norm.unsqueeze(-1))
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x = self.conv2(
+ bg, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x
+ x = drop2(x) if self.n_layer == 2 else x
+
+ bg.ndata['h'] = x
+ input_emb = self.init_embed[input_ids]
+ for idx, g in enumerate(dgl.unbatch(bg)):
+ # print(g.ndata['h'].size())
+ # print(input_emb[idx][:].size())
+ input_emb[idx][1:g.num_nodes() + 1] = g.ndata['h']
+ # sub_list = []
+ # obj_list = []
+ # for idx, g in enumerate(dgl.unbatch(bg)):
+ # head_idx = torch.where(g.ndata[NID] == subj[idx])[0]
+ # tail_idx = torch.where(g.ndata[NID] == obj[idx])[0]
+ # head_emb = g.ndata['h'][head_idx]
+ # tail_emb = g.ndata['h'][tail_idx]
+ # sub_list.append(head_emb)
+ # obj_list.append(tail_emb)
+ # sub_emb = torch.cat(sub_list,dim=0)
+ # obj_emb = torch.cat(obj_list,dim=0)
+ # # filter out embeddings of subjects in this batch
+ # sub_emb = torch.index_select(x, 0, subj)
+ # # filter out embeddings of objects in this batch
+ # obj_emb = torch.index_select(x, 0, obj)
+
+ return input_emb
+
+ def forward_base_rel_subgraph_trans_new(self, bg, drop1, drop2):
+ """
+ :param g: graph
+ :param sub: subjects in a batch [batch]
+ :param rel: relations in a batch [batch]
+ :param drop1: dropout rate in first layer
+ :param drop2: dropout rate in second layer
+ :return: sub_emb: [batch, D]
+ rel_emb: [num_rel*2, D]
+ x: [num_ent, D]
+ """
+ x, r = self.init_embed[bg.ndata[NID]], self.init_rel # embedding of relations
+ # print(bg.ndata[NID])
+ # print(self.edge_type.size())
+ # exit(0)
+ edge_type = self.edge_type[bg.edata[EID]]
+ edge_norm = self.edge_norm[bg.edata[EID]]
+ if self.n_layer > 0:
+ if self.encoder == 'compgcn':
+ if self.n_layer < 3:
+ # print(bg)
+ x, r = self.conv1(bg, x, r, edge_type, edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(
+ bg, x, r, edge_type, edge_norm) if self.n_layer == 2 else (x, r)
+ x = drop2(x) if self.n_layer == 2 else x
+ else:
+ x, r = self.conv1(bg, x, r, edge_type, edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(bg, x, r, edge_type, edge_norm)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv3(bg, x, r, edge_type, edge_norm)
+ x = drop2(x)
+ elif self.encoder == 'rgcn':
+ x = self.conv1(bg, x, self.edge_type,
+ self.edge_norm.unsqueeze(-1))
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x = self.conv2(
+ bg, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x
+ x = drop2(x) if self.n_layer == 2 else x
+
+ bg.ndata['h'] = x
+
+ return bg, r
+
+ def forward_base_no_transform(self, subj, rel):
+ x, r = self.init_embed, self.init_rel
+ sub_emb = torch.index_select(x, 0, subj)
+ rel_emb = torch.index_select(r, 0, rel)
+
+ return sub_emb, rel_emb, x
+
+ def forward_base_subgraph_search(self, bg, drop1, drop2):
+ """
+ :param g: graph
+ :param sub: subjects in a batch [batch]
+ :param rel: relations in a batch [batch]
+ :param drop1: dropout rate in first layer
+ :param drop2: dropout rate in second layer
+ :return: sub_emb: [batch, D]
+ rel_emb: [num_rel*2, D]
+ x: [num_ent, D]
+ """
+ x, r = self.init_embed[bg.ndata[NID]], self.init_rel # embedding of relations
+ edge_type = self.edge_type[bg.edata[EID]]
+ edge_norm = self.edge_norm[bg.edata[EID]]
+ x_hidden = []
+ if self.n_layer > 0:
+ if self.encoder == 'compgcn':
+ if self.n_layer < 3:
+ x, r = self.conv1(bg, x, r, edge_type, edge_norm)
+ x_hidden.append(x)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(
+ bg, x, r, edge_type, edge_norm) if self.n_layer == 2 else (x, r)
+ x_hidden.append(x)
+ x = drop2(x) if self.n_layer == 2 else x
+ else:
+ x, r = self.conv1(bg, x, r, edge_type, edge_norm)
+ x_hidden.append(x)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv2(bg, x, r, edge_type, edge_norm)
+ x_hidden.append(x)
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x, r = self.conv3(bg, x, r, edge_type, edge_norm)
+ x_hidden.append(x)
+ x = drop2(x)
+ elif self.encoder == 'rgcn':
+ x = self.conv1(bg, x, self.edge_type,
+ self.edge_norm.unsqueeze(-1))
+ x = drop1(x) # embeddings of entities [num_ent, dim]
+ x = self.conv2(
+ bg, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x
+ x = drop2(x) if self.n_layer == 2 else x
+
+ x_hidden = torch.stack(x_hidden, dim=1)
+
+ return x_hidden
+
+
+
+class GCN_TransE(GCNs):
+ def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ bias=True, gcn_drop=0., opn='mult', hid_drop=0., gamma=9., wni=False, wsi=False, encoder='compgcn',
+ use_bn=True, ltr=True):
+ super(GCN_TransE, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer,
+ edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, use_bn, ltr)
+ self.drop = nn.Dropout(hid_drop)
+ self.gamma = gamma
+
+ def forward(self, g, subj, rel):
+ """
+ :param g: dgl graph
+ :param sub: subject in batch [batch_size]
+ :param rel: relation in batch [batch_size]
+ :return: score: [batch_size, ent_num], the prob in link-prediction
+ """
+ sub_emb, rel_emb, all_ent = self.forward_base(
+ g, subj, rel, self.drop, self.drop)
+ obj_emb = sub_emb + rel_emb
+
+ x = self.gamma - torch.norm(obj_emb.unsqueeze(1) - all_ent, p=1, dim=2)
+
+ score = torch.sigmoid(x)
+
+ return score
+
+
+class GCN_DistMult(GCNs):
+ def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ bias=True, gcn_drop=0., opn='mult', hid_drop=0., wni=False, wsi=False, encoder='compgcn', use_bn=True,
+ ltr=True):
+ super(GCN_DistMult, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer,
+ edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, use_bn, ltr)
+ self.drop = nn.Dropout(hid_drop)
+
+ def forward(self, g, subj, rel):
+ """
+ :param g: dgl graph
+ :param sub: subject in batch [batch_size]
+ :param rel: relation in batch [batch_size]
+ :return: score: [batch_size, ent_num], the prob in link-prediction
+ """
+ sub_emb, rel_emb, all_ent = self.forward_base(
+ g, subj, rel, self.drop, self.drop)
+ obj_emb = sub_emb * rel_emb # [batch_size, emb_dim]
+ x = torch.mm(obj_emb, all_ent.transpose(1, 0)) # [batch_size, ent_num]
+ x += self.bias.expand_as(x)
+ score = torch.sigmoid(x)
+ return score
+
+
+class GCN_ConvE(GCNs):
+ def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ bias=True, gcn_drop=0., opn='mult', hid_drop=0., input_drop=0., conve_hid_drop=0., feat_drop=0.,
+ num_filt=None, ker_sz=None, k_h=None, k_w=None, wni=False, wsi=False, encoder='compgcn', use_bn=True,
+ ltr=True):
+ """
+ :param num_ent: number of entities
+ :param num_rel: number of different relations
+ :param num_base: number of bases to use
+ :param init_dim: initial dimension
+ :param gcn_dim: dimension after first layer
+ :param embed_dim: dimension after second layer
+ :param n_layer: number of layer
+ :param edge_type: relation type of each edge, [E]
+ :param bias: weather to add bias
+ :param gcn_drop: dropout rate in compgcncov
+ :param opn: combination operator
+ :param hid_drop: gcn output (embedding of each entity) dropout
+ :param input_drop: dropout in conve input
+ :param conve_hid_drop: dropout in conve hidden layer
+ :param feat_drop: feature dropout in conve
+ :param num_filt: number of filters in conv2d
+ :param ker_sz: kernel size in conv2d
+ :param k_h: height of 2D reshape
+ :param k_w: width of 2D reshape
+ """
+ super(GCN_ConvE, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer,
+ edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, use_bn, ltr)
+ self.hid_drop, self.input_drop, self.conve_hid_drop, self.feat_drop = hid_drop, input_drop, conve_hid_drop, feat_drop
+ self.num_filt = num_filt
+ self.ker_sz, self.k_w, self.k_h = ker_sz, k_w, k_h
+
+ # one channel, do bn on initial embedding
+ self.bn0 = torch.nn.BatchNorm2d(1)
+ self.bn1 = torch.nn.BatchNorm2d(
+ self.num_filt) # do bn on output of conv
+ self.bn2 = torch.nn.BatchNorm1d(self.embed_dim)
+
+ self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout
+ self.input_drop = torch.nn.Dropout(
+ self.input_drop) # stacked input dropout
+ self.feature_drop = torch.nn.Dropout(
+ self.feat_drop) # feature map dropout
+ self.hidden_drop = torch.nn.Dropout(
+ self.conve_hid_drop) # hidden layer dropout
+
+ self.conv2d = torch.nn.Conv2d(in_channels=1, out_channels=self.num_filt,
+ kernel_size=(self.ker_sz, self.ker_sz), stride=1, padding=0, bias=bias)
+
+ flat_sz_h = int(2 * self.k_h) - self.ker_sz + 1 # height after conv
+ flat_sz_w = self.k_w - self.ker_sz + 1 # width after conv
+ self.flat_sz = flat_sz_h * flat_sz_w * self.num_filt
+ # fully connected projection
+ self.fc = torch.nn.Linear(self.flat_sz, self.embed_dim)
+ self.cat_type = 'multii'
+
+ def concat(self, ent_embed, rel_embed):
+ """
+ :param ent_embed: [batch_size, embed_dim]
+ :param rel_embed: [batch_size, embed_dim]
+ :return: stack_input: [B, C, H, W]
+ """
+ ent_embed = ent_embed.view(-1, 1, self.embed_dim)
+ rel_embed = rel_embed.view(-1, 1, self.embed_dim)
+ # [batch_size, 2, embed_dim]
+ stack_input = torch.cat([ent_embed, rel_embed], 1)
+
+ assert self.embed_dim == self.k_h * self.k_w
+ # reshape to 2D [batch, 1, 2*k_h, k_w]
+ stack_input = stack_input.reshape(-1, 1, 2 * self.k_h, self.k_w)
+ return stack_input
+
+ def forward(self, g, subj, rel):
+ """
+ :param g: dgl graph
+ :param sub: subject in batch [batch_size]
+ :param rel: relation in batch [batch_size]
+ :return: score: [batch_size, ent_num], the prob in link-prediction
+ """
+ sub_emb, rel_emb, all_ent = self.forward_base(
+ g, subj, rel, self.drop, self.input_drop)
+ # [batch_size, 1, 2*k_h, k_w]
+ stack_input = self.concat(sub_emb, rel_emb)
+ x = self.bn0(stack_input)
+ x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+ x = self.bn1(x)
+ x = F.relu(x)
+ x = self.feature_drop(x)
+ x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+ x = self.fc(x) # [batch_size, embed_dim]
+ x = self.hidden_drop(x)
+ x = self.bn2(x)
+ x = F.relu(x)
+ x = torch.mm(x, all_ent.transpose(1, 0)) # [batch_size, ent_num]
+ x += self.bias.expand_as(x)
+ score = torch.sigmoid(x)
+ return score
+
+ def forward_search(self, g, subj, rel):
+ sub_emb, rel_emb, all_ent, hidden_all_ent = self.forward_base_search(
+ g, subj, rel, self.drop, self.input_drop)
+
+ return hidden_all_ent
+
+ def compute_pred(self, hidden_x, subj, obj):
+ # raise NotImplementedError
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ print(h.size())
+
+ def cross_pair(self, x_i, x_all):
+ x = []
+ # print(x_i.size())
+ # print(x_all.size())
+ x_all = x_all.repeat(x_i.size(0), 1, 1, 1)
+ # print(x_all.size())
+ for i in range(self.n_layer):
+ for j in range(self.n_layer):
+ if self.cat_type == 'multi':
+ # print(x_i[:, i, :].size())
+ # print(x_all[:,:,j,:].size())
+ test = x_i[:, i, :].unsqueeze(1) * x_all[:, :, j, :]
+ # print(test.size())
+ x.append(test)
+ else:
+ test = torch.cat([x_i[:, i, :].unsqueeze(1), x_all[:, :, j, :]], dim=1)
+ print(test.size())
+ x.append(test)
+ x = torch.stack(x, dim=1)
+ return x
+
+
+class GCN_ConvE_Rel(GCNs):
+ def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ bias=True, gcn_drop=0., opn='mult', hid_drop=0., input_drop=0., conve_hid_drop=0., feat_drop=0.,
+ num_filt=None, ker_sz=None, k_h=None, k_w=None, wni=False, wsi=False, encoder='compgcn', use_bn=True,
+ ltr=True, input_type='subgraph'):
+ """
+ :param num_ent: number of entities
+ :param num_rel: number of different relations
+ :param num_base: number of bases to use
+ :param init_dim: initial dimension
+ :param gcn_dim: dimension after first layer
+ :param embed_dim: dimension after second layer
+ :param n_layer: number of layer
+ :param edge_type: relation type of each edge, [E]
+ :param bias: weather to add bias
+ :param gcn_drop: dropout rate in compgcncov
+ :param opn: combination operator
+ :param hid_drop: gcn output (embedding of each entity) dropout
+ :param input_drop: dropout in conve input
+ :param conve_hid_drop: dropout in conve hidden layer
+ :param feat_drop: feature dropout in conve
+ :param num_filt: number of filters in conv2d
+ :param ker_sz: kernel size in conv2d
+ :param k_h: height of 2D reshape
+ :param k_w: width of 2D reshape
+ """
+ super(GCN_ConvE_Rel, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer,
+ edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, use_bn, ltr,
+ input_type)
+ self.hid_drop, self.input_drop, self.conve_hid_drop, self.feat_drop = hid_drop, input_drop, conve_hid_drop, feat_drop
+ self.num_filt = num_filt
+ self.ker_sz, self.k_w, self.k_h = ker_sz, k_w, k_h
+ self.n_layer = n_layer
+
+ # one channel, do bn on initial embedding
+ self.bn0 = torch.nn.BatchNorm2d(1)
+ self.bn1 = torch.nn.BatchNorm2d(
+ self.num_filt) # do bn on output of conv
+ self.bn2 = torch.nn.BatchNorm1d(self.embed_dim)
+
+ self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout
+ self.input_drop = torch.nn.Dropout(
+ self.input_drop) # stacked input dropout
+ self.feature_drop = torch.nn.Dropout(
+ self.feat_drop) # feature map dropout
+ self.hidden_drop = torch.nn.Dropout(
+ self.conve_hid_drop) # hidden layer dropout
+
+ self.conv2d = torch.nn.Conv2d(in_channels=1, out_channels=self.num_filt,
+ kernel_size=(self.ker_sz, self.ker_sz), stride=1, padding=0, bias=bias)
+
+ flat_sz_h = int(2 * self.k_h) - self.ker_sz + 1 # height after conv
+ flat_sz_w = self.k_w - self.ker_sz + 1 # width after conv
+ self.flat_sz = flat_sz_h * flat_sz_w * self.num_filt
+ # fully connected projection
+ self.fc = torch.nn.Linear(self.flat_sz, self.embed_dim)
+ self.combine_type = 'concat'
+
+ def concat(self, ent_embed, rel_embed):
+ """
+ :param ent_embed: [batch_size, embed_dim]
+ :param rel_embed: [batch_size, embed_dim]
+ :return: stack_input: [B, C, H, W]
+ """
+ ent_embed = ent_embed.view(-1, 1, self.embed_dim)
+ rel_embed = rel_embed.view(-1, 1, self.embed_dim)
+ # [batch_size, 2, embed_dim]
+ stack_input = torch.cat([ent_embed, rel_embed], 1)
+
+ assert self.embed_dim == self.k_h * self.k_w
+ # reshape to 2D [batch, 1, 2*k_h, k_w]
+ stack_input = stack_input.reshape(-1, 1, 2 * self.k_h, self.k_w)
+ return stack_input
+
+ def forward(self, g, subj, obj):
+ """
+ :param g: dgl graph
+ :param sub: subject in batch [batch_size]
+ :param rel: relation in batch [batch_size]
+ :return: score: [batch_size, ent_num], the prob in link-prediction
+ """
+ if self.input_type == 'subgraph':
+ sub_emb, obj_emb, all_ent, all_rel = self.forward_base_rel_subgraph(g, subj, obj, self.drop,
+ self.input_drop)
+ else:
+ sub_emb, obj_emb, all_ent, all_rel = self.forward_base_rel(g, subj, obj, self.drop, self.input_drop)
+ # [batch_size, 1, 2*k_h, k_w]
+ stack_input = self.concat(sub_emb, obj_emb)
+ x = self.bn0(stack_input)
+ x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+ x = self.bn1(x)
+ x = F.relu(x)
+ x = self.feature_drop(x)
+ x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+ x = self.fc(x) # [batch_size, embed_dim]
+ x = self.hidden_drop(x)
+ x = self.bn2(x)
+ x = F.relu(x)
+ x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num]
+ x += self.bias_rel.expand_as(x)
+ # score = torch.sigmoid(x)
+ score = x
+ return score
+
+ def forward_search(self, g, subj, obj):
+ hidden_all_ent, all_rel = self.forward_base_rel_search(
+ g, subj, obj, self.drop, self.input_drop)
+
+ return hidden_all_ent, all_rel
+
+ def compute_pred(self, hidden_x, all_rel, subj, obj, subgraph_sampler, mode='search'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = subgraph_sampler(h, mode)
+ # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # print(subgraph_sampler(h,mode='argmax'))
+ n, c = atten_matrix.shape
+ h = h * atten_matrix.view(n, c, 1)
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ h = torch.sum(h, dim=1)
+ # print(h.size()) # [batch_size, 2*dim]
+ h = h.reshape(-1, 1, 2 * self.k_h, self.k_w)
+ x = self.bn0(h)
+ x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+ x = self.bn1(x)
+ x = F.relu(x)
+ x = self.feature_drop(x)
+ x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+ x = self.fc(x) # [batch_size, embed_dim]
+ x = self.hidden_drop(x)
+ x = self.bn2(x)
+ x = F.relu(x)
+ x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num]
+ x += self.bias_rel.expand_as(x)
+ # score = torch.sigmoid(x)
+ score = x
+ return score
+ # print(h.size())
+
+ def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0')
+ for i in range(h.size(0)):
+ atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1
+ # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # print(subgraph_sampler(h,mode='argmax'))
+ n, c = atten_matrix.shape
+ h = h * atten_matrix.view(n, c, 1)
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ h = torch.sum(h, dim=1)
+ # print(h.size()) # [batch_size, 2*dim]
+ h = h.reshape(-1, 1, 2 * self.k_h, self.k_w)
+ x = self.bn0(h)
+ x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+ x = self.bn1(x)
+ x = F.relu(x)
+ x = self.feature_drop(x)
+ x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+ x = self.fc(x) # [batch_size, embed_dim]
+ x = self.hidden_drop(x)
+ x = self.bn2(x)
+ x = F.relu(x)
+ x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num]
+ x += self.bias_rel.expand_as(x)
+ # score = torch.sigmoid(x)
+ score = x
+ return score
+ # print(h.size())
+
+ def cross_pair(self, x_i, x_j):
+ x = []
+ for i in range(self.n_layer):
+ for j in range(self.n_layer):
+ if self.combine_type == 'multi':
+ x.append(x_i[:, i, :] * x_j[:, j, :])
+ elif self.combine_type == 'concat':
+ x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1))
+ x = torch.stack(x, dim=1)
+ return x
+
+ def vis_hop_distribution(self, hidden_x, all_rel, subj, obj, subgraph_sampler, mode='search'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ atten_matrix = subgraph_sampler(h, mode)
+ return torch.sum(atten_matrix, dim=0)
+
+ def vis_hop_distribution_rs(self, hidden_x, all_rel, subj, obj, random_hops):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0')
+ for i in range(h.size(0)):
+ atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1
+ return torch.sum(atten_matrix, dim=0)
+
+
+class GCN_Transformer(GCNs):
+ def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ bias=True, gcn_drop=0., opn='mult', hid_drop=0., input_drop=0., conve_hid_drop=0., feat_drop=0.,
+ num_filt=None, ker_sz=None, k_h=None, k_w=None, wni=False, wsi=False, encoder='compgcn', use_bn=True,
+ ltr=True, input_type='subgraph',
+ d_model=100, num_transformer_layers=2, nhead=8, dim_feedforward=100, transformer_dropout=0.1,
+ transformer_activation='relu',
+ graph_pooling='cls', concat_type="gso", max_input_len=100, loss_type='ce'):
+ super(GCN_Transformer, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer,
+ edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, use_bn, ltr,
+ input_type, loss_type)
+
+ self.hid_drop, self.input_drop, self.conve_hid_drop, self.feat_drop = hid_drop, input_drop, conve_hid_drop, feat_drop
+ self.num_filt = num_filt
+ self.ker_sz, self.k_w, self.k_h = ker_sz, k_w, k_h
+
+ # one channel, do bn on initial embedding
+ self.bn0 = torch.nn.BatchNorm2d(1)
+ self.bn1 = torch.nn.BatchNorm2d(
+ self.num_filt) # do bn on output of conv
+ self.bn2 = torch.nn.BatchNorm1d(self.embed_dim)
+
+ self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout
+ self.input_drop = torch.nn.Dropout(
+ self.input_drop) # stacked input dropout
+ self.feature_drop = torch.nn.Dropout(
+ self.feat_drop) # feature map dropout
+ self.hidden_drop = torch.nn.Dropout(
+ self.conve_hid_drop) # hidden layer dropout
+
+ self.conv2d = torch.nn.Conv2d(in_channels=1, out_channels=self.num_filt,
+ kernel_size=(self.ker_sz, self.ker_sz), stride=1, padding=0, bias=bias)
+
+ flat_sz_h = int(2 * self.k_h) - self.ker_sz + 1 # height after conv
+ flat_sz_w = self.k_w - self.ker_sz + 1 # width after conv
+ self.flat_sz = flat_sz_h * flat_sz_w * self.num_filt
+ # fully connected projection
+ self.fc = torch.nn.Linear(self.flat_sz, self.embed_dim)
+
+ self.d_model = d_model
+ self.num_layer = num_transformer_layers
+ self.gnn2transformer = nn.Linear(gcn_dim, d_model)
+ # Creating Transformer Encoder Model
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, transformer_dropout, transformer_activation
+ )
+ encoder_norm = nn.LayerNorm(d_model)
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_transformer_layers, encoder_norm)
+ # self.max_input_len = args.max_input_len
+ self.norm_input = nn.LayerNorm(d_model)
+ self.graph_pooling = graph_pooling
+ self.max_input_len = max_input_len
+ self.concat_type = concat_type
+ self.cls_embedding = None
+ if self.graph_pooling == "cls":
+ self.cls_embedding = nn.Parameter(torch.randn([1, 1, self.d_model], requires_grad=True))
+ if self.concat_type == "gso":
+ self.fc1 = nn.Linear(d_model * 3, 256)
+ elif self.concat_type == "so":
+ self.fc1 = nn.Linear(d_model * 2, 256)
+ elif self.concat_type == "g":
+ self.fc1 = nn.Linear(d_model, 256)
+ self.fc2 = nn.Linear(256, num_rel * 2)
+
+ def forward(self, bg, input_ids, subj, obj):
+ # tokens_emb = self.forward_base_rel_subgraph_trans(bg,input_ids,self.drop,self.input_drop)
+ # tokens_emb = tokens_emb.permute(1,0,2) # (s, b, d)
+ # tokens_emb = self.gnn2transformer(tokens_emb)
+ # padding_mask = get_pad_mask(input_ids, self.num_ent+1)
+ # tokens_emb = self.norm_input(tokens_emb)
+ # transformer_out = self.transformer(tokens_emb, src_key_padding_mask=padding_mask)
+ # h_graph = transformer_out[0]
+ # output = self.fc_out(h_graph)
+ # output = torch.sigmoid(output)
+ # return output
+
+ bg, all_rel = self.forward_base_rel_subgraph_trans_new(bg, self.drop, self.input_drop)
+ batch_size = bg.batch_size
+ h_node = self.gnn2transformer(bg.ndata['h'])
+ padded_h_node, src_padding_mask, subj_idx, obj_idx = pad_batch(h_node, bg, self.max_input_len, subj, obj)
+ if self.cls_embedding is not None:
+ expand_cls_embedding = self.cls_embedding.expand(1, padded_h_node.size(1), -1)
+ padded_h_node = torch.cat([padded_h_node, expand_cls_embedding], dim=0)
+ zeros = src_padding_mask.data.new(src_padding_mask.size(0), 1).fill_(0)
+ src_padding_mask = torch.cat([src_padding_mask, zeros], dim=1)
+ padded_h_node = self.norm_input(padded_h_node)
+ transformer_out = self.transformer(padded_h_node, src_key_padding_mask=src_padding_mask)
+ subj_emb_list = []
+ obj_emb_list = []
+ for batch_idx in range(batch_size):
+ subj_emb = transformer_out[subj_idx[batch_idx], batch_idx, :]
+ obj_emb = transformer_out[obj_idx[batch_idx], batch_idx, :]
+ subj_emb_list.append(subj_emb)
+ obj_emb_list.append(obj_emb)
+ subj_emb = torch.stack(subj_emb_list)
+ obj_emb = torch.stack(obj_emb_list)
+ if self.graph_pooling in ["last", "cls"]:
+ h_graph = transformer_out[-1]
+ if self.concat_type == "gso":
+ h_repr = torch.cat([h_graph, subj_emb, obj_emb], dim=1)
+ elif self.concat_type == "g":
+ h_repr = h_graph
+ else:
+ if self.concat_type == "so":
+ h_repr = torch.cat([subj_emb, obj_emb], dim=1)
+ h_repr = F.relu(self.fc1(h_repr))
+ score = self.fc2(h_repr)
+ # print(score.size())
+ # print(score)
+ # score = self.conve(subj_emb, obj_emb, all_rel)
+
+ return score
+
+ def conve_rel(self, sub_emb, obj_emb, all_rel):
+ stack_input = self.concat(sub_emb, obj_emb)
+ x = self.bn0(stack_input)
+ x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+ x = self.bn1(x)
+ x = F.relu(x)
+ x = self.feature_drop(x)
+ x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+ x = self.fc(x) # [batch_size, embed_dim]
+ x = self.hidden_drop(x)
+ x = self.bn2(x)
+ x = F.relu(x)
+ x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num]
+ x += self.bias_rel.expand_as(x)
+ score = torch.sigmoid(x)
+ return score
+
+ def concat(self, ent_embed, rel_embed):
+ """
+ :param ent_embed: [batch_size, embed_dim]
+ :param rel_embed: [batch_size, embed_dim]
+ :return: stack_input: [B, C, H, W]
+ """
+ ent_embed = ent_embed.view(-1, 1, self.embed_dim)
+ rel_embed = rel_embed.view(-1, 1, self.embed_dim)
+ # [batch_size, 2, embed_dim]
+ stack_input = torch.cat([ent_embed, rel_embed], 1)
+
+ assert self.embed_dim == self.k_h * self.k_w
+ # reshape to 2D [batch, 1, 2*k_h, k_w]
+ stack_input = stack_input.reshape(-1, 1, 2 * self.k_h, self.k_w)
+ return stack_input
+
+ def conve_ent(self, sub_emb, rel_emb, all_ent):
+ stack_input = self.concat(sub_emb, rel_emb)
+ x = self.bn0(stack_input)
+ x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+ x = self.bn1(x)
+ x = F.relu(x)
+ x = self.feature_drop(x)
+ x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+ x = self.fc(x) # [batch_size, embed_dim]
+ x = self.hidden_drop(x)
+ x = self.bn2(x)
+ x = F.relu(x)
+ x = torch.mm(x, all_ent.transpose(1, 0)) # [batch_size, ent_num]
+ x += self.bias.expand_as(x)
+ score = torch.sigmoid(x)
+ return score
+
+ def evaluate(self, subj, obj):
+ sub_emb, rel_emb, all_ent = self.forward_base_no_transform(subj, obj)
+ score = self.conve_ent(sub_emb, rel_emb, all_ent)
+ return score
+
+
+class GCN_None(GCNs):
+
+ def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ bias=True, gcn_drop=0., opn='mult', hid_drop=0., input_drop=0., conve_hid_drop=0., feat_drop=0.,
+ num_filt=None, ker_sz=None, k_h=None, k_w=None, wni=False, wsi=False, encoder='compgcn', use_bn=True,
+ ltr=True,
+ input_type='subgraph', graph_pooling='mean', concat_type='gso', loss_type='ce', add_reverse=True):
+ super(GCN_None, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer,
+ edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder, use_bn, ltr,
+ input_type, loss_type, add_reverse)
+
+ self.hid_drop, self.input_drop, self.conve_hid_drop, self.feat_drop = hid_drop, input_drop, conve_hid_drop, feat_drop
+ self.graph_pooling_type = graph_pooling
+ self.concat_type = concat_type
+ self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout
+ self.input_drop = torch.nn.Dropout(
+ self.input_drop) # stacked input dropout
+ if self.concat_type == "gso":
+ self.fc1 = nn.Linear(gcn_dim * 3, 256)
+ elif self.concat_type == "so":
+ self.fc1 = nn.Linear(gcn_dim * 2, 256)
+ elif self.concat_type == "g":
+ self.fc1 = nn.Linear(gcn_dim, 256)
+ if add_reverse:
+ self.fc2 = nn.Linear(256, num_rel * 2)
+ else:
+ self.fc2 = nn.Linear(256, num_rel)
+
+ def forward(self, g, subj, obj):
+ """
+ :param g: dgl graph
+ :param sub: subject in batch [batch_size]
+ :param rel: relation in batch [batch_size]
+ :return: score: [batch_size, ent_num], the prob in link-prediction
+ """
+ # start_time = time.time()
+ bg, all_rel = self.forward_base_rel_subgraph_trans_new(g, self.drop, self.input_drop)
+ # print(f'{time.time() - start_time:.2f}s')
+ h_repr_list = []
+ for batch_idx, g in enumerate(dgl.unbatch(bg)):
+ h_graph = readout_nodes(g, 'h', op=self.graph_pooling_type)
+ sub_idx = torch.where(g.ndata[NID] == subj[batch_idx])[0]
+ obj_idx = torch.where(g.ndata[NID] == obj[batch_idx])[0]
+ sub_emb = g.ndata['h'][sub_idx]
+ obj_emb = g.ndata['h'][obj_idx]
+ h_repr = torch.cat([h_graph, sub_emb, obj_emb], dim=1)
+ h_repr_list.append(h_repr)
+ h_repr = torch.stack(h_repr_list, dim=0).squeeze(1)
+ # print(f'{time.time() - start_time:.2f}s')
+ score = F.relu(self.fc1(h_repr))
+ score = self.fc2(score)
+
+ return score
+
+
+class GCN_MLP(GCNs):
+ def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ bias=True, gcn_drop=0., opn='mult', hid_drop=0., input_drop=0., conve_hid_drop=0., feat_drop=0.,
+ num_filt=None, ker_sz=None, k_h=None, k_w=None, wni=False, wsi=False, encoder='compgcn', use_bn=True,
+ ltr=True, input_type='subgraph', graph_pooling='mean', combine_type='mult', loss_type='ce',
+ add_reverse=True):
+ """
+ :param num_ent: number of entities
+ :param num_rel: number of different relations
+ :param num_base: number of bases to use
+ :param init_dim: initial dimension
+ :param gcn_dim: dimension after first layer
+ :param embed_dim: dimension after second layer
+ :param n_layer: number of layer
+ :param edge_type: relation type of each edge, [E]
+ :param bias: weather to add bias
+ :param gcn_drop: dropout rate in compgcncov
+ :param opn: combination operator
+ :param hid_drop: gcn output (embedding of each entity) dropout
+ :param input_drop: dropout in conve input
+ :param conve_hid_drop: dropout in conve hidden layer
+ :param feat_drop: feature dropout in conve
+ :param num_filt: number of filters in conv2d
+ :param ker_sz: kernel size in conv2d
+ :param k_h: height of 2D reshape
+ :param k_w: width of 2D reshape
+ """
+ super(GCN_MLP, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer,
+ edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder,
+ use_bn, ltr, input_type, loss_type, add_reverse)
+ self.hid_drop, self.input_drop = hid_drop, input_drop
+ if add_reverse:
+ self.num_rel = num_rel * 2
+ else:
+ self.num_rel = num_rel
+
+ self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout
+ self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout
+
+ self.combine_type = combine_type
+ if self.combine_type == 'concat':
+ self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel)
+ elif self.combine_type == 'mult':
+ self.fc = torch.nn.Linear(self.embed_dim, self.num_rel)
+
+ def forward(self, g, subj, obj):
+ """
+ :param g: dgl graph
+ :param sub: subject in batch [batch_size]
+ :param rel: relation in batch [batch_size]
+ :return: score: [batch_size, ent_num], the prob in link-prediction
+ """
+ if self.input_type == 'subgraph':
+ print('1')
+ # sub_emb, obj_emb, all_ent, all_rel = self.forward_base_rel_subgraph(g, subj, obj, self.drop, self.input_drop)
+ bg, all_rel = self.forward_base_rel_subgraph_trans_new(g, self.drop, self.input_drop)
+ sub_ids = (bg.ndata['id'] == 1).nonzero().squeeze(1)
+ sub_embs = bg.ndata['h'][sub_ids]
+ obj_ids = (bg.ndata['id'] == 2).nonzero().squeeze(1)
+ obj_embs = bg.ndata['h'][obj_ids]
+ h_graph = readout_nodes(bg, 'h', op=self.graph_pooling_type)
+ # edge_embs = torch.concat([sub_embs, obj_embs], dim=1)
+ # score = self.fc(edge_embs)
+ else:
+ sub_embs, obj_embs, all_ent, all_rel = self.forward_base_rel(g, subj, obj, self.drop, self.input_drop)
+ if self.combine_type == 'concat':
+ edge_embs = torch.concat([sub_embs, obj_embs], dim=1)
+ elif self.combine_type == 'mult':
+ edge_embs = sub_embs * obj_embs
+ else:
+ raise NotImplementedError
+ score = self.fc(edge_embs)
+ return score
+
+ def compute_vid_pred(self, hidden_x, subj, obj):
+ scores = []
+ scores_sigmoid = []
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ for i in range(h.size()[1]):
+ # print(h.size())
+ # exit(0)
+ edge_embs = h[:,i,:]
+ # if self.combine_type == 'concat':
+ # edge_embs = torch.concat([sub_embs, obj_embs], dim=1)
+ # elif self.combine_type == 'mult':
+ # edge_embs = sub_embs * obj_embs
+ # else:
+ # raise NotImplementedError
+ # print(edge_embs.size())
+ score = self.fc(edge_embs)
+ scores.append(score)
+ # score_sigmoid = torch.sigmoid(score)
+ # print(score.size())
+ # scores.append(torch.max(score[:6,:],1))
+ # scores_sigmoid.append(torch.max(score_sigmoid[:6,:],1))
+ # print(scores[-1])
+ return scores
+
+ def forward_search(self, g, mode='allgraph'):
+ # if mode == 'allgraph':
+ hidden_all_ent, all_ent = self.forward_base_rel_search(
+ g, self.drop, self.input_drop)
+ # elif mode == 'subgraph':
+ # hidden_all_ent = self.forward_base_subgraph_search(
+ # g, self.drop, self.input_drop)
+
+ return hidden_all_ent, all_ent
+
+ def compute_pred(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = subgraph_sampler(h, mode, search_algorithm)
+ # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # print(subgraph_sampler(h,mode='argmax'))
+ n, c = atten_matrix.shape
+ h = h * atten_matrix.view(n, c, 1)
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ h = torch.sum(h, dim=1)
+ # print(h.size()) # [batch_size, 2*dim]
+ score = self.fc(h)
+ return score
+
+ def compute_mix_hop_pred(self, hidden_x, subj, obj, hop_index):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ edge_embs = h[:,hop_index,:]
+ score = self.fc(edge_embs)
+ return score
+ # def fine_tune_with_implicit_subgraph(self, all_ent, subgraph, subj, obj):
+ # sg_list = []
+ # for idx in range(subgraph.size(0)):
+ # sg_list.append(torch.mean(all_ent[subgraph[idx,:]], dim=0).unsqueeze(0))
+ # sg_embs = torch.concat(sg_list)
+ # # print(sg_embs.size())
+ # sub_embs = torch.index_select(all_ent, 0, subj)
+ # # print(sub_embs.size())
+ # # filter out embeddings of relations in this batch
+ # obj_embs = torch.index_select(all_ent, 0, obj)
+ # # print(obj_embs.size())
+ # edge_embs = torch.concat([sub_embs, obj_embs, sg_embs], dim=1)
+ # score = self.predictor(edge_embs)
+ # # print(F.embedding(subgraph, all_ent))
+ # return score
+
+ # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops):
+ # h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0')
+ # for i in range(h.size(0)):
+ # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1
+ # # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # # print(subgraph_sampler(h,mode='argmax'))
+ # n, c = atten_matrix.shape
+ # h = h * atten_matrix.view(n,c,1)
+ # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ # h = torch.sum(h,dim=1)
+ # # print(h.size()) # [batch_size, 2*dim]
+ # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w)
+ # # x = self.bn0(h)
+ # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+ # # x = self.bn1(x)
+ # # x = F.relu(x)
+ # # x = self.feature_drop(x)
+ # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+ # # x = self.fc(x) # [batch_size, embed_dim]
+ # # x = self.hidden_drop(x)
+ # # x = self.bn2(x)
+ # # x = F.relu(x)
+ # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num]
+ # # x += self.bias_rel.expand_as(x)
+ # # # score = torch.sigmoid(x)
+ # # score = x
+ # return score
+ # print(h.size())
+
+ def cross_pair(self, x_i, x_j):
+ x = []
+ for i in range(self.n_layer):
+ for j in range(self.n_layer):
+ if self.combine_type == 'mult':
+ x.append(x_i[:, i, :] * x_j[:, j, :])
+ elif self.combine_type == 'concat':
+ x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1))
+ x = torch.stack(x, dim=1)
+ return x
+
+ def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ atten_matrix = subgraph_sampler(h, mode)
+ return torch.sum(atten_matrix, dim=0)
+
+ def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0')
+ for i in range(h.size(0)):
+ atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1
+ return torch.sum(atten_matrix, dim=0)
+
+
+class GCN_MLP_NCN(GCNs):
+ def __init__(self, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ bias=True, gcn_drop=0., opn='mult', hid_drop=0., input_drop=0., conve_hid_drop=0., feat_drop=0.,
+ num_filt=None, ker_sz=None, k_h=None, k_w=None, wni=False, wsi=False, encoder='compgcn', use_bn=True,
+ ltr=True, input_type='subgraph', graph_pooling='mean', combine_type='mult', loss_type='ce',
+ add_reverse=True):
+ """
+ :param num_ent: number of entities
+ :param num_rel: number of different relations
+ :param num_base: number of bases to use
+ :param init_dim: initial dimension
+ :param gcn_dim: dimension after first layer
+ :param embed_dim: dimension after second layer
+ :param n_layer: number of layer
+ :param edge_type: relation type of each edge, [E]
+ :param bias: weather to add bias
+ :param gcn_drop: dropout rate in compgcncov
+ :param opn: combination operator
+ :param hid_drop: gcn output (embedding of each entity) dropout
+ :param input_drop: dropout in conve input
+ :param conve_hid_drop: dropout in conve hidden layer
+ :param feat_drop: feature dropout in conve
+ :param num_filt: number of filters in conv2d
+ :param ker_sz: kernel size in conv2d
+ :param k_h: height of 2D reshape
+ :param k_w: width of 2D reshape
+ """
+ super(GCN_MLP_NCN, self).__init__(num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer,
+ edge_type, edge_norm, bias, gcn_drop, opn, wni, wsi, encoder,
+ use_bn, ltr, input_type, loss_type, add_reverse)
+ self.hid_drop, self.input_drop = hid_drop, input_drop
+ self.num_rel = num_rel
+
+ self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout
+ self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout
+
+ # fully connected projection
+ self.combine_type = combine_type
+ self.graph_pooling_type = graph_pooling
+ self.lin_layers = 2
+ self._init_predictor()
+
+ def _init_predictor(self):
+ self.lins = torch.nn.ModuleList()
+ if self.combine_type == 'mult':
+ input_channels = self.embed_dim
+ else:
+ input_channels = self.embed_dim * 2
+ self.lins.append(torch.nn.Linear(input_channels + self.embed_dim, self.embed_dim))
+ for _ in range(self.lin_layers - 2):
+ self.lins.append(torch.nn.Linear(self.embed_dim, self.embed_dim))
+ self.lins.append(torch.nn.Linear(self.embed_dim, self.num_rel))
+
+ def forward(self, g, subj, obj, cns):
+ """
+ :param g: dgl graph
+ :param sub: subject in batch [batch_size]
+ :param rel: relation in batch [batch_size]
+ :return: score: [batch_size, ent_num], the prob in link-prediction
+ """
+ sub_embs, obj_embs, all_ent, all_rel = self.forward_base_rel(g, subj, obj, self.drop, self.input_drop)
+ cn_embs = self.get_common_1hopneighbor_emb(all_ent, cns)
+
+ if self.combine_type == 'mult':
+ edge_embs = torch.concat([sub_embs * obj_embs, cn_embs], dim=1)
+ elif self.combine_type == 'concat':
+ edge_embs = torch.concat([sub_embs, obj_embs, cn_embs], dim=1)
+ x = edge_embs
+ for lin in self.lins[:-1]:
+ x = lin(x)
+ x = F.relu(x)
+ x = F.dropout(x, p=0.1, training=self.training)
+ score = self.lins[-1](x)
+
+ # [batch_size, 1, 2*k_h, k_w]
+ # stack_input = self.concat(sub_emb, obj_emb)
+ # x = self.bn0(stack_input)
+ # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+ # x = self.bn1(x)
+ # x = F.relu(x)
+ # x = self.feature_drop(x)
+ # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+ # x = self.fc(x) # [batch_size, embed_dim]
+ # x = self.hidden_drop(x)
+ # x = self.bn2(x)
+ # x = F.relu(x)
+ # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num]
+ # x += self.bias_rel.expand_as(x)
+ # # score = torch.sigmoid(x)
+ # score = x
+ return score
+
+ def forward_search(self, g, mode='allgraph'):
+ hidden_all_ent, all_ent = self.forward_base_rel_search(
+ g, self.drop, self.input_drop)
+
+ return hidden_all_ent, all_ent
+
+ def compute_pred(self, hidden_x, all_ent, subj, obj, cns, subgraph_sampler, mode='search'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = subgraph_sampler(h, mode)
+ # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # print(subgraph_sampler(h,mode='argmax'))
+ n, c = atten_matrix.shape
+ h = h * atten_matrix.view(n, c, 1)
+ cn_embs = self.get_common_1hopneighbor_emb(all_ent, cns)
+ # cn = self.get_common_neighbor_emb(hidden_x, cns)
+ # cn = self.get_common_neighbor_emb_(all_ent, cns)
+ # cn = cn * atten_matrix.view(n, c, 1)
+
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ h = torch.sum(h, dim=1)
+ # cn = torch.sum(cn, dim=1)
+ concat_embs = torch.concat([h, cn_embs], dim=1)
+ x = concat_embs
+ for lin in self.lins[:-1]:
+ x = lin(x)
+ x = F.relu(x)
+ x = F.dropout(x, p=0.1, training=self.training)
+ score = self.lins[-1](x)
+ # score = self.fc(h)
+ # if self.combine_type == 'so':
+ # score = self.fc(h)
+ # elif self.combine_type == 'gso':
+ # bg, all_rel = self.forward_base_rel_subgraph_trans_new(g, self.drop, self.input_drop)
+ # h_graph = readout_nodes(bg, 'h', op=self.graph_pooling_type)
+ # edge_embs = torch.concat([h_graph, h], dim=1)
+ # score = self.predictor(edge_embs)
+ # print(h.size()) # [batch_size, 2*dim]
+
+ return score
+
+ # # print(h.size())
+
+ def get_common_neighbor_emb(self, hidden_x, cns):
+ x = []
+ for i in range(self.n_layer):
+ for j in range(self.n_layer):
+ x_tmp = []
+ for idx in range(cns.size(0)):
+ print(hidden_x[cns[idx, i * 2 + j, :], i, :].size())
+ print(hidden_x[cns[idx, i * 2 + j, :], j, :].size())
+ x_tmp.append(torch.mean(hidden_x[cns[idx, i * 2 + j, :], i * 2 + j, :], dim=0))
+ x.append(torch.stack(x_tmp, dim=0))
+ x = torch.stack(x, dim=1)
+ return x
+
+ def get_common_neighbor_emb_(self, all_ent, cns):
+ x = []
+ for i in range(self.n_layer):
+ for j in range(self.n_layer):
+ x_tmp = []
+ for idx in range(cns.size(0)):
+ x_tmp.append(torch.mean(all_ent[cns[idx, i * 2 + j, :]], dim=0))
+ x.append(torch.stack(x_tmp, dim=0))
+ x = torch.stack(x, dim=1)
+ return x
+
+ def get_common_1hopneighbor_emb(self, all_ent, cns):
+ x = []
+ for idx in range(cns.size(0)):
+ x.append(torch.mean(all_ent[cns[idx, :], :], dim=0))
+ cn_embs = torch.stack(x, dim=0)
+ return cn_embs
+
+ # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops):
+ # h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0')
+ # for i in range(h.size(0)):
+ # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1
+ # # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # # print(subgraph_sampler(h,mode='argmax'))
+ # n, c = atten_matrix.shape
+ # h = h * atten_matrix.view(n,c,1)
+ # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ # h = torch.sum(h,dim=1)
+ # # print(h.size()) # [batch_size, 2*dim]
+ # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w)
+ # # x = self.bn0(h)
+ # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+ # # x = self.bn1(x)
+ # # x = F.relu(x)
+ # # x = self.feature_drop(x)
+ # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+ # # x = self.fc(x) # [batch_size, embed_dim]
+ # # x = self.hidden_drop(x)
+ # # x = self.bn2(x)
+ # # x = F.relu(x)
+ # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num]
+ # # x += self.bias_rel.expand_as(x)
+ # # # score = torch.sigmoid(x)
+ # # score = x
+ # return score
+ # print(h.size())
+
+ def cross_pair(self, x_i, x_j):
+ x = []
+ for i in range(self.n_layer):
+ for j in range(self.n_layer):
+ if self.combine_type == 'mult':
+ x.append(x_i[:, i, :] * x_j[:, j, :])
+ elif self.combine_type == 'concat':
+ x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1))
+ x = torch.stack(x, dim=1)
+ return x
+
+ def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ atten_matrix = subgraph_sampler(h, mode)
+ return torch.sum(atten_matrix, dim=0)
+
+ def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0')
+ for i in range(h.size(0)):
+ atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1
+ return torch.sum(atten_matrix, dim=0)
+
+
+def get_pad_mask(seq, pad_idx):
+ return (seq == pad_idx)
+
+
+def pad_batch(h_node, bg, max_input_len, subj, obj):
+ subj_list = []
+ obj_list = []
+ batch_size = bg.batch_size
+ max_num_nodes = min(max(bg.batch_num_nodes()).item(), max_input_len)
+ padded_h_node = h_node.data.new(max_num_nodes, batch_size, h_node.size(-1)).fill_(0)
+ src_padding_mask = h_node.data.new(batch_size, max_num_nodes).fill_(0).bool()
+ for batch_idx, g in enumerate(dgl.unbatch(bg)):
+ num_nodes = g.num_nodes()
+ padded_h_node[-num_nodes:, batch_idx] = g.ndata['h']
+ src_padding_mask[batch_idx, : max_num_nodes - num_nodes] = True
+ subj_idx = torch.where(g.ndata[NID] == subj[batch_idx])[0] + max_num_nodes - num_nodes
+ obj_idx = torch.where(g.ndata[NID] == obj[batch_idx])[0] + max_num_nodes - num_nodes
+ subj_list.append(subj_idx)
+ obj_list.append(obj_idx)
+ subj_idx = torch.cat(subj_list)
+ obj_idx = torch.cat(obj_list)
+ return padded_h_node, src_padding_mask, subj_idx, obj_idx
diff --git a/model/genotypes.py b/model/genotypes.py
new file mode 100644
index 0000000..abce99e
--- /dev/null
+++ b/model/genotypes.py
@@ -0,0 +1,23 @@
+COMP_PRIMITIVES = [
+ 'sub',
+ 'mult',
+ 'ccorr',
+ 'rotate'
+]
+
+AGG_PRIMITIVES = [
+ 'mean',
+ 'sum',
+ 'max',
+]
+
+COMB_PRIMITIVES = [
+ 'mlp',
+ 'concat'
+]
+
+ACT_PRIMITIVES = [
+ 'identity',
+ 'relu',
+ 'tanh'
+]
\ No newline at end of file
diff --git a/model/lte_models.py b/model/lte_models.py
new file mode 100644
index 0000000..ca193e3
--- /dev/null
+++ b/model/lte_models.py
@@ -0,0 +1,171 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def get_param(shape):
+ param = nn.Parameter(torch.Tensor(*shape))
+ nn.init.xavier_normal_(param.data)
+ return param
+
+
+class LTEModel(nn.Module):
+ def __init__(self, num_ents, num_rels, params=None):
+ super(LTEModel, self).__init__()
+
+ self.bceloss = torch.nn.BCELoss()
+
+ self.p = params
+ self.init_embed = get_param((num_ents, self.p.init_dim))
+ self.device = "cuda"
+
+ self.init_rel = get_param((num_rels * 2, self.p.init_dim))
+
+ self.bias = nn.Parameter(torch.zeros(num_ents))
+
+ self.h_ops_dict = nn.ModuleDict({
+ 'p': nn.Linear(self.p.init_dim, self.p.gcn_dim, bias=False),
+ 'b': nn.BatchNorm1d(self.p.gcn_dim),
+ 'd': nn.Dropout(self.p.hid_drop),
+ 'a': nn.Tanh(),
+ })
+
+ self.t_ops_dict = nn.ModuleDict({
+ 'p': nn.Linear(self.p.init_dim, self.p.gcn_dim, bias=False),
+ 'b': nn.BatchNorm1d(self.p.gcn_dim),
+ 'd': nn.Dropout(self.p.hid_drop),
+ 'a': nn.Tanh(),
+ })
+
+ self.r_ops_dict = nn.ModuleDict({
+ 'p': nn.Linear(self.p.init_dim, self.p.gcn_dim, bias=False),
+ 'b': nn.BatchNorm1d(self.p.gcn_dim),
+ 'd': nn.Dropout(self.p.hid_drop),
+ 'a': nn.Tanh(),
+ })
+
+ self.x_ops = self.p.x_ops
+ self.r_ops = self.p.r_ops
+ self.diff_ht = False
+
+ def calc_loss(self, pred, label):
+ return self.loss(pred, label)
+
+ def loss(self, pred, true_label):
+ return self.bceloss(pred, true_label)
+
+ def exop(self, x, r, x_ops=None, r_ops=None, diff_ht=False):
+ x_head = x_tail = x
+ if len(x_ops) > 0:
+ for x_op in x_ops.split("."):
+ if diff_ht:
+ x_head = self.h_ops_dict[x_op](x_head)
+ x_tail = self.t_ops_dict[x_op](x_tail)
+ else:
+ x_head = x_tail = self.h_ops_dict[x_op](x_head)
+
+ if len(r_ops) > 0:
+ for r_op in r_ops.split("."):
+ r = self.r_ops_dict[r_op](r)
+
+ return x_head, x_tail, r
+
+
+class TransE(LTEModel):
+ def __init__(self, num_ents, num_rels, params=None):
+ super(self.__class__, self).__init__(num_ents, num_rels, params)
+ self.loop_emb = get_param([1, self.p.init_dim])
+
+ def forward(self, g, sub, rel):
+ x = self.init_embed
+ r = self.init_rel
+
+ x_h, x_t, r = self.exop(x - self.loop_emb, r, self.x_ops, self.r_ops)
+
+ sub_emb = torch.index_select(x_h, 0, sub)
+ rel_emb = torch.index_select(r, 0, rel)
+ all_ent = x_t
+
+ obj_emb = sub_emb + rel_emb
+ x = self.p.gamma - \
+ torch.norm(obj_emb.unsqueeze(1) - all_ent, p=1, dim=2)
+ score = torch.sigmoid(x)
+
+ return score
+
+
+class DistMult(LTEModel):
+ def __init__(self, num_ents, num_rels, params=None):
+ super(self.__class__, self).__init__(num_ents, num_rels, params)
+
+ def forward(self, g, sub, rel):
+ x = self.init_embed
+ r = self.init_rel
+
+ x_h, x_t, r = self.exop(x, r, self.x_ops, self.r_ops)
+
+ sub_emb = torch.index_select(x_h, 0, sub)
+ rel_emb = torch.index_select(r, 0, rel)
+ all_ent = x_t
+
+ obj_emb = sub_emb * rel_emb
+ x = torch.mm(obj_emb, all_ent.transpose(1, 0))
+ x += self.bias.expand_as(x)
+ score = torch.sigmoid(x)
+
+ return score
+
+
+class ConvE(LTEModel):
+ def __init__(self, num_ents, num_rels, params=None):
+ super(self.__class__, self).__init__(num_ents, num_rels, params)
+ self.bn0 = torch.nn.BatchNorm2d(1)
+ self.bn1 = torch.nn.BatchNorm2d(self.p.num_filt)
+ self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim)
+
+ self.hidden_drop = torch.nn.Dropout(self.p.hid_drop)
+ self.hidden_drop2 = torch.nn.Dropout(self.p.conve_hid_drop)
+ self.feature_drop = torch.nn.Dropout(self.p.feat_drop)
+ self.m_conv1 = torch.nn.Conv2d(1, out_channels=self.p.num_filt, kernel_size=(self.p.ker_sz, self.p.ker_sz),
+ stride=1, padding=0, bias=self.p.bias)
+
+ flat_sz_h = int(2 * self.p.k_w) - self.p.ker_sz + 1
+ flat_sz_w = self.p.k_h - self.p.ker_sz + 1
+ self.flat_sz = flat_sz_h * flat_sz_w * self.p.num_filt
+ self.fc = torch.nn.Linear(self.flat_sz, self.p.embed_dim)
+
+ def concat(self, e1_embed, rel_embed):
+ e1_embed = e1_embed.view(-1, 1, self.p.embed_dim)
+ rel_embed = rel_embed.view(-1, 1, self.p.embed_dim)
+ stack_inp = torch.cat([e1_embed, rel_embed], 1)
+ stack_inp = torch.transpose(stack_inp, 2, 1).reshape(
+ (-1, 1, 2 * self.p.k_w, self.p.k_h))
+ return stack_inp
+
+ def forward(self, g, sub, rel):
+ x = self.init_embed
+ r = self.init_rel
+
+ x_h, x_t, r = self.exop(x, r, self.x_ops, self.r_ops)
+
+ sub_emb = torch.index_select(x_h, 0, sub)
+ rel_emb = torch.index_select(r, 0, rel)
+ all_ent = x_t
+
+ stk_inp = self.concat(sub_emb, rel_emb)
+ x = self.bn0(stk_inp)
+ x = self.m_conv1(x)
+ x = self.bn1(x)
+ x = F.relu(x)
+ x = self.feature_drop(x)
+ x = x.view(-1, self.flat_sz)
+ x = self.fc(x)
+ x = self.hidden_drop2(x)
+ x = self.bn2(x)
+ x = F.relu(x)
+
+ x = torch.mm(x, all_ent.transpose(1, 0))
+ x += self.bias.expand_as(x)
+
+ score = torch.sigmoid(x)
+ return score
diff --git a/model/model.py b/model/model.py
new file mode 100644
index 0000000..a99da4b
--- /dev/null
+++ b/model/model.py
@@ -0,0 +1,262 @@
+import torch
+from torch import nn
+import dgl
+import dgl.function as fn
+from torch.autograd import Variable
+from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES
+from model.operations import *
+from model.fune_layer import SearchedGCNConv
+from torch.nn.functional import softmax
+from pprint import pprint
+
+
+class NetworkGNN(nn.Module):
+ def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ conv_bias=True, gcn_drop=0., wni=False, wsi=False, use_bn=True, ltr=True, loss_type='ce', genotype=None):
+ super(NetworkGNN, self).__init__()
+ self.act = torch.tanh
+ self.args = args
+ self.loss_type = loss_type
+ self.num_ent, self.num_rel, self.num_base = num_ent, num_rel, num_base
+ self.init_dim, self.gcn_dim, self.embed_dim = init_dim, gcn_dim, embed_dim
+
+ self.gcn_drop = gcn_drop
+ self.edge_type = edge_type # [E]
+ self.edge_norm = edge_norm # [E]
+ self.n_layer = n_layer
+
+ self.init_embed = self.get_param([self.num_ent + 1, self.init_dim])
+ self.init_rel = self.get_param([self.num_rel, self.init_dim])
+ self.bias_rel = nn.Parameter(torch.zeros(self.num_rel))
+
+ self._initialize_loss()
+ self.gnn_layers = nn.ModuleList()
+ ops = genotype.split('||')
+ for idx in range(self.args.n_layer):
+ if idx == 0:
+ self.gnn_layers.append(
+ SearchedGCNConv(self.init_dim, self.gcn_dim, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ comp=ops[4*idx], agg=ops[4*idx+1], comb=ops[4*idx+2], act=ops[4*idx+3]))
+ elif idx == self.args.n_layer-1:
+ self.gnn_layers.append(
+ SearchedGCNConv(self.gcn_dim, self.embed_dim, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ comp=ops[4*idx], agg=ops[4*idx+1], comb=ops[4*idx+2], act=ops[4*idx+3]))
+ else:
+ self.gnn_layers.append(
+ SearchedGCNConv(self.gcn_dim, self.gcn_dim, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ comp=ops[4*idx], agg=ops[4*idx+1], comb=ops[4*idx+2], act=ops[4*idx+3]))
+
+ def get_param(self, shape):
+ param = nn.Parameter(torch.Tensor(*shape))
+ nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu'))
+ return param
+
+
+ def _initialize_loss(self):
+ if self.loss_type == 'ce':
+ self.loss = nn.CrossEntropyLoss()
+ elif self.loss_type == 'bce':
+ self.loss = nn.BCELoss(reduce=False)
+ elif self.loss_type == 'bce_logits':
+ self.loss = nn.BCEWithLogitsLoss()
+ else:
+ raise NotImplementedError
+
+ def calc_loss(self, pred, label, pos_neg=None):
+ if pos_neg is not None:
+ m = nn.Sigmoid()
+ score_pos = m(pred)
+ targets_pos = pos_neg.unsqueeze(1)
+ loss = self.loss(score_pos, label * targets_pos)
+ return torch.sum(loss * label)
+ return self.loss(pred, label)
+
+ def forward_base(self, g, subj, obj, drop1, drop2, mode=None):
+ x, r = self.init_embed, self.init_rel # embedding of relations
+
+ for i, layer in enumerate(self.gnn_layers):
+ if i != self.args.n_layer-1:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm)
+ x = drop1(x)
+ else:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm)
+ x = drop2(x)
+
+ sub_emb = torch.index_select(x, 0, subj)
+ # filter out embeddings of objects in this batch
+ obj_emb = torch.index_select(x, 0, obj)
+
+ return sub_emb, obj_emb, x, r
+
+ def forward_base_search(self, g, drop1, drop2):
+ x, r = self.init_embed, self.init_rel # embedding of relations
+ x_hidden = []
+ for i, layer in enumerate(self.gnn_layers):
+ if i != self.args.n_layer-1:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm)
+ x_hidden.append(x)
+ x = drop1(x)
+ else:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm)
+ x_hidden.append(x)
+ x = drop2(x)
+ x_hidden = torch.stack(x_hidden, dim=1)
+
+ return x_hidden, x
+
+
+class SearchedGCN_MLP(NetworkGNN):
+ def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ bias=True, gcn_drop=0., hid_drop=0., input_drop=0., wni=False, wsi=False, use_bn=True,
+ ltr=True, combine_type='mult', loss_type='ce', genotype=None):
+ """
+ :param num_ent: number of entities
+ :param num_rel: number of different relations
+ :param num_base: number of bases to use
+ :param init_dim: initial dimension
+ :param gcn_dim: dimension after first layer
+ :param embed_dim: dimension after second layer
+ :param n_layer: number of layer
+ :param edge_type: relation type of each edge, [E]
+ :param bias: weather to add bias
+ :param gcn_drop: dropout rate in compgcncov
+ :param opn: combination operator
+ :param hid_drop: gcn output (embedding of each entity) dropout
+ :param input_drop: dropout in conve input
+ :param conve_hid_drop: dropout in conve hidden layer
+ :param feat_drop: feature dropout in conve
+ :param num_filt: number of filters in conv2d
+ :param ker_sz: kernel size in conv2d
+ :param k_h: height of 2D reshape
+ :param k_w: width of 2D reshape
+ """
+ super(SearchedGCN_MLP, self).__init__(args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer,
+ edge_type, edge_norm, bias, gcn_drop, wni, wsi, use_bn, ltr, loss_type, genotype)
+ self.hid_drop, self.input_drop = hid_drop, input_drop
+
+ self.num_rel = num_rel
+
+ self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout
+ self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout
+
+ self.combine_type = combine_type
+ if self.combine_type == 'concat':
+ self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel)
+ elif self.combine_type == 'mult':
+ self.fc = torch.nn.Linear(self.embed_dim, self.num_rel)
+
+ def forward(self, g, subj, obj, mode=None):
+ """
+ :param g: dgl graph
+ :param sub: subject in batch [batch_size]
+ :param rel: relation in batch [batch_size]
+ :return: score: [batch_size, ent_num], the prob in link-prediction
+ """
+
+ sub_embs, obj_embs, all_ent, all_rel = self.forward_base(g, subj, obj, self.drop, self.input_drop, mode)
+ if self.combine_type == 'concat':
+ edge_embs = torch.concat([sub_embs, obj_embs], dim=1)
+ elif self.combine_type == 'mult':
+ edge_embs = sub_embs * obj_embs
+ else:
+ raise NotImplementedError
+ score = self.fc(edge_embs)
+ return score
+
+ def forward_search(self, g, mode='allgraph'):
+ # if mode == 'allgraph':
+ hidden_all_ent, all_ent = self.forward_base_search(
+ g, self.drop, self.input_drop)
+ # elif mode == 'subgraph':
+ # hidden_all_ent = self.forward_base_subgraph_search(
+ # g, self.drop, self.input_drop)
+
+ return hidden_all_ent, all_ent
+
+ def compute_pred(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = subgraph_sampler(h, mode, search_algorithm)
+ # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # print(subgraph_sampler(h,mode='argmax'))
+ n, c = atten_matrix.shape
+ h = h * atten_matrix.view(n, c, 1)
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ h = torch.sum(h, dim=1)
+ # print(h.size()) # [batch_size, 2*dim]
+ score = self.fc(h)
+ return score
+
+ # def fine_tune_with_implicit_subgraph(self, all_ent, subgraph, subj, obj):
+ # sg_list = []
+ # for idx in range(subgraph.size(0)):
+ # sg_list.append(torch.mean(all_ent[subgraph[idx,:]], dim=0).unsqueeze(0))
+ # sg_embs = torch.concat(sg_list)
+ # # print(sg_embs.size())
+ # sub_embs = torch.index_select(all_ent, 0, subj)
+ # # print(sub_embs.size())
+ # # filter out embeddings of relations in this batch
+ # obj_embs = torch.index_select(all_ent, 0, obj)
+ # # print(obj_embs.size())
+ # edge_embs = torch.concat([sub_embs, obj_embs, sg_embs], dim=1)
+ # score = self.predictor(edge_embs)
+ # # print(F.embedding(subgraph, all_ent))
+ # return score
+
+ # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops):
+ # h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0')
+ # for i in range(h.size(0)):
+ # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1
+ # # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # # print(subgraph_sampler(h,mode='argmax'))
+ # n, c = atten_matrix.shape
+ # h = h * atten_matrix.view(n,c,1)
+ # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ # h = torch.sum(h,dim=1)
+ # # print(h.size()) # [batch_size, 2*dim]
+ # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w)
+ # # x = self.bn0(h)
+ # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+ # # x = self.bn1(x)
+ # # x = F.relu(x)
+ # # x = self.feature_drop(x)
+ # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+ # # x = self.fc(x) # [batch_size, embed_dim]
+ # # x = self.hidden_drop(x)
+ # # x = self.bn2(x)
+ # # x = F.relu(x)
+ # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num]
+ # # x += self.bias_rel.expand_as(x)
+ # # # score = torch.sigmoid(x)
+ # # score = x
+ # return score
+ # print(h.size())
+
+ def cross_pair(self, x_i, x_j):
+ x = []
+ for i in range(self.n_layer):
+ for j in range(self.n_layer):
+ if self.combine_type == 'mult':
+ x.append(x_i[:, i, :] * x_j[:, j, :])
+ elif self.combine_type == 'concat':
+ x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1))
+ x = torch.stack(x, dim=1)
+ return x
+
+ def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ atten_matrix = subgraph_sampler(h, mode, search_algorithm)
+ return torch.sum(atten_matrix, dim=0)
+
+ def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0')
+ for i in range(h.size(0)):
+ atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1
+ return torch.sum(atten_matrix, dim=0)
diff --git a/model/model_fast.py b/model/model_fast.py
new file mode 100644
index 0000000..193743b
--- /dev/null
+++ b/model/model_fast.py
@@ -0,0 +1,284 @@
+import torch
+from torch import nn
+import dgl
+import dgl.function as fn
+from torch.autograd import Variable
+from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES
+from model.operations import *
+from model.fune_layer import SearchedGCNConv
+from torch.nn.functional import softmax
+from pprint import pprint
+
+
+class NetworkGNN_MLP(nn.Module):
+ def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ conv_bias=True, gcn_drop=0., hid_drop=0., input_drop=0, wni=False, wsi=False, use_bn=True, ltr=True,
+ combine_type='mult', loss_type='ce', genotype=None):
+ super(NetworkGNN_MLP, self).__init__()
+ self.act = torch.tanh
+ self.args = args
+ self.loss_type = loss_type
+ self.num_ent, self.num_rel, self.num_base = num_ent, num_rel, num_base
+ self.init_dim, self.gcn_dim, self.embed_dim = init_dim, gcn_dim, embed_dim
+ self.hid_drop, self.input_drop = hid_drop, input_drop
+ self.gcn_drop = gcn_drop
+ self.edge_type = edge_type # [E]
+ self.edge_norm = edge_norm # [E]
+ self.n_layer = n_layer
+
+ self.init_embed = self.get_param([self.num_ent + 1, self.init_dim])
+ self.init_rel = self.get_param([self.num_rel, self.init_dim])
+ self.bias_rel = nn.Parameter(torch.zeros(self.num_rel))
+
+ self._initialize_loss()
+ self.gnn_layers = nn.ModuleList()
+ ops = genotype.split('||')
+ for idx in range(self.args.n_layer):
+ if idx == 0:
+ self.gnn_layers.append(
+ SearchedGCNConv(self.init_dim, self.gcn_dim, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ comp=ops[4*idx], agg=ops[4*idx+1], comb=ops[4*idx+2], act=ops[4*idx+3]))
+ elif idx == self.args.n_layer-1:
+ self.gnn_layers.append(
+ SearchedGCNConv(self.gcn_dim, self.embed_dim, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ comp=ops[4*idx], agg=ops[4*idx+1], comb=ops[4*idx+2], act=ops[4*idx+3]))
+ else:
+ self.gnn_layers.append(
+ SearchedGCNConv(self.gcn_dim, self.gcn_dim, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ comp=ops[4*idx], agg=ops[4*idx+1], comb=ops[4*idx+2], act=ops[4*idx+3]))
+
+ self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout
+ self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout
+
+ self.combine_type = combine_type
+ if self.combine_type == 'concat':
+ self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel)
+ elif self.combine_type == 'mult':
+ self.fc = torch.nn.Linear(self.embed_dim, self.num_rel)
+
+ def get_param(self, shape):
+ param = nn.Parameter(torch.Tensor(*shape))
+ nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu'))
+ return param
+
+
+ def _initialize_loss(self):
+ if self.loss_type == 'ce':
+ self.loss = nn.CrossEntropyLoss()
+ elif self.loss_type == 'bce':
+ self.loss = nn.BCELoss()
+ elif self.loss_type == 'bce_logits':
+ self.loss = nn.BCEWithLogitsLoss()
+ else:
+ raise NotImplementedError
+
+ def calc_loss(self, pred, label):
+ return self.loss(pred, label)
+
+ def forward_base(self, g, subj, obj, drop1, drop2, mode=None):
+ x, r = self.init_embed, self.init_rel # embedding of relations
+
+ for i, layer in enumerate(self.gnn_layers):
+ if i != self.args.n_layer-1:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm)
+ x = drop1(x)
+ else:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm)
+ x = drop2(x)
+
+ sub_emb = torch.index_select(x, 0, subj)
+ # filter out embeddings of objects in this batch
+ obj_emb = torch.index_select(x, 0, obj)
+
+ return sub_emb, obj_emb, x, r
+
+ def forward_base_search(self, g, drop1, drop2):
+ x, r = self.init_embed, self.init_rel # embedding of relations
+ x_hidden = []
+ for i, layer in enumerate(self.gnn_layers):
+ if i != self.args.n_layer-1:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm)
+ x_hidden.append(x)
+ x = drop1(x)
+ else:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm)
+ x_hidden.append(x)
+ x = drop2(x)
+ x_hidden = torch.stack(x_hidden, dim=1)
+
+ return x_hidden, x
+
+ def forward(self, g, subj, obj, mode=None):
+ """
+ :param g: dgl graph
+ :param sub: subject in batch [batch_size]
+ :param rel: relation in batch [batch_size]
+ :return: score: [batch_size, ent_num], the prob in link-prediction
+ """
+
+ sub_embs, obj_embs, all_ent, all_rel = self.forward_base(g, subj, obj, self.drop, self.input_drop, mode)
+ if self.combine_type == 'concat':
+ edge_embs = torch.concat([sub_embs, obj_embs], dim=1)
+ elif self.combine_type == 'mult':
+ edge_embs = sub_embs * obj_embs
+ else:
+ raise NotImplementedError
+ score = self.fc(edge_embs)
+ return score
+
+
+# class SearchedGCN_MLP(NetworkGNN):
+# def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+# bias=True, gcn_drop=0., hid_drop=0., input_drop=0., wni=False, wsi=False, use_bn=True,
+# ltr=True, combine_type='mult', loss_type='ce', genotype=None):
+# """
+# :param num_ent: number of entities
+# :param num_rel: number of different relations
+# :param num_base: number of bases to use
+# :param init_dim: initial dimension
+# :param gcn_dim: dimension after first layer
+# :param embed_dim: dimension after second layer
+# :param n_layer: number of layer
+# :param edge_type: relation type of each edge, [E]
+# :param bias: weather to add bias
+# :param gcn_drop: dropout rate in compgcncov
+# :param opn: combination operator
+# :param hid_drop: gcn output (embedding of each entity) dropout
+# :param input_drop: dropout in conve input
+# :param conve_hid_drop: dropout in conve hidden layer
+# :param feat_drop: feature dropout in conve
+# :param num_filt: number of filters in conv2d
+# :param ker_sz: kernel size in conv2d
+# :param k_h: height of 2D reshape
+# :param k_w: width of 2D reshape
+# """
+# super(SearchedGCN_MLP, self).__init__(args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer,
+# edge_type, edge_norm, bias, gcn_drop, wni, wsi, use_bn, ltr, loss_type, genotype)
+# self.hid_drop, self.input_drop = hid_drop, input_drop
+#
+# self.num_rel = num_rel
+#
+# self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout
+# self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout
+#
+# self.combine_type = combine_type
+# if self.combine_type == 'concat':
+# self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel)
+# elif self.combine_type == 'mult':
+# self.fc = torch.nn.Linear(self.embed_dim, self.num_rel)
+#
+# def forward(self, g, subj, obj, mode=None):
+# """
+# :param g: dgl graph
+# :param sub: subject in batch [batch_size]
+# :param rel: relation in batch [batch_size]
+# :return: score: [batch_size, ent_num], the prob in link-prediction
+# """
+#
+# sub_embs, obj_embs, all_ent, all_rel = self.forward_base(g, subj, obj, self.drop, self.input_drop, mode)
+# if self.combine_type == 'concat':
+# edge_embs = torch.concat([sub_embs, obj_embs], dim=1)
+# elif self.combine_type == 'mult':
+# edge_embs = sub_embs * obj_embs
+# else:
+# raise NotImplementedError
+# score = self.fc(edge_embs)
+# return score
+#
+# def forward_search(self, g, mode='allgraph'):
+# # if mode == 'allgraph':
+# hidden_all_ent, all_ent = self.forward_base_rel_search(
+# g, self.drop, self.input_drop)
+# # elif mode == 'subgraph':
+# # hidden_all_ent = self.forward_base_subgraph_search(
+# # g, self.drop, self.input_drop)
+#
+# return hidden_all_ent, all_ent
+#
+# def compute_pred(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'):
+# h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+# # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+# atten_matrix = subgraph_sampler(h, mode, search_algorithm)
+# # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+# # print(subgraph_sampler(h,mode='argmax'))
+# n, c = atten_matrix.shape
+# h = h * atten_matrix.view(n, c, 1)
+# # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+# h = torch.sum(h, dim=1)
+# # print(h.size()) # [batch_size, 2*dim]
+# score = self.fc(h)
+# return score
+#
+# # def fine_tune_with_implicit_subgraph(self, all_ent, subgraph, subj, obj):
+# # sg_list = []
+# # for idx in range(subgraph.size(0)):
+# # sg_list.append(torch.mean(all_ent[subgraph[idx,:]], dim=0).unsqueeze(0))
+# # sg_embs = torch.concat(sg_list)
+# # # print(sg_embs.size())
+# # sub_embs = torch.index_select(all_ent, 0, subj)
+# # # print(sub_embs.size())
+# # # filter out embeddings of relations in this batch
+# # obj_embs = torch.index_select(all_ent, 0, obj)
+# # # print(obj_embs.size())
+# # edge_embs = torch.concat([sub_embs, obj_embs, sg_embs], dim=1)
+# # score = self.predictor(edge_embs)
+# # # print(F.embedding(subgraph, all_ent))
+# # return score
+#
+# # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops):
+# # h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+# # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+# # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0')
+# # for i in range(h.size(0)):
+# # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1
+# # # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+# # # print(subgraph_sampler(h,mode='argmax'))
+# # n, c = atten_matrix.shape
+# # h = h * atten_matrix.view(n,c,1)
+# # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+# # h = torch.sum(h,dim=1)
+# # # print(h.size()) # [batch_size, 2*dim]
+# # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w)
+# # # x = self.bn0(h)
+# # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+# # # x = self.bn1(x)
+# # # x = F.relu(x)
+# # # x = self.feature_drop(x)
+# # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+# # # x = self.fc(x) # [batch_size, embed_dim]
+# # # x = self.hidden_drop(x)
+# # # x = self.bn2(x)
+# # # x = F.relu(x)
+# # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num]
+# # # x += self.bias_rel.expand_as(x)
+# # # # score = torch.sigmoid(x)
+# # # score = x
+# # return score
+# # print(h.size())
+#
+# def cross_pair(self, x_i, x_j):
+# x = []
+# for i in range(self.n_layer):
+# for j in range(self.n_layer):
+# if self.combine_type == 'mult':
+# x.append(x_i[:, i, :] * x_j[:, j, :])
+# elif self.combine_type == 'concat':
+# x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1))
+# x = torch.stack(x, dim=1)
+# return x
+#
+# def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search'):
+# h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+# atten_matrix = subgraph_sampler(h, mode)
+# return torch.sum(atten_matrix, dim=0)
+#
+# def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops):
+# h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+# # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+# atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0')
+# for i in range(h.size(0)):
+# atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1
+# return torch.sum(atten_matrix, dim=0)
diff --git a/model/model_search.py b/model/model_search.py
new file mode 100644
index 0000000..da6a662
--- /dev/null
+++ b/model/model_search.py
@@ -0,0 +1,448 @@
+import torch
+from torch import nn
+import dgl
+import dgl.function as fn
+from torch.autograd import Variable
+from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES
+from model.operations import *
+from model.search_layer import SearchGCNConv
+from torch.nn.functional import softmax
+from pprint import pprint
+
+
+class CompMixOP(nn.Module):
+ def __init__(self):
+ super(CompMixOP, self).__init__()
+ self._ops = nn.ModuleList()
+ for primitive in COMP_PRIMITIVES:
+ op = COMP_OPS[primitive]()
+ self._ops.append(op)
+
+ def reset_parameters(self):
+ for op in self._ops:
+ op.reset_parameters()
+
+ def forward(self, src_emb, rel_emb, weights):
+ mixed_res = []
+ for w, op in zip(weights, self._ops):
+ mixed_res.append(w * op(src_emb, rel_emb))
+ return sum(mixed_res)
+
+
+class CompOp(nn.Module):
+
+ def __init__(self, op_name):
+ super(CompOp, self).__init__()
+ self.op = COMP_OPS[op_name]()
+
+ def reset_parameters(self):
+ self.op.reset_parameters()
+
+ def forward(self, src_emb, rel_emb):
+ return self.op(src_emb, rel_emb)
+
+
+class Network(nn.Module):
+ def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ conv_bias=True, gcn_drop=0., wni=False, wsi=False, use_bn=True, ltr=True, loss_type='ce'):
+ super(Network, self).__init__()
+ self.act = torch.tanh
+ self.args = args
+ self.loss_type = loss_type
+ self.num_ent, self.num_rel, self.num_base = num_ent, num_rel, num_base
+ self.init_dim, self.gcn_dim, self.embed_dim = init_dim, gcn_dim, embed_dim
+
+ self.gcn_drop = gcn_drop
+ self.edge_type = edge_type # [E]
+ self.edge_norm = edge_norm # [E]
+ self.n_layer = n_layer
+
+ self.init_embed = self.get_param([self.num_ent + 1, self.init_dim])
+ self.init_rel = self.get_param([self.num_rel, self.init_dim])
+ self.bias_rel = nn.Parameter(torch.zeros(self.num_rel))
+
+ self._initialize_alphas()
+ self._initialize_loss()
+ self.gnn_layers = nn.ModuleList()
+ for idx in range(self.args.n_layer):
+ if idx == 0:
+ self.gnn_layers.append(
+ SearchGCNConv(self.init_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ comp_weights=self.comp_alphas[idx]))
+ elif idx == self.args.n_layer-1:
+ self.gnn_layers.append(SearchGCNConv(self.gcn_dim, self.embed_dim, self.act, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ comp_weights=self.comp_alphas[idx]))
+ else:
+ self.gnn_layers.append(
+ SearchGCNConv(self.gcn_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr,
+ comp_weights=self.comp_alphas[idx]))
+
+ def arch_parameters(self):
+ return self._arch_parameters
+
+ def genotype(self):
+ gene = []
+ comp_max, comp_indices = torch.max(softmax(self.comp_alphas, dim=-1).data.cpu(), dim=-1)
+ agg_max, agg_indices = torch.max(softmax(self.agg_alphas, dim=-1).data.cpu(), dim=-1)
+ comb_max, comb_indices = torch.max(softmax(self.comb_alphas, dim=-1).data.cpu(), dim=-1)
+ act_max, act_indices = torch.max(softmax(self.act_alphas, dim=-1).data.cpu(), dim=-1)
+ pprint(comp_max)
+ pprint(agg_max)
+ pprint(comb_max)
+ pprint(act_max)
+ for idx in range(self.args.n_layer):
+ gene.append(COMP_PRIMITIVES[comp_indices[idx]])
+ gene.append(AGG_PRIMITIVES[agg_indices[idx]])
+ gene.append(COMB_PRIMITIVES[comb_indices[idx]])
+ gene.append(ACT_PRIMITIVES[act_indices[idx]])
+ return "||".join(gene)
+
+ # self.in_channels = in_channels
+ # self.out_channels = out_channels
+ # self.act = act # activation function
+ # self.device = None
+ # if add_reverse:
+ # self.rel = nn.Parameter(torch.empty([num_rel * 2, in_channels], dtype=torch.float))
+ # else:
+ # self.rel = nn.Parameter(torch.empty([num_rel, in_channels], dtype=torch.float))
+ # self.opn = opn
+ #
+ # self.use_bn = use_bn
+ # self.ltr = ltr
+ #
+ # # relation-type specific parameter
+ # self.in_w = self.get_param([in_channels, out_channels])
+ # self.out_w = self.get_param([in_channels, out_channels])
+ # self.loop_w = self.get_param([in_channels, out_channels])
+ # # transform embedding of relations to next layer
+ # self.w_rel = self.get_param([in_channels, out_channels])
+ # self.loop_rel = self.get_param([1, in_channels]) # self-loop embedding
+ #
+ # self.drop = nn.Dropout(drop_rate)
+ # self.bn = torch.nn.BatchNorm1d(out_channels)
+ # self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None
+
+ # if num_base > 0:
+ # if add_reverse:
+ # self.rel_wt = self.get_param([num_rel * 2, num_base])
+ # else:
+ # self.rel_wt = self.get_param([num_rel, num_base])
+ # else:
+ # self.rel_wt = None
+ #
+ # self.wni = wni
+ # self.wsi = wsi
+
+ def get_param(self, shape):
+ param = nn.Parameter(torch.Tensor(*shape))
+ nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu'))
+ return param
+
+ def _initialize_alphas(self):
+ comp_ops_num = len(COMP_PRIMITIVES)
+ agg_ops_num = len(AGG_PRIMITIVES)
+ comb_ops_num = len(COMB_PRIMITIVES)
+ act_ops_num = len(ACT_PRIMITIVES)
+ if self.args.search_algorithm == "darts":
+ self.comp_alphas = Variable(1e-3 * torch.randn(self.args.n_layer, comp_ops_num).cuda(), requires_grad=True)
+ self.agg_alphas = Variable(1e-3 * torch.randn(self.args.n_layer, agg_ops_num).cuda(), requires_grad=True)
+ self.comb_alphas = Variable(1e-3 * torch.randn(self.args.n_layer, comb_ops_num).cuda(), requires_grad=True)
+ self.act_alphas = Variable(1e-3 * torch.randn(self.args.n_layer, act_ops_num).cuda(), requires_grad=True)
+ elif self.args.search_algorithm == "snas":
+ # self.comp_alphas = Variable(1e-3 * torch.randn(self.args.n_layer, comp_ops_num).cuda(), requires_grad=True)
+ self.comp_alphas = Variable(
+ torch.ones(self.args.n_layer, comp_ops_num).normal_(self.args.loc_mean, self.args.loc_std).cuda(),
+ requires_grad=True)
+ self.agg_alphas = Variable(
+ torch.ones(self.args.n_layer, agg_ops_num).normal_(self.args.loc_mean, self.args.loc_std).cuda(),
+ requires_grad=True)
+ self.comb_alphas = Variable(
+ torch.ones(self.args.n_layer, comb_ops_num).normal_(self.args.loc_mean, self.args.loc_std).cuda(),
+ requires_grad=True)
+ self.act_alphas = Variable(
+ torch.ones(self.args.n_layer, act_ops_num).normal_(self.args.loc_mean, self.args.loc_std).cuda(),
+ requires_grad=True)
+ # self.la_alphas = Variable(torch.ones(1, la_ops_num).normal_(self.args.loc_mean, self.args.loc_std).cuda(),
+ # requires_grad=True)
+ # self.seq_alphas = Variable(
+ # torch.ones(self.layer_num, seq_ops_num).normal_(self.args.loc_mean, self.args.loc_std).cuda(),
+ # requires_grad=True)
+ else:
+ raise NotImplementedError
+ self._arch_parameters = [
+ self.comp_alphas,
+ self.agg_alphas,
+ self.comb_alphas,
+ self.act_alphas
+ ]
+
+ def _initialize_loss(self):
+ if self.loss_type == 'ce':
+ self.loss = nn.CrossEntropyLoss()
+ elif self.loss_type == 'bce':
+ self.loss = nn.BCELoss(reduce=False)
+ elif self.loss_type == 'bce_logits':
+ self.loss = nn.BCEWithLogitsLoss()
+ else:
+ raise NotImplementedError
+
+ def calc_loss(self, pred, label, pos_neg=None):
+ if pos_neg is not None:
+ m = nn.Sigmoid()
+ score_pos = m(pred)
+ targets_pos = pos_neg.unsqueeze(1)
+ loss = self.loss(score_pos, label * targets_pos)
+ return torch.sum(loss * label)
+ return self.loss(pred, label)
+
+ def _get_categ_mask(self, alpha):
+ # log_alpha = torch.log(alpha)
+ log_alpha = alpha
+ u = torch.zeros_like(log_alpha).uniform_()
+ softmax = torch.nn.Softmax(-1)
+ one_hot = softmax((log_alpha + (-((-(u.log())).log()))) / self.args.temperature)
+ return one_hot
+
+ def _get_categ_mask_new(self, alpha):
+ # log_alpha = torch.log(alpha)
+ print(alpha)
+ m = torch.distributions.relaxed_categorical.RelaxedOneHotCategorical(
+ torch.tensor([self.args.temperature]).cuda(), alpha)
+ print(m.sample())
+ print(m.rsample())
+
+ def get_one_hot_alpha(self, alpha):
+ one_hot_alpha = torch.zeros_like(alpha, device=alpha.device)
+ idx = torch.argmax(alpha, dim=-1)
+
+ for i in range(one_hot_alpha.size(0)):
+ one_hot_alpha[i, idx[i]] = 1.0
+
+ return one_hot_alpha
+
+ def forward_base(self, g, subj, obj, drop1, drop2, mode=None):
+ if self.args.search_algorithm == "darts":
+ comp_weights = softmax(self.comp_alphas / self.args.temperature, dim=-1)
+ agg_weights = softmax(self.agg_alphas / self.args.temperature, dim=-1)
+ comb_weights = softmax(self.comb_alphas / self.args.temperature, dim=-1)
+ act_weights = softmax(self.act_alphas / self.args.temperature, dim=-1)
+ elif self.args.search_algorithm == "snas":
+ # comp_weights = self._get_categ_mask_new(self.comp_alphas)
+ comp_weights = self._get_categ_mask(self.comp_alphas)
+ agg_weights = self._get_categ_mask(self.agg_alphas)
+ comb_weights = self._get_categ_mask(self.comb_alphas)
+ act_weights = self._get_categ_mask(self.act_alphas)
+ else:
+ raise NotImplementedError
+ if mode == 'evaluate_single_path':
+ comp_weights = self.get_one_hot_alpha(comp_weights)
+ agg_weights = self.get_one_hot_alpha(agg_weights)
+ comb_weights = self.get_one_hot_alpha(comb_weights)
+ act_weights = self.get_one_hot_alpha(act_weights)
+ # weights = dict()
+ # weights['comp'] = comp_weights
+ x, r = self.init_embed, self.init_rel # embedding of relations
+
+ for i, layer in enumerate(self.gnn_layers):
+ if i != self.args.n_layer-1:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm, comp_weights[i], agg_weights[i], comb_weights[i], act_weights[i])
+ x = drop1(x)
+ else:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm, comp_weights[i], agg_weights[i], comb_weights[i], act_weights[i])
+ x = drop2(x)
+
+ sub_emb = torch.index_select(x, 0, subj)
+ # filter out embeddings of objects in this batch
+ obj_emb = torch.index_select(x, 0, obj)
+
+ return sub_emb, obj_emb, x, r
+
+ def forward_base_search(self, g, drop1, drop2, mode=None):
+ if self.args.search_algorithm == "darts":
+ comp_weights = softmax(self.comp_alphas / self.args.temperature, dim=-1)
+ agg_weights = softmax(self.agg_alphas / self.args.temperature, dim=-1)
+ comb_weights = softmax(self.comb_alphas / self.args.temperature, dim=-1)
+ act_weights = softmax(self.act_alphas / self.args.temperature, dim=-1)
+ elif self.args.search_algorithm == "snas":
+ comp_weights = self._get_categ_mask(self.comp_alphas)
+ agg_weights = self._get_categ_mask(self.agg_alphas)
+ comb_weights = self._get_categ_mask(self.comb_alphas)
+ act_weights = self._get_categ_mask(self.act_alphas)
+ else:
+ raise NotImplementedError
+ if mode == 'evaluate_single_path':
+ comp_weights = self.get_one_hot_alpha(comp_weights)
+ agg_weights = self.get_one_hot_alpha(agg_weights)
+ comb_weights = self.get_one_hot_alpha(comb_weights)
+ act_weights = self.get_one_hot_alpha(act_weights)
+ x, r = self.init_embed, self.init_rel # embedding of relations
+ x_hidden = []
+ for i, layer in enumerate(self.gnn_layers):
+ if i != self.args.n_layer-1:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm, comp_weights[i], agg_weights[i], comb_weights[i], act_weights[i])
+ x_hidden.append(x)
+ x = drop1(x)
+ else:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm, comp_weights[i], agg_weights[i], comb_weights[i], act_weights[i])
+ x_hidden.append(x)
+ x = drop2(x)
+ x_hidden = torch.stack(x_hidden, dim=1)
+
+ return x_hidden, x
+
+
+class SearchGCN_MLP(Network):
+ def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ bias=True, gcn_drop=0., hid_drop=0., input_drop=0., wni=False, wsi=False, use_bn=True,
+ ltr=True, combine_type='mult', loss_type='ce'):
+ """
+ :param num_ent: number of entities
+ :param num_rel: number of different relations
+ :param num_base: number of bases to use
+ :param init_dim: initial dimension
+ :param gcn_dim: dimension after first layer
+ :param embed_dim: dimension after second layer
+ :param n_layer: number of layer
+ :param edge_type: relation type of each edge, [E]
+ :param bias: weather to add bias
+ :param gcn_drop: dropout rate in compgcncov
+ :param opn: combination operator
+ :param hid_drop: gcn output (embedding of each entity) dropout
+ :param input_drop: dropout in conve input
+ :param conve_hid_drop: dropout in conve hidden layer
+ :param feat_drop: feature dropout in conve
+ :param num_filt: number of filters in conv2d
+ :param ker_sz: kernel size in conv2d
+ :param k_h: height of 2D reshape
+ :param k_w: width of 2D reshape
+ """
+ super(SearchGCN_MLP, self).__init__(args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer,
+ edge_type, edge_norm, bias, gcn_drop, wni, wsi, use_bn, ltr, loss_type)
+ self.hid_drop, self.input_drop = hid_drop, input_drop
+
+ self.num_rel = num_rel
+
+ self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout
+ self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout
+
+ self.combine_type = combine_type
+ if self.combine_type == 'concat':
+ self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel)
+ elif self.combine_type == 'mult':
+ self.fc = torch.nn.Linear(self.embed_dim, self.num_rel)
+
+ def forward(self, g, subj, obj, mode=None):
+ """
+ :param g: dgl graph
+ :param sub: subject in batch [batch_size]
+ :param rel: relation in batch [batch_size]
+ :return: score: [batch_size, ent_num], the prob in link-prediction
+ """
+
+ sub_embs, obj_embs, all_ent, all_rel = self.forward_base(g, subj, obj, self.drop, self.input_drop, mode)
+ if self.combine_type == 'concat':
+ edge_embs = torch.concat([sub_embs, obj_embs], dim=1)
+ elif self.combine_type == 'mult':
+ edge_embs = sub_embs * obj_embs
+ else:
+ raise NotImplementedError
+ score = self.fc(edge_embs)
+ return score
+
+ def forward_search(self, g, mode=None):
+ # if mode == 'allgraph':
+ hidden_all_ent, all_ent = self.forward_base_search(
+ g, self.drop, self.input_drop, mode)
+ # elif mode == 'subgraph':
+ # hidden_all_ent = self.forward_base_subgraph_search(
+ # g, self.drop, self.input_drop)
+
+ return hidden_all_ent, all_ent
+
+ def compute_pred(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = subgraph_sampler(h, mode, search_algorithm)
+ # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # print(subgraph_sampler(h,mode='argmax'))
+ n, c = atten_matrix.shape
+ h = h * atten_matrix.view(n, c, 1)
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ h = torch.sum(h, dim=1)
+ # print(h.size()) # [batch_size, 2*dim]
+ score = self.fc(h)
+ return score
+
+ # def fine_tune_with_implicit_subgraph(self, all_ent, subgraph, subj, obj):
+ # sg_list = []
+ # for idx in range(subgraph.size(0)):
+ # sg_list.append(torch.mean(all_ent[subgraph[idx,:]], dim=0).unsqueeze(0))
+ # sg_embs = torch.concat(sg_list)
+ # # print(sg_embs.size())
+ # sub_embs = torch.index_select(all_ent, 0, subj)
+ # # print(sub_embs.size())
+ # # filter out embeddings of relations in this batch
+ # obj_embs = torch.index_select(all_ent, 0, obj)
+ # # print(obj_embs.size())
+ # edge_embs = torch.concat([sub_embs, obj_embs, sg_embs], dim=1)
+ # score = self.predictor(edge_embs)
+ # # print(F.embedding(subgraph, all_ent))
+ # return score
+
+ # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops):
+ # h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0')
+ # for i in range(h.size(0)):
+ # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1
+ # # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # # print(subgraph_sampler(h,mode='argmax'))
+ # n, c = atten_matrix.shape
+ # h = h * atten_matrix.view(n,c,1)
+ # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ # h = torch.sum(h,dim=1)
+ # # print(h.size()) # [batch_size, 2*dim]
+ # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w)
+ # # x = self.bn0(h)
+ # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+ # # x = self.bn1(x)
+ # # x = F.relu(x)
+ # # x = self.feature_drop(x)
+ # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+ # # x = self.fc(x) # [batch_size, embed_dim]
+ # # x = self.hidden_drop(x)
+ # # x = self.bn2(x)
+ # # x = F.relu(x)
+ # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num]
+ # # x += self.bias_rel.expand_as(x)
+ # # # score = torch.sigmoid(x)
+ # # score = x
+ # return score
+ # print(h.size())
+
+ def cross_pair(self, x_i, x_j):
+ x = []
+ for i in range(self.n_layer):
+ for j in range(self.n_layer):
+ if self.combine_type == 'mult':
+ x.append(x_i[:, i, :] * x_j[:, j, :])
+ elif self.combine_type == 'concat':
+ x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1))
+ x = torch.stack(x, dim=1)
+ return x
+
+ def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ atten_matrix = subgraph_sampler(h, mode)
+ return torch.sum(atten_matrix, dim=0)
+
+ def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0')
+ for i in range(h.size(0)):
+ atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1
+ return torch.sum(atten_matrix, dim=0)
diff --git a/model/model_spos.py b/model/model_spos.py
new file mode 100644
index 0000000..97a8e3f
--- /dev/null
+++ b/model/model_spos.py
@@ -0,0 +1,421 @@
+import torch
+from torch import nn
+import dgl
+import dgl.function as fn
+from torch.autograd import Variable
+from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES
+from model.operations import *
+from model.search_layer_spos import SearchSPOSGCNConv
+from torch.nn.functional import softmax
+from numpy.random import choice
+from pprint import pprint
+
+
+class NetworkSPOS(nn.Module):
+ def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ conv_bias=True, gcn_drop=0., wni=False, wsi=False, use_bn=True, ltr=True, loss_type='ce'):
+ super(NetworkSPOS, self).__init__()
+ self.act = torch.tanh
+ self.args = args
+ self.loss_type = loss_type
+ self.num_ent, self.num_rel, self.num_base = num_ent, num_rel, num_base
+ self.init_dim, self.gcn_dim, self.embed_dim = init_dim, gcn_dim, embed_dim
+
+ self.gcn_drop = gcn_drop
+ self.edge_type = edge_type # [E]
+ self.edge_norm = edge_norm # [E]
+ self.n_layer = n_layer
+
+ self.init_embed = self.get_param([self.num_ent + 1, self.init_dim])
+ self.init_rel = self.get_param([self.num_rel, self.init_dim])
+ self.bias_rel = nn.Parameter(torch.zeros(self.num_rel))
+
+ self._initialize_loss()
+ self.ops = None
+ self.gnn_layers = nn.ModuleList()
+ for idx in range(self.args.n_layer):
+ if idx == 0:
+ self.gnn_layers.append(
+ SearchSPOSGCNConv(self.init_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr))
+ elif idx == self.args.n_layer-1:
+ self.gnn_layers.append(SearchSPOSGCNConv(self.gcn_dim, self.embed_dim, self.act, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr))
+ else:
+ self.gnn_layers.append(
+ SearchSPOSGCNConv(self.gcn_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr))
+
+ def generate_single_path(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append(choice(COMP_PRIMITIVES))
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append(choice(ACT_PRIMITIVES))
+ return single_path
+
+ def generate_single_path_ccorr(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append('ccorr')
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append(choice(ACT_PRIMITIVES))
+ return single_path
+
+ def generate_single_path_act(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append('ccorr')
+ single_path.append('sum')
+ single_path.append('add')
+ single_path.append(choice(ACT_PRIMITIVES))
+ return single_path
+
+ def generate_single_path_comb(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append('ccorr')
+ single_path.append('sum')
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append('tanh')
+ return single_path
+
+ def generate_single_path_comp(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append(choice(COMP_PRIMITIVES))
+ single_path.append('sum')
+ single_path.append('add')
+ single_path.append('tanh')
+ return single_path
+
+ def generate_single_path_agg(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append('ccorr')
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append('add')
+ single_path.append('tanh')
+ return single_path
+
+ def generate_single_path_agg_comb(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append('ccorr')
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append('tanh')
+ return single_path
+
+ def generate_single_path_agg_comb_comp(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append(choice(COMP_PRIMITIVES))
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append('relu')
+ return single_path
+
+ def generate_single_path_agg_comb_act_rotate(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append('rotate')
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append(choice(ACT_PRIMITIVES))
+ return single_path
+
+ def generate_single_path_agg_comb_act_ccorr(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append('ccorr')
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append(choice(ACT_PRIMITIVES))
+ return single_path
+
+ def generate_single_path_agg_comb_act_mult(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append('mult')
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append(choice(ACT_PRIMITIVES))
+ return single_path
+
+ def generate_single_path_agg_comb_act_sub(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append('sub')
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append(choice(ACT_PRIMITIVES))
+ return single_path
+
+ def generate_single_path_agg_comb_act_1mult(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ if i == 0:
+ single_path.append('mult')
+ else:
+ single_path.append(choice(COMP_PRIMITIVES))
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append(choice(ACT_PRIMITIVES))
+ return single_path
+
+ def generate_single_path_agg_comb_act_1ccorr(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ if i == 0:
+ single_path.append('ccorr')
+ else:
+ single_path.append(choice(COMP_PRIMITIVES))
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append(choice(ACT_PRIMITIVES))
+ return single_path
+
+ def generate_single_path_agg_comb_act_few_shot_comp(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ if i == 0:
+ single_path.append(op_subsupernet)
+ else:
+ single_path.append(choice(COMP_PRIMITIVES))
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append(choice(ACT_PRIMITIVES))
+ return single_path
+
+ def get_param(self, shape):
+ param = nn.Parameter(torch.Tensor(*shape))
+ nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu'))
+ return param
+
+ def _initialize_loss(self):
+ if self.loss_type == 'ce':
+ self.loss = nn.CrossEntropyLoss()
+ elif self.loss_type == 'bce':
+ self.loss = nn.BCELoss()
+ elif self.loss_type == 'bce_logits':
+ self.loss = nn.BCEWithLogitsLoss()
+ else:
+ raise NotImplementedError
+
+ def calc_loss(self, pred, label, pos_neg=None):
+ if pos_neg is not None:
+ m = nn.Sigmoid()
+ score_pos = m(pred)
+ targets_pos = pos_neg.unsqueeze(1)
+ loss = self.loss(score_pos, label * targets_pos)
+ return torch.sum(loss * label)
+ return self.loss(pred, label)
+
+ def forward_base(self, g, subj, obj, drop1, drop2, mode=None):
+ x, r = self.init_embed, self.init_rel # embedding of relations
+
+ for i, layer in enumerate(self.gnn_layers):
+ if i != self.args.n_layer-1:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm, self.ops[4*i], self.ops[4*i+1], self.ops[4*i+2], self.ops[4*i+3])
+ x = drop1(x)
+ else:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm, self.ops[4*i], self.ops[4*i+1], self.ops[4*i+2], self.ops[4*i+3])
+ x = drop2(x)
+
+ sub_emb = torch.index_select(x, 0, subj)
+ # filter out embeddings of objects in this batch
+ obj_emb = torch.index_select(x, 0, obj)
+
+ return sub_emb, obj_emb, x, r
+
+ def forward_base_search(self, g, drop1, drop2, mode=None):
+ # if self.args.search_algorithm == "darts":
+ # comp_weights = softmax(self.comp_alphas, dim=-1)
+ # agg_weights = softmax(self.agg_alphas, dim=-1)
+ # comb_weights = softmax(self.comb_alphas, dim=-1)
+ # act_weights = softmax(self.act_alphas, dim=-1)
+ # elif self.args.search_algorithm == "snas":
+ # comp_weights = self._get_categ_mask(self.comp_alphas)
+ # agg_weights = self._get_categ_mask(self.agg_alphas)
+ # comb_weights = self._get_categ_mask(self.comb_alphas)
+ # act_weights = self._get_categ_mask(self.act_alphas)
+ # else:
+ # raise NotImplementedError
+ # if mode == 'evaluate_single_path':
+ # comp_weights = self.get_one_hot_alpha(comp_weights)
+ # agg_weights = self.get_one_hot_alpha(agg_weights)
+ # comb_weights = self.get_one_hot_alpha(comb_weights)
+ # act_weights = self.get_one_hot_alpha(act_weights)
+ x, r = self.init_embed, self.init_rel # embedding of relations
+ x_hidden = []
+ for i, layer in enumerate(self.gnn_layers):
+ if i != self.args.n_layer-1:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm, self.ops[4 * i], self.ops[4 * i + 1],
+ self.ops[4 * i + 2], self.ops[4 * i + 3])
+ x_hidden.append(x)
+ x = drop1(x)
+ else:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm, self.ops[4 * i], self.ops[4 * i + 1],
+ self.ops[4 * i + 2], self.ops[4 * i + 3])
+ x_hidden.append(x)
+ x = drop2(x)
+ x_hidden = torch.stack(x_hidden, dim=1)
+
+ return x_hidden, x
+
+
+class SearchGCN_MLP_SPOS(NetworkSPOS):
+ def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ bias=True, gcn_drop=0., hid_drop=0., input_drop=0., wni=False, wsi=False, use_bn=True,
+ ltr=True, combine_type='mult', loss_type='ce'):
+ """
+ :param num_ent: number of entities
+ :param num_rel: number of different relations
+ :param num_base: number of bases to use
+ :param init_dim: initial dimension
+ :param gcn_dim: dimension after first layer
+ :param embed_dim: dimension after second layer
+ :param n_layer: number of layer
+ :param edge_type: relation type of each edge, [E]
+ :param bias: weather to add bias
+ :param gcn_drop: dropout rate in compgcncov
+ :param opn: combination operator
+ :param hid_drop: gcn output (embedding of each entity) dropout
+ :param input_drop: dropout in conve input
+ :param conve_hid_drop: dropout in conve hidden layer
+ :param feat_drop: feature dropout in conve
+ :param num_filt: number of filters in conv2d
+ :param ker_sz: kernel size in conv2d
+ :param k_h: height of 2D reshape
+ :param k_w: width of 2D reshape
+ """
+ super(SearchGCN_MLP_SPOS, self).__init__(args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer,
+ edge_type, edge_norm, bias, gcn_drop, wni, wsi, use_bn, ltr, loss_type)
+ self.hid_drop, self.input_drop = hid_drop, input_drop
+
+ self.num_rel = num_rel
+
+ self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout
+ self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout
+
+ self.combine_type = combine_type
+ if self.combine_type == 'concat':
+ self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel)
+ elif self.combine_type == 'mult':
+ self.fc = torch.nn.Linear(self.embed_dim, self.num_rel)
+
+ def forward(self, g, subj, obj, mode=None):
+ """
+ :param g: dgl graph
+ :param sub: subject in batch [batch_size]
+ :param rel: relation in batch [batch_size]
+ :return: score: [batch_size, ent_num], the prob in link-prediction
+ """
+
+ sub_embs, obj_embs, all_ent, all_rel = self.forward_base(g, subj, obj, self.drop, self.input_drop, mode)
+ if self.combine_type == 'concat':
+ edge_embs = torch.concat([sub_embs, obj_embs], dim=1)
+ elif self.combine_type == 'mult':
+ edge_embs = sub_embs * obj_embs
+ else:
+ raise NotImplementedError
+ score = self.fc(edge_embs)
+ return score
+
+ def forward_search(self, g, mode=None):
+ # if mode == 'allgraph':
+ hidden_all_ent, all_ent = self.forward_base_search(
+ g, self.drop, self.input_drop, mode)
+ # elif mode == 'subgraph':
+ # hidden_all_ent = self.forward_base_subgraph_search(
+ # g, self.drop, self.input_drop)
+
+ return hidden_all_ent, all_ent
+
+ def compute_pred(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = subgraph_sampler(h, mode, search_algorithm)
+ # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # print(subgraph_sampler(h,mode='argmax'))
+ n, c = atten_matrix.shape
+ h = h * atten_matrix.view(n, c, 1)
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ h = torch.sum(h, dim=1)
+ # print(h.size()) # [batch_size, 2*dim]
+ score = self.fc(h)
+ return score
+
+ # def fine_tune_with_implicit_subgraph(self, all_ent, subgraph, subj, obj):
+ # sg_list = []
+ # for idx in range(subgraph.size(0)):
+ # sg_list.append(torch.mean(all_ent[subgraph[idx,:]], dim=0).unsqueeze(0))
+ # sg_embs = torch.concat(sg_list)
+ # # print(sg_embs.size())
+ # sub_embs = torch.index_select(all_ent, 0, subj)
+ # # print(sub_embs.size())
+ # # filter out embeddings of relations in this batch
+ # obj_embs = torch.index_select(all_ent, 0, obj)
+ # # print(obj_embs.size())
+ # edge_embs = torch.concat([sub_embs, obj_embs, sg_embs], dim=1)
+ # score = self.predictor(edge_embs)
+ # # print(F.embedding(subgraph, all_ent))
+ # return score
+
+ # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops):
+ # h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0')
+ # for i in range(h.size(0)):
+ # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1
+ # # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # # print(subgraph_sampler(h,mode='argmax'))
+ # n, c = atten_matrix.shape
+ # h = h * atten_matrix.view(n,c,1)
+ # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ # h = torch.sum(h,dim=1)
+ # # print(h.size()) # [batch_size, 2*dim]
+ # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w)
+ # # x = self.bn0(h)
+ # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+ # # x = self.bn1(x)
+ # # x = F.relu(x)
+ # # x = self.feature_drop(x)
+ # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+ # # x = self.fc(x) # [batch_size, embed_dim]
+ # # x = self.hidden_drop(x)
+ # # x = self.bn2(x)
+ # # x = F.relu(x)
+ # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num]
+ # # x += self.bias_rel.expand_as(x)
+ # # # score = torch.sigmoid(x)
+ # # score = x
+ # return score
+ # print(h.size())
+
+ def cross_pair(self, x_i, x_j):
+ x = []
+ for i in range(self.n_layer):
+ for j in range(self.n_layer):
+ if self.combine_type == 'mult':
+ x.append(x_i[:, i, :] * x_j[:, j, :])
+ elif self.combine_type == 'concat':
+ x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1))
+ x = torch.stack(x, dim=1)
+ return x
+
+ def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ atten_matrix = subgraph_sampler(h, mode, search_algorithm)
+ return torch.sum(atten_matrix, dim=0)
+
+ def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0')
+ for i in range(h.size(0)):
+ atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1
+ return torch.sum(atten_matrix, dim=0)
diff --git a/model/model_spos_fast.py b/model/model_spos_fast.py
new file mode 100644
index 0000000..545752f
--- /dev/null
+++ b/model/model_spos_fast.py
@@ -0,0 +1,287 @@
+import torch
+from torch import nn
+import dgl
+import dgl.function as fn
+from torch.autograd import Variable
+from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES
+from model.operations import *
+from model.search_layer_spos import SearchSPOSGCNConv
+from torch.nn.functional import softmax
+from numpy.random import choice
+from pprint import pprint
+
+
+class NetworkSPOS(nn.Module):
+ def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ conv_bias=True, gcn_drop=0., wni=False, wsi=False, use_bn=True, ltr=True, loss_type='ce'):
+ super(NetworkSPOS, self).__init__()
+ self.act = torch.tanh
+ self.args = args
+ self.loss_type = loss_type
+ self.num_ent, self.num_rel, self.num_base = num_ent, num_rel, num_base
+ self.init_dim, self.gcn_dim, self.embed_dim = init_dim, gcn_dim, embed_dim
+
+ self.gcn_drop = gcn_drop
+ self.edge_type = edge_type # [E]
+ self.edge_norm = edge_norm # [E]
+ self.n_layer = n_layer
+
+ self.init_embed = self.get_param([self.num_ent + 1, self.init_dim])
+ self.init_rel = self.get_param([self.num_rel, self.init_dim])
+ self.bias_rel = nn.Parameter(torch.zeros(self.num_rel))
+
+ self._initialize_loss()
+ self.ops = None
+ self.gnn_layers = nn.ModuleList()
+ for idx in range(self.args.n_layer):
+ if idx == 0:
+ self.gnn_layers.append(
+ SearchSPOSGCNConv(self.init_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr))
+ elif idx == self.args.n_layer-1:
+ self.gnn_layers.append(SearchSPOSGCNConv(self.gcn_dim, self.embed_dim, self.act, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr))
+ else:
+ self.gnn_layers.append(
+ SearchSPOSGCNConv(self.gcn_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, num_base=-1,
+ num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr))
+
+ def generate_single_path(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append(choice(COMP_PRIMITIVES))
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append(choice(ACT_PRIMITIVES))
+ return single_path
+
+ def generate_single_path_ccorr(self, op_subsupernet=''):
+ single_path = []
+ for i in range(self.args.n_layer):
+ single_path.append('ccorr')
+ single_path.append(choice(AGG_PRIMITIVES))
+ single_path.append(choice(COMB_PRIMITIVES))
+ single_path.append(choice(ACT_PRIMITIVES))
+ return single_path
+
+ def get_param(self, shape):
+ param = nn.Parameter(torch.Tensor(*shape))
+ nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu'))
+ return param
+
+ def _initialize_loss(self):
+ if self.loss_type == 'ce':
+ self.loss = nn.CrossEntropyLoss()
+ elif self.loss_type == 'bce':
+ self.loss = nn.BCELoss()
+ elif self.loss_type == 'bce_logits':
+ self.loss = nn.BCEWithLogitsLoss()
+ else:
+ raise NotImplementedError
+
+ def calc_loss(self, pred, label):
+ return self.loss(pred, label)
+
+ def forward_base(self, g, subj, obj, drop1, drop2, mode=None):
+ x, r = self.init_embed, self.init_rel # embedding of relations
+
+ for i, layer in enumerate(self.gnn_layers):
+ if i != self.args.n_layer-1:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm, self.ops[4*i], self.ops[4*i+1], self.ops[4*i+2], self.ops[4*i+3])
+ x = drop1(x)
+ else:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm, self.ops[4*i], self.ops[4*i+1], self.ops[4*i+2], self.ops[4*i+3])
+ x = drop2(x)
+
+ sub_emb = torch.index_select(x, 0, subj)
+ # filter out embeddings of objects in this batch
+ obj_emb = torch.index_select(x, 0, obj)
+
+ return sub_emb, obj_emb, x, r
+
+ def forward_base_search(self, g, drop1, drop2, mode=None):
+ if self.args.search_algorithm == "darts":
+ comp_weights = softmax(self.comp_alphas, dim=-1)
+ agg_weights = softmax(self.agg_alphas, dim=-1)
+ comb_weights = softmax(self.comb_alphas, dim=-1)
+ act_weights = softmax(self.act_alphas, dim=-1)
+ elif self.args.search_algorithm == "snas":
+ comp_weights = self._get_categ_mask(self.comp_alphas)
+ agg_weights = self._get_categ_mask(self.agg_alphas)
+ comb_weights = self._get_categ_mask(self.comb_alphas)
+ act_weights = self._get_categ_mask(self.act_alphas)
+ else:
+ raise NotImplementedError
+ if mode == 'evaluate_single_path':
+ comp_weights = self.get_one_hot_alpha(comp_weights)
+ agg_weights = self.get_one_hot_alpha(agg_weights)
+ comb_weights = self.get_one_hot_alpha(comb_weights)
+ act_weights = self.get_one_hot_alpha(act_weights)
+ x, r = self.init_embed, self.init_rel # embedding of relations
+ x_hidden = []
+ for i, layer in enumerate(self.gnn_layers):
+ if i != self.args.n_layer-1:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm, comp_weights[i], agg_weights[i], comb_weights[i], act_weights[i])
+ x_hidden.append(x)
+ x = drop1(x)
+ else:
+ x, r = layer(g, x, r, self.edge_type, self.edge_norm, comp_weights[i], agg_weights[i], comb_weights[i], act_weights[i])
+ x_hidden.append(x)
+ x = drop2(x)
+ x_hidden = torch.stack(x_hidden, dim=1)
+
+ return x_hidden, x
+
+
+class SearchGCN_MLP_SPOS(NetworkSPOS):
+ def __init__(self, args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm,
+ bias=True, gcn_drop=0., hid_drop=0., input_drop=0., wni=False, wsi=False, use_bn=True,
+ ltr=True, combine_type='mult', loss_type='ce'):
+ """
+ :param num_ent: number of entities
+ :param num_rel: number of different relations
+ :param num_base: number of bases to use
+ :param init_dim: initial dimension
+ :param gcn_dim: dimension after first layer
+ :param embed_dim: dimension after second layer
+ :param n_layer: number of layer
+ :param edge_type: relation type of each edge, [E]
+ :param bias: weather to add bias
+ :param gcn_drop: dropout rate in compgcncov
+ :param opn: combination operator
+ :param hid_drop: gcn output (embedding of each entity) dropout
+ :param input_drop: dropout in conve input
+ :param conve_hid_drop: dropout in conve hidden layer
+ :param feat_drop: feature dropout in conve
+ :param num_filt: number of filters in conv2d
+ :param ker_sz: kernel size in conv2d
+ :param k_h: height of 2D reshape
+ :param k_w: width of 2D reshape
+ """
+ super(SearchGCN_MLP_SPOS, self).__init__(args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer,
+ edge_type, edge_norm, bias, gcn_drop, wni, wsi, use_bn, ltr, loss_type)
+ self.hid_drop, self.input_drop = hid_drop, input_drop
+
+ self.num_rel = num_rel
+
+ self.drop = torch.nn.Dropout(self.hid_drop) # gcn output dropout
+ self.input_drop = torch.nn.Dropout(self.input_drop) # stacked input dropout
+
+ self.combine_type = combine_type
+ if self.combine_type == 'concat':
+ self.fc = torch.nn.Linear(2 * self.embed_dim, self.num_rel)
+ elif self.combine_type == 'mult':
+ self.fc = torch.nn.Linear(self.embed_dim, self.num_rel)
+
+ def forward(self, g, subj, obj, mode=None):
+ """
+ :param g: dgl graph
+ :param sub: subject in batch [batch_size]
+ :param rel: relation in batch [batch_size]
+ :return: score: [batch_size, ent_num], the prob in link-prediction
+ """
+
+ sub_embs, obj_embs, all_ent, all_rel = self.forward_base(g, subj, obj, self.drop, self.input_drop, mode)
+ if self.combine_type == 'concat':
+ edge_embs = torch.concat([sub_embs, obj_embs], dim=1)
+ elif self.combine_type == 'mult':
+ edge_embs = sub_embs * obj_embs
+ else:
+ raise NotImplementedError
+ score = self.fc(edge_embs)
+ return score
+
+ def forward_search(self, g, mode=None):
+ # if mode == 'allgraph':
+ hidden_all_ent, all_ent = self.forward_base_search(
+ g, self.drop, self.input_drop, mode)
+ # elif mode == 'subgraph':
+ # hidden_all_ent = self.forward_base_subgraph_search(
+ # g, self.drop, self.input_drop)
+
+ return hidden_all_ent, all_ent
+
+ def compute_pred(self, hidden_x, subj, obj, subgraph_sampler, mode='search', search_algorithm='darts'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = subgraph_sampler(h, mode, search_algorithm)
+ # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # print(subgraph_sampler(h,mode='argmax'))
+ n, c = atten_matrix.shape
+ h = h * atten_matrix.view(n, c, 1)
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ h = torch.sum(h, dim=1)
+ # print(h.size()) # [batch_size, 2*dim]
+ score = self.fc(h)
+ return score
+
+ # def fine_tune_with_implicit_subgraph(self, all_ent, subgraph, subj, obj):
+ # sg_list = []
+ # for idx in range(subgraph.size(0)):
+ # sg_list.append(torch.mean(all_ent[subgraph[idx,:]], dim=0).unsqueeze(0))
+ # sg_embs = torch.concat(sg_list)
+ # # print(sg_embs.size())
+ # sub_embs = torch.index_select(all_ent, 0, subj)
+ # # print(sub_embs.size())
+ # # filter out embeddings of relations in this batch
+ # obj_embs = torch.index_select(all_ent, 0, obj)
+ # # print(obj_embs.size())
+ # edge_embs = torch.concat([sub_embs, obj_embs, sg_embs], dim=1)
+ # score = self.predictor(edge_embs)
+ # # print(F.embedding(subgraph, all_ent))
+ # return score
+
+ # def compute_pred_rs(self, hidden_x, all_rel, subj, obj, random_hops):
+ # h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ # atten_matrix = torch.zeros(h.size(0),h.size(1)).to('cuda:0')
+ # for i in range(h.size(0)):
+ # atten_matrix[i][self.n_layer*(random_hops[i][0]-1)+random_hops[i][1]-1] = 1
+ # # print(atten_matrix.size()) # [batch_size, encoder_layer^2]
+ # # print(subgraph_sampler(h,mode='argmax'))
+ # n, c = atten_matrix.shape
+ # h = h * atten_matrix.view(n,c,1)
+ # # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ # h = torch.sum(h,dim=1)
+ # # print(h.size()) # [batch_size, 2*dim]
+ # h = h.reshape(-1, 1, 2 * self.k_h, self.k_w)
+ # # x = self.bn0(h)
+ # # x = self.conv2d(x) # [batch_size, num_filt, flat_sz_h, flat_sz_w]
+ # # x = self.bn1(x)
+ # # x = F.relu(x)
+ # # x = self.feature_drop(x)
+ # # x = x.view(-1, self.flat_sz) # [batch_size, flat_sz]
+ # # x = self.fc(x) # [batch_size, embed_dim]
+ # # x = self.hidden_drop(x)
+ # # x = self.bn2(x)
+ # # x = F.relu(x)
+ # # x = torch.mm(x, all_rel.transpose(1, 0)) # [batch_size, ent_num]
+ # # x += self.bias_rel.expand_as(x)
+ # # # score = torch.sigmoid(x)
+ # # score = x
+ # return score
+ # print(h.size())
+
+ def cross_pair(self, x_i, x_j):
+ x = []
+ for i in range(self.n_layer):
+ for j in range(self.n_layer):
+ if self.combine_type == 'mult':
+ x.append(x_i[:, i, :] * x_j[:, j, :])
+ elif self.combine_type == 'concat':
+ x.append(torch.cat([x_i[:, i, :], x_j[:, j, :]], dim=1))
+ x = torch.stack(x, dim=1)
+ return x
+
+ def vis_hop_distribution(self, hidden_x, subj, obj, subgraph_sampler, mode='search'):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ atten_matrix = subgraph_sampler(h, mode)
+ return torch.sum(atten_matrix, dim=0)
+
+ def vis_hop_distribution_rs(self, hidden_x, subj, obj, random_hops):
+ h = self.cross_pair(hidden_x[subj], hidden_x[obj])
+ # print(h.size()) # [batch_size, encoder_layer^2, 2*dim]
+ atten_matrix = torch.zeros(h.size(0), h.size(1)).to('cuda:0')
+ for i in range(h.size(0)):
+ atten_matrix[i][self.n_layer * (random_hops[i][0] - 1) + random_hops[i][1] - 1] = 1
+ return torch.sum(atten_matrix, dim=0)
diff --git a/model/operations.py b/model/operations.py
new file mode 100644
index 0000000..7f21f40
--- /dev/null
+++ b/model/operations.py
@@ -0,0 +1,148 @@
+import torch
+from torch.nn import Module, Linear
+
+COMP_OPS = {
+ 'mult': lambda: MultOp(),
+ 'sub': lambda: SubOp(),
+ 'add': lambda: AddOp(),
+ 'ccorr': lambda: CcorrOp(),
+ 'rotate': lambda: RotateOp()
+}
+
+AGG_OPS = {
+ 'max': lambda: MaxOp(),
+ 'sum': lambda: SumOp(),
+ 'mean': lambda: MeanOp()
+}
+
+COMB_OPS = {
+ 'add': lambda out_channels: CombAddOp(out_channels),
+ 'mlp': lambda out_channels: CombMLPOp(out_channels),
+ 'concat': lambda out_channels: CombConcatOp(out_channels)
+}
+
+ACT_OPS = {
+ 'identity': lambda: torch.nn.Identity(),
+ 'relu': lambda: torch.nn.ReLU(),
+ 'tanh': lambda: torch.nn.Tanh(),
+}
+
+class CcorrOp(Module):
+ def __init__(self):
+ super(CcorrOp, self).__init__()
+
+ def forward(self, src_emb, rel_emb):
+ return self.comp(src_emb, rel_emb)
+
+ def comp(self, h, edge_data):
+ def com_mult(a, b):
+ r1, i1 = a.real, a.imag
+ r2, i2 = b.real, b.imag
+ real = r1 * r2 - i1 * i2
+ imag = r1 * i2 + i1 * r2
+ return torch.complex(real, imag)
+
+ def conj(a):
+ a.imag = -a.imag
+ return a
+
+ def ccorr(a, b):
+ return torch.fft.irfft(com_mult(conj(torch.fft.rfft(a)), torch.fft.rfft(b)), a.shape[-1])
+
+ return ccorr(h, edge_data.expand_as(h))
+
+
+class MultOp(Module):
+
+ def __init__(self):
+ super(MultOp, self).__init__()
+
+ def forward(self, src_emb, rel_emb):
+ # print('hr.shape', hr.shape)
+ return src_emb * rel_emb
+
+
+class SubOp(Module):
+ def __init__(self):
+ super(SubOp, self).__init__()
+
+ def forward(self, src_emb, rel_emb):
+ # print('hr.shape', hr.shape)
+ return src_emb - rel_emb
+
+
+class AddOp(Module):
+ def __init__(self):
+ super(AddOp, self).__init__()
+
+ def forward(self, src_emb, rel_emb):
+ # print('hr.shape', hr.shape)
+ return src_emb + rel_emb
+
+
+class RotateOp(Module):
+ def __init__(self):
+ super(RotateOp, self).__init__()
+
+ def forward(self, src_emb, rel_emb):
+ # print('hr.shape', hr.shape)
+ return self.rotate(src_emb, rel_emb)
+
+ def rotate(self, h, r):
+ # re: first half, im: second half
+ # assume embedding dim is the last dimension
+ d = h.shape[-1]
+ h_re, h_im = torch.split(h, d // 2, -1)
+ r_re, r_im = torch.split(r, d // 2, -1)
+ return torch.cat([h_re * r_re - h_im * r_im,
+ h_re * r_im + h_im * r_re], dim=-1)
+
+
+class MaxOp(Module):
+ def __init__(self):
+ super(MaxOp, self).__init__()
+
+ def forward(self, msg):
+ return torch.max(msg, dim=1)[0]
+
+
+class SumOp(Module):
+ def __init__(self):
+ super(SumOp, self).__init__()
+
+ def forward(self, msg):
+ return torch.sum(msg, dim=1)
+
+
+class MeanOp(Module):
+ def __init__(self):
+ super(MeanOp, self).__init__()
+
+ def forward(self, msg):
+ return torch.mean(msg, dim=1)
+
+
+class CombAddOp(Module):
+ def __init__(self, out_channels):
+ super(CombAddOp, self).__init__()
+
+ def forward(self, self_emb, msg):
+ return self_emb + msg
+
+
+class CombMLPOp(Module):
+ def __init__(self, out_channels):
+ super(CombMLPOp, self).__init__()
+ self.linear = Linear(out_channels, out_channels)
+
+ def forward(self, self_emb, msg):
+ return self.linear(self_emb + msg)
+
+
+class CombConcatOp(Module):
+ def __init__(self, out_channels):
+ super(CombConcatOp, self).__init__()
+ self.linear = Linear(2*out_channels, out_channels)
+
+ def forward(self, self_emb, msg):
+ return self.linear(torch.concat([self_emb,msg],dim=1))
\ No newline at end of file
diff --git a/model/rgcn_layer.py b/model/rgcn_layer.py
new file mode 100644
index 0000000..25161f5
--- /dev/null
+++ b/model/rgcn_layer.py
@@ -0,0 +1,417 @@
+"""
+based on the implementation in DGL
+(https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/relgraphconv.py)
+"""
+
+
+"""Torch Module for Relational graph convolution layer"""
+# pylint: disable= no-member, arguments-differ, invalid-name
+
+
+
+
+import functools
+import numpy as np
+import torch as th
+from torch import nn
+import dgl.function as fn
+from dgl.nn.pytorch import utils
+from dgl.base import DGLError
+from dgl.subgraph import edge_subgraph
+class RelGraphConv(nn.Module):
+ r"""Relational graph convolution layer.
+
+ Relational graph convolution is introduced in "`Modeling Relational Data with Graph
+ Convolutional Networks `__"
+ and can be described in DGL as below:
+
+ .. math::
+
+ h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}}
+ \sum_{j\in\mathcal{N}^r(i)}e_{j,i}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})
+
+ where :math:`\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation
+ :math:`r`. :math:`e_{j,i}` is the normalizer. :math:`\sigma` is an activation
+ function. :math:`W_0` is the self-loop weight.
+
+ The basis regularization decomposes :math:`W_r` by:
+
+ .. math::
+
+ W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}
+
+ where :math:`B` is the number of bases, :math:`V_b^{(l)}` are linearly combined
+ with coefficients :math:`a_{rb}^{(l)}`.
+
+ The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B`
+ number of block diagonal matrices. We refer :math:`B` as the number of bases.
+
+ The block regularization decomposes :math:`W_r` by:
+
+ .. math::
+
+ W_r^{(l)} = \oplus_{b=1}^B Q_{rb}^{(l)}
+
+ where :math:`B` is the number of bases, :math:`Q_{rb}^{(l)}` are block
+ bases with shape :math:`R^{(d^{(l+1)}/B)*(d^{l}/B)}`.
+
+ Parameters
+ ----------
+ in_feat : int
+ Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.
+ out_feat : int
+ Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.
+ num_rels : int
+ Number of relations. .
+ regularizer : str
+ Which weight regularizer to use "basis" or "bdd".
+ "basis" is short for basis-diagonal-decomposition.
+ "bdd" is short for block-diagonal-decomposition.
+ num_bases : int, optional
+ Number of bases. If is none, use number of relations. Default: ``None``.
+ bias : bool, optional
+ True if bias is added. Default: ``True``.
+ activation : callable, optional
+ Activation function. Default: ``None``.
+ self_loop : bool, optional
+ True to include self loop message. Default: ``True``.
+ low_mem : bool, optional
+ True to use low memory implementation of relation message passing function. Default: False.
+ This option trades speed with memory consumption, and will slowdown the forward/backward.
+ Turn it on when you encounter OOM problem during training or evaluation. Default: ``False``.
+ dropout : float, optional
+ Dropout rate. Default: ``0.0``
+ layer_norm: float, optional
+ Add layer norm. Default: ``False``
+
+ Examples
+ --------
+ >>> import dgl
+ >>> import numpy as np
+ >>> import torch as th
+ >>> from dgl.nn import RelGraphConv
+ >>>
+ >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
+ >>> feat = th.ones(6, 10)
+ >>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2)
+ >>> conv.weight.shape
+ torch.Size([2, 10, 2])
+ >>> etype = th.tensor(np.array([0,1,2,0,1,2]).astype(np.int64))
+ >>> res = conv(g, feat, etype)
+ >>> res
+ tensor([[ 0.3996, -2.3303],
+ [-0.4323, -0.1440],
+ [ 0.3996, -2.3303],
+ [ 2.1046, -2.8654],
+ [-0.4323, -0.1440],
+ [-0.1309, -1.0000]], grad_fn=)
+
+ >>> # One-hot input
+ >>> one_hot_feat = th.tensor(np.array([0,1,2,3,4,5]).astype(np.int64))
+ >>> res = conv(g, one_hot_feat, etype)
+ >>> res
+ tensor([[ 0.5925, 0.0985],
+ [-0.3953, 0.8408],
+ [-0.9819, 0.5284],
+ [-1.0085, -0.1721],
+ [ 0.5962, 1.2002],
+ [ 0.0365, -0.3532]], grad_fn=)
+ """
+
+ def __init__(self,
+ in_feat,
+ out_feat,
+ num_rels,
+ regularizer="basis",
+ num_bases=None,
+ bias=True,
+ activation=None,
+ self_loop=True,
+ low_mem=False,
+ dropout=0.0,
+ layer_norm=False,
+ wni=False):
+ super(RelGraphConv, self).__init__()
+ self.in_feat = in_feat
+ self.out_feat = out_feat
+ self.num_rels = num_rels
+ self.regularizer = regularizer
+ self.num_bases = num_bases
+ if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases <= 0:
+ self.num_bases = self.num_rels
+ self.bias = bias
+ self.activation = activation
+ self.self_loop = self_loop
+ self.low_mem = low_mem
+ self.layer_norm = layer_norm
+
+ self.wni = wni
+
+ if regularizer == "basis":
+ # add basis weights
+ self.weight = nn.Parameter(
+ th.Tensor(self.num_bases, self.in_feat, self.out_feat))
+ if self.num_bases < self.num_rels:
+ # linear combination coefficients
+ self.w_comp = nn.Parameter(
+ th.Tensor(self.num_rels, self.num_bases))
+ nn.init.xavier_uniform_(
+ self.weight, gain=nn.init.calculate_gain('relu'))
+ if self.num_bases < self.num_rels:
+ nn.init.xavier_uniform_(self.w_comp,
+ gain=nn.init.calculate_gain('relu'))
+ # message func
+ self.message_func = self.basis_message_func
+ elif regularizer == "bdd":
+ print(in_feat)
+ print(out_feat)
+ if in_feat % self.num_bases != 0 or out_feat % self.num_bases != 0:
+ raise ValueError(
+ 'Feature size must be a multiplier of num_bases (%d).'
+ % self.num_bases
+ )
+ # add block diagonal weights
+ self.submat_in = in_feat // self.num_bases
+ self.submat_out = out_feat // self.num_bases
+
+ # assuming in_feat and out_feat are both divisible by num_bases
+ self.weight = nn.Parameter(th.Tensor(
+ self.num_rels, self.num_bases * self.submat_in * self.submat_out))
+ nn.init.xavier_uniform_(
+ self.weight, gain=nn.init.calculate_gain('relu'))
+ # message func
+ self.message_func = self.bdd_message_func
+ else:
+ raise ValueError("Regularizer must be either 'basis' or 'bdd'")
+
+ # bias
+ if self.bias:
+ self.h_bias = nn.Parameter(th.Tensor(out_feat))
+ nn.init.zeros_(self.h_bias)
+
+ # layer norm
+ if self.layer_norm:
+ self.layer_norm_weight = nn.LayerNorm(
+ out_feat, elementwise_affine=True)
+
+ # weight for self loop
+ if self.self_loop:
+ self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
+ nn.init.xavier_uniform_(self.loop_weight,
+ gain=nn.init.calculate_gain('relu'))
+
+ self.dropout = nn.Dropout(dropout)
+
+ def basis_message_func(self, edges, etypes):
+ """Message function for basis regularizer.
+
+ Parameters
+ ----------
+ edges : dgl.EdgeBatch
+ Input to DGL message UDF.
+ etypes : torch.Tensor or list[int]
+ Edge type data. Could be either:
+
+ * An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID.
+ Preferred format if ``lowmem == False``.
+ * An integer list. The i^th element is the number of edges of the i^th type.
+ This requires the input graph to store edges sorted by their type IDs.
+ Preferred format if ``lowmem == True``.
+ """
+ if self.num_bases < self.num_rels:
+ # generate all weights from bases
+ weight = self.weight.view(self.num_bases,
+ self.in_feat * self.out_feat)
+ weight = th.matmul(self.w_comp, weight).view(
+ self.num_rels, self.in_feat, self.out_feat)
+ else:
+ weight = self.weight
+
+ h = edges.src['h']
+ device = h.device
+
+ if h.dtype == th.int64 and h.ndim == 1:
+ # Each element is the node's ID. Use index select: weight[etypes, h, :]
+ # The following is a faster version of it.
+ if isinstance(etypes, list):
+ etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
+ th.tensor(etypes, device=device))
+ idim = weight.shape[1]
+ weight = weight.view(-1, weight.shape[2])
+ flatidx = etypes * idim + h
+ msg = weight.index_select(0, flatidx)
+ elif self.low_mem:
+ # A more memory-friendly implementation.
+ # Calculate msg @ W_r before put msg into edge.
+ assert isinstance(etypes, list)
+ h_t = th.split(h, etypes)
+ msg = []
+ for etype in range(self.num_rels):
+ if h_t[etype].shape[0] == 0:
+ continue
+ msg.append(th.matmul(h_t[etype], weight[etype]))
+ msg = th.cat(msg)
+ else:
+ # Use batched matmult
+ if isinstance(etypes, list):
+ etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
+ th.tensor(etypes, device=device))
+ weight = weight.index_select(0, etypes)
+ msg = th.bmm(h.unsqueeze(1), weight).squeeze(1)
+
+ if 'norm' in edges.data:
+ msg = msg * edges.data['norm']
+ return {'msg': msg}
+
+ def bdd_message_func(self, edges, etypes):
+ """Message function for block-diagonal-decomposition regularizer.
+
+ Parameters
+ ----------
+ edges : dgl.EdgeBatch
+ Input to DGL message UDF.
+ etypes : torch.Tensor or list[int]
+ Edge type data. Could be either:
+
+ * An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID.
+ Preferred format if ``lowmem == False``.
+ * An integer list. The i^th element is the number of edges of the i^th type.
+ This requires the input graph to store edges sorted by their type IDs.
+ Preferred format if ``lowmem == True``.
+ """
+ h = edges.src['h']
+ device = h.device
+
+ if h.dtype == th.int64 and h.ndim == 1:
+ raise TypeError(
+ 'Block decomposition does not allow integer ID feature.')
+
+ if self.low_mem:
+ # A more memory-friendly implementation.
+ # Calculate msg @ W_r before put msg into edge.
+ assert isinstance(etypes, list)
+ h_t = th.split(h, etypes)
+ msg = []
+ for etype in range(self.num_rels):
+ if h_t[etype].shape[0] == 0:
+ continue
+ tmp_w = self.weight[etype].view(
+ self.num_bases, self.submat_in, self.submat_out)
+ tmp_h = h_t[etype].view(-1, self.num_bases, self.submat_in)
+ msg.append(th.einsum('abc,bcd->abd', tmp_h,
+ tmp_w).reshape(-1, self.out_feat))
+ msg = th.cat(msg)
+ else:
+ # Use batched matmult
+ if isinstance(etypes, list):
+ etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
+ th.tensor(etypes, device=device))
+ weight = self.weight.index_select(0, etypes).view(
+ -1, self.submat_in, self.submat_out)
+ node = h.view(-1, 1, self.submat_in)
+ msg = th.bmm(node, weight).view(-1, self.out_feat)
+ if 'norm' in edges.data:
+ msg = msg * edges.data['norm']
+ return {'msg': msg}
+
+ def forward(self, g, feat, etypes, norm=None):
+ """Forward computation.
+
+ Parameters
+ ----------
+ g : DGLGraph
+ The graph.
+ feat : torch.Tensor
+ Input node features. Could be either
+
+ * :math:`(|V|, D)` dense tensor
+ * :math:`(|V|,)` int64 vector, representing the categorical values of each
+ node. It then treat the input feature as an one-hot encoding feature.
+ etypes : torch.Tensor or list[int]
+ Edge type data. Could be either
+
+ * An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID.
+ Preferred format if ``lowmem == False``.
+ * An integer list. The i^th element is the number of edges of the i^th type.
+ This requires the input graph to store edges sorted by their type IDs.
+ Preferred format if ``lowmem == True``.
+ norm : torch.Tensor, optional
+ Edge normalizer. Could be either
+
+ * An :math:`(|E|, 1)` tensor storing the normalizer on each edge.
+
+ Returns
+ -------
+ torch.Tensor
+ New node features.
+
+ Notes
+ -----
+ Under the ``low_mem`` mode, DGL will sort the graph based on the edge types
+ and compute message passing one type at a time. DGL recommends sorts the
+ graph beforehand (and cache it if possible) and provides the integer list
+ format to the ``etypes`` argument. Use DGL's :func:`~dgl.to_homogeneous` API
+ to get a sorted homogeneous graph from a heterogeneous graph. Pass ``return_count=True``
+ to it to get the ``etypes`` in integer list.
+ """
+ if isinstance(etypes, th.Tensor):
+ if len(etypes) != g.num_edges():
+ raise DGLError('"etypes" tensor must have length equal to the number of edges'
+ ' in the graph. But got {} and {}.'.format(
+ len(etypes), g.num_edges()))
+ if self.low_mem and not (feat.dtype == th.int64 and feat.ndim == 1):
+ # Low-mem optimization is not enabled for node ID input. When enabled,
+ # it first sorts the graph based on the edge types (the sorting will not
+ # change the node IDs). It then converts the etypes tensor to an integer
+ # list, where each element is the number of edges of the type.
+ # Sort the graph based on the etypes
+ sorted_etypes, index = th.sort(etypes)
+ g = edge_subgraph(g, index, relabel_nodes=False)
+ # Create a new etypes to be an integer list of number of edges.
+ pos = _searchsorted(sorted_etypes, th.arange(
+ self.num_rels, device=g.device))
+ num = th.tensor([len(etypes)], device=g.device)
+ etypes = (th.cat([pos[1:], num]) - pos).tolist()
+ if norm is not None:
+ norm = norm[index]
+
+ with g.local_scope():
+ g.srcdata['h'] = feat
+ if norm is not None:
+ g.edata['norm'] = norm
+ if self.self_loop:
+ loop_message = utils.matmul_maybe_select(feat[:g.number_of_dst_nodes()],
+ self.loop_weight)
+
+ if not self.wni:
+ # message passing
+ g.update_all(functools.partial(self.message_func, etypes=etypes),
+ fn.sum(msg='msg', out='h'))
+ # apply bias and activation
+ node_repr = g.dstdata['h']
+ if self.layer_norm:
+ node_repr = self.layer_norm_weight(node_repr)
+ if self.bias:
+ node_repr = node_repr + self.h_bias
+ else:
+ node_repr = 0
+
+ if self.self_loop:
+ node_repr = node_repr + loop_message
+ if self.activation:
+ node_repr = self.activation(node_repr)
+ node_repr = self.dropout(node_repr)
+ return node_repr
+
+
+_TORCH_HAS_SEARCHSORTED = getattr(th, 'searchsorted', None)
+
+
+def _searchsorted(sorted_sequence, values):
+ # searchsorted is introduced to PyTorch in 1.6.0
+ if _TORCH_HAS_SEARCHSORTED:
+ return th.searchsorted(sorted_sequence, values)
+ else:
+ device = values.device
+ return th.from_numpy(np.searchsorted(sorted_sequence.cpu().numpy(),
+ values.cpu().numpy())).to(device)
diff --git a/model/seal_model.py b/model/seal_model.py
new file mode 100644
index 0000000..a61c59f
--- /dev/null
+++ b/model/seal_model.py
@@ -0,0 +1,62 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from dgl.nn.pytorch import SortPooling, SumPooling
+from dgl.nn.pytorch import GraphConv, SAGEConv
+from dgl import NID, EID
+
+
+class SEAL_GCN(nn.Module):
+ def __init__(self, num_ent, num_rel, init_dim, gcn_dim, embed_dim, n_layer, loss_type, max_z=1000):
+ super(SEAL_GCN, self).__init__()
+
+ if loss_type == 'ce':
+ self.loss = nn.CrossEntropyLoss()
+ elif loss_type == 'bce':
+ self.loss = nn.BCELoss(reduce=False)
+ elif loss_type == 'bce_logits':
+ self.loss = nn.BCEWithLogitsLoss()
+ self.init_embed = self.get_param([num_ent, init_dim])
+ self.init_rel = self.get_param([num_rel, init_dim])
+ self.z_embedding = nn.Embedding(max_z, init_dim)
+ init_dim += init_dim
+
+ self.layers = nn.ModuleList()
+ self.layers.append(GraphConv(init_dim, gcn_dim, allow_zero_in_degree=True))
+ for _ in range(n_layer - 1):
+ self.layers.append(GraphConv(gcn_dim, gcn_dim, allow_zero_in_degree=True))
+
+ self.linear_1 = nn.Linear(embed_dim, embed_dim)
+ self.linear_2 = nn.Linear(embed_dim, num_rel)
+ self.pooling = SumPooling()
+
+ def get_param(self, shape):
+ param = nn.Parameter(torch.Tensor(*shape))
+ nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu'))
+ return param
+
+ def calc_loss(self, pred, label, pos_neg=None):
+ if pos_neg is not None:
+ m = nn.Sigmoid()
+ score_pos = m(pred)
+ targets_pos = pos_neg.unsqueeze(1)
+ loss = self.loss(score_pos, label * targets_pos)
+ return torch.sum(loss * label)
+ return self.loss(pred, label)
+
+ def forward(self, g, z):
+ x, r = self.init_embed[g.ndata[NID]], self.init_rel
+ z_emb = self.z_embedding(z)
+ x = torch.cat([x, z_emb], 1)
+ for layer in self.layers[:-1]:
+ x = layer(g, x)
+ x = F.relu(x)
+ x = F.dropout(x, p=0.5, training=self.training)
+ x = self.layers[-1](g, x)
+
+ x = self.pooling(g, x)
+ x = F.relu(self.linear_1(x))
+ F.dropout(x, p=0.5, training=self.training)
+ x = self.linear_2(x)
+
+ return x
\ No newline at end of file
diff --git a/model/search_layer.py b/model/search_layer.py
new file mode 100644
index 0000000..ef76eee
--- /dev/null
+++ b/model/search_layer.py
@@ -0,0 +1,192 @@
+import torch
+from torch import nn
+import dgl
+import dgl.function as fn
+from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES
+from model.operations import *
+
+
+class CompMixOp(nn.Module):
+ def __init__(self):
+ super(CompMixOp, self).__init__()
+ self._ops = nn.ModuleList()
+ for primitive in COMP_PRIMITIVES:
+ op = COMP_OPS[primitive]()
+ self._ops.append(op)
+
+ def reset_parameters(self):
+ for op in self._ops:
+ op.reset_parameters()
+
+ def forward(self, src_emb, rel_emb, weights):
+ mixed_res = []
+ for w, op in zip(weights, self._ops):
+ mixed_res.append(w * op(src_emb, rel_emb))
+ return sum(mixed_res)
+
+
+class AggMixOp(nn.Module):
+ def __init__(self):
+ super(AggMixOp, self).__init__()
+ self._ops = nn.ModuleList()
+ for primitive in AGG_PRIMITIVES:
+ op = AGG_OPS[primitive]()
+ self._ops.append(op)
+
+ def reset_parameters(self):
+ for op in self._ops:
+ op.reset_parameters()
+
+ def forward(self, msg, weights):
+ mixed_res = []
+ for w, op in zip(weights, self._ops):
+ mixed_res.append(w * op(msg))
+ return sum(mixed_res)
+
+
+class CombMixOp(nn.Module):
+ def __init__(self, out_channels):
+ super(CombMixOp, self).__init__()
+ self._ops = nn.ModuleList()
+ for primitive in COMB_PRIMITIVES:
+ op = COMB_OPS[primitive](out_channels)
+ self._ops.append(op)
+
+ def reset_parameters(self):
+ for op in self._ops:
+ op.reset_parameters()
+
+ def forward(self, self_emb, msg, weights):
+ mixed_res = []
+ for w, op in zip(weights, self._ops):
+ mixed_res.append(w * op(self_emb, msg))
+ return sum(mixed_res)
+
+
+class ActMixOp(nn.Module):
+ def __init__(self):
+ super(ActMixOp, self).__init__()
+ self._ops = nn.ModuleList()
+ for primitive in ACT_PRIMITIVES:
+ op = ACT_OPS[primitive]()
+ self._ops.append(op)
+
+ def reset_parameters(self):
+ for op in self._ops:
+ op.reset_parameters()
+
+ def forward(self, emb, weights):
+ mixed_res = []
+ for w, op in zip(weights, self._ops):
+ mixed_res.append(w * op(emb))
+ return sum(mixed_res)
+
+class SearchGCNConv(nn.Module):
+ def __init__(self, in_channels, out_channels, act=lambda x: x, bias=True, drop_rate=0., num_base=-1,
+ num_rel=None, wni=False, wsi=False, use_bn=True, ltr=True, comp_weights=None, agg_weights=None):
+ super(SearchGCNConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ # self.act = act # activation function
+ self.device = None
+
+ self.rel = nn.Parameter(torch.empty([num_rel, in_channels], dtype=torch.float))
+
+ self.use_bn = use_bn
+ self.ltr = ltr
+
+ # relation-type specific parameter
+ self.in_w = self.get_param([in_channels, out_channels])
+ self.out_w = self.get_param([in_channels, out_channels])
+ self.loop_w = self.get_param([in_channels, out_channels])
+ # transform embedding of relations to next layer
+ self.w_rel = self.get_param([in_channels, out_channels])
+ self.loop_rel = self.get_param([1, in_channels]) # self-loop embedding
+
+ self.drop = nn.Dropout(drop_rate)
+ self.bn = torch.nn.BatchNorm1d(out_channels)
+ self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None
+ if num_base > 0:
+ self.rel_wt = self.get_param([num_rel, num_base])
+ else:
+ self.rel_wt = None
+
+ self.wni = wni
+ self.wsi = wsi
+ self.comp_weights = comp_weights
+ self.agg_weights = agg_weights
+ self.comp = CompMixOp()
+ self.agg = AggMixOp()
+ self.comb = CombMixOp(out_channels)
+ self.act = ActMixOp()
+
+ def get_param(self, shape):
+ param = nn.Parameter(torch.Tensor(*shape))
+ nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu'))
+ return param
+
+ def message_func(self, edges):
+ edge_type = edges.data['type'] # [E, 1]
+ edge_num = edge_type.shape[0]
+ edge_data = self.comp(
+ edges.src['h'], self.rel[edge_type], self.comp_weights) # [E, in_channel]
+ # NOTE: first half edges are all in-directions, last half edges are out-directions.
+ msg = torch.cat([torch.matmul(edge_data[:edge_num // 2, :], self.in_w),
+ torch.matmul(edge_data[edge_num // 2:, :], self.out_w)])
+ msg = msg * edges.data['norm'].reshape(-1, 1) # [E, D] * [E, 1]
+ return {'msg': msg}
+
+ def reduce_func(self, nodes):
+ # return {'h': torch.sum(nodes.mailbox['msg'], dim=1)}
+ return {'h': self.agg(nodes.mailbox['msg'], self.agg_weights)}
+
+ def apply_node_func(self, nodes):
+ return {'h': self.drop(nodes.data['h'])}
+
+ def forward(self, g: dgl.DGLGraph, x, rel_repr, edge_type, edge_norm, comp_weights, agg_weights, comb_weights, act_weights):
+ """
+ :param g: dgl Graph, a graph without self-loop
+ :param x: input node features, [V, in_channel]
+ :param rel_repr: input relation features: 1. not using bases: [num_rel*2, in_channel]
+ 2. using bases: [num_base, in_channel]
+ :param edge_type: edge type, [E]
+ :param edge_norm: edge normalization, [E]
+ :return: x: output node features: [V, out_channel]
+ rel: output relation features: [num_rel*2, out_channel]
+ """
+ self.device = x.device
+ g = g.local_var()
+ g.ndata['h'] = x
+ g.edata['type'] = edge_type
+ g.edata['norm'] = edge_norm
+ self.comp_weights = comp_weights
+ self.agg_weights = agg_weights
+ self.comb_weights = comb_weights
+ self.act_weights = act_weights
+ if self.rel_wt is None:
+ self.rel.data = rel_repr
+ else:
+ # [num_rel*2, num_base] @ [num_base, in_c]
+ self.rel.data = torch.mm(self.rel_wt, rel_repr)
+ g.update_all(self.message_func, self.reduce_func, self.apply_node_func)
+
+ if (not self.wni) and (not self.wsi):
+ x = self.comb(g.ndata.pop('h'), torch.mm(self.comp(x, self.loop_rel, self.comp_weights), self.loop_w), self.comb_weights)*(1/3)
+ # x = (g.ndata.pop('h') +
+ # torch.mm(self.comp(x, self.loop_rel, self.comp_weights), self.loop_w)) / 3
+ # else:
+ # if self.wsi:
+ # x = g.ndata.pop('h') / 2
+ # if self.wni:
+ # x = torch.mm(self.comp(x, self.loop_rel), self.loop_w)
+
+ if self.bias is not None:
+ x = x + self.bias
+
+ if self.use_bn:
+ x = self.bn(x)
+
+ if self.ltr:
+ return self.act(x, self.act_weights), torch.matmul(self.rel.data, self.w_rel)
+ else:
+ return self.act(x, self.act_weights), self.rel.data
\ No newline at end of file
diff --git a/model/search_layer_spos.py b/model/search_layer_spos.py
new file mode 100644
index 0000000..919b762
--- /dev/null
+++ b/model/search_layer_spos.py
@@ -0,0 +1,185 @@
+import torch
+from torch import nn
+import dgl
+import dgl.function as fn
+from model.genotypes import COMP_PRIMITIVES, AGG_PRIMITIVES, COMB_PRIMITIVES, ACT_PRIMITIVES
+from model.operations import *
+
+
+class CompOpBlock(nn.Module):
+ def __init__(self):
+ super(CompOpBlock, self).__init__()
+ self._ops = nn.ModuleList()
+ for primitive in COMP_PRIMITIVES:
+ op = COMP_OPS[primitive]()
+ self._ops.append(op)
+
+ def reset_parameters(self):
+ for op in self._ops:
+ op.reset_parameters()
+
+ def forward(self, src_emb, rel_emb, primitive):
+ return self._ops[COMP_PRIMITIVES.index(primitive)](src_emb, rel_emb)
+
+
+class AggOpBlock(nn.Module):
+ def __init__(self):
+ super(AggOpBlock, self).__init__()
+ self._ops = nn.ModuleList()
+ for primitive in AGG_PRIMITIVES:
+ op = AGG_OPS[primitive]()
+ self._ops.append(op)
+
+ def reset_parameters(self):
+ for op in self._ops:
+ op.reset_parameters()
+
+ def forward(self, msg, primitive):
+ return self._ops[AGG_PRIMITIVES.index(primitive)](msg)
+
+
+class CombOpBlock(nn.Module):
+ def __init__(self, out_channels):
+ super(CombOpBlock, self).__init__()
+ self._ops = nn.ModuleList()
+ for primitive in COMB_PRIMITIVES:
+ op = COMB_OPS[primitive](out_channels)
+ self._ops.append(op)
+
+ def reset_parameters(self):
+ for op in self._ops:
+ op.reset_parameters()
+
+ def forward(self, self_emb, msg, primitive):
+ return self._ops[COMB_PRIMITIVES.index(primitive)](self_emb, msg)
+
+
+class ActOpBlock(nn.Module):
+ def __init__(self):
+ super(ActOpBlock, self).__init__()
+ self._ops = nn.ModuleList()
+ for primitive in ACT_PRIMITIVES:
+ op = ACT_OPS[primitive]()
+ self._ops.append(op)
+
+ def reset_parameters(self):
+ for op in self._ops:
+ op.reset_parameters()
+
+ def forward(self, emb, primitive):
+ return self._ops[ACT_PRIMITIVES.index(primitive)](emb)
+
+
+class SearchSPOSGCNConv(nn.Module):
+ def __init__(self, in_channels, out_channels, act=lambda x: x, bias=True, drop_rate=0., num_base=-1,
+ num_rel=None, wni=False, wsi=False, use_bn=True, ltr=True):
+ super(SearchSPOSGCNConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ # self.act = act # activation function
+ self.device = None
+
+ self.rel = nn.Parameter(torch.empty([num_rel, in_channels], dtype=torch.float))
+
+ self.use_bn = use_bn
+ self.ltr = ltr
+
+ # relation-type specific parameter
+ self.in_w = self.get_param([in_channels, out_channels])
+ self.out_w = self.get_param([in_channels, out_channels])
+ self.loop_w = self.get_param([in_channels, out_channels])
+ # transform embedding of relations to next layer
+ self.w_rel = self.get_param([in_channels, out_channels])
+ self.loop_rel = self.get_param([1, in_channels]) # self-loop embedding
+
+ self.drop = nn.Dropout(drop_rate)
+ self.bn = torch.nn.BatchNorm1d(out_channels)
+ self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None
+ if num_base > 0:
+ self.rel_wt = self.get_param([num_rel, num_base])
+ else:
+ self.rel_wt = None
+
+ self.wni = wni
+ self.wsi = wsi
+ self.comp = CompOpBlock()
+ # self.agg = AggOpBlock()
+ self.comb = CombOpBlock(out_channels)
+ self.act = ActOpBlock()
+
+ def get_param(self, shape):
+ param = nn.Parameter(torch.Tensor(*shape))
+ nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu'))
+ return param
+
+ def message_func(self, edges):
+ edge_type = edges.data['type'] # [E, 1]
+ edge_num = edge_type.shape[0]
+ edge_data = self.comp(
+ edges.src['h'], self.rel[edge_type], self.comp_primitive) # [E, in_channel]
+ # NOTE: first half edges are all in-directions, last half edges are out-directions.
+ msg = torch.cat([torch.matmul(edge_data[:edge_num // 2, :], self.in_w),
+ torch.matmul(edge_data[edge_num // 2:, :], self.out_w)])
+ msg = msg * edges.data['norm'].reshape(-1, 1) # [E, D] * [E, 1]
+ return {'msg': msg}
+
+ def reduce_func(self, nodes):
+ # return {'h': torch.sum(nodes.mailbox['msg'], dim=1)}
+ return {'h': self.agg(nodes.mailbox['msg'], self.agg_primitive)}
+
+ def apply_node_func(self, nodes):
+ return {'h': self.drop(nodes.data['h'])}
+
+ def forward(self, g: dgl.DGLGraph, x, rel_repr, edge_type, edge_norm, comp_primitive, agg_primitive, comb_primitive, act_primitive):
+ """
+ :param g: dgl Graph, a graph without self-loop
+ :param x: input node features, [V, in_channel]
+ :param rel_repr: input relation features: 1. not using bases: [num_rel*2, in_channel]
+ 2. using bases: [num_base, in_channel]
+ :param edge_type: edge type, [E]
+ :param edge_norm: edge normalization, [E]
+ :return: x: output node features: [V, out_channel]
+ rel: output relation features: [num_rel*2, out_channel]
+ """
+ self.device = x.device
+ g = g.local_var()
+ g.ndata['h'] = x
+ g.edata['type'] = edge_type
+ g.edata['norm'] = edge_norm
+ self.comp_primitive = comp_primitive
+ self.agg_primitive = agg_primitive
+ self.comb_primitive = comb_primitive
+ self.act_primitive = act_primitive
+ if self.rel_wt is None:
+ self.rel.data = rel_repr
+ else:
+ # [num_rel*2, num_base] @ [num_base, in_c]
+ self.rel.data = torch.mm(self.rel_wt, rel_repr)
+ if self.agg_primitive == 'max':
+ g.update_all(self.message_func, fn.max(msg='msg', out='h'), self.apply_node_func)
+ elif self.agg_primitive == 'mean':
+ g.update_all(self.message_func, fn.mean(msg='msg', out='h'), self.apply_node_func)
+ elif self.agg_primitive == 'sum':
+ g.update_all(self.message_func, fn.sum(msg='msg', out='h'), self.apply_node_func)
+ # g.update_all(self.message_func, self.reduce_func, self.apply_node_func)
+
+ if (not self.wni) and (not self.wsi):
+ x = self.comb(g.ndata.pop('h'), torch.mm(self.comp(x, self.loop_rel, self.comp_primitive), self.loop_w), self.comb_primitive)*(1/3)
+ # x = (g.ndata.pop('h') +
+ # torch.mm(self.comp(x, self.loop_rel, self.comp_weights), self.loop_w)) / 3
+ # else:
+ # if self.wsi:
+ # x = g.ndata.pop('h') / 2
+ # if self.wni:
+ # x = torch.mm(self.comp(x, self.loop_rel), self.loop_w)
+
+ if self.bias is not None:
+ x = x + self.bias
+
+ if self.use_bn:
+ x = self.bn(x)
+
+ if self.ltr:
+ return self.act(x, self.act_primitive), torch.matmul(self.rel.data, self.w_rel)
+ else:
+ return self.act(x, self.act_primitive), self.rel.data
\ No newline at end of file
diff --git a/model/subgraph_selector.py b/model/subgraph_selector.py
new file mode 100644
index 0000000..ecb6583
--- /dev/null
+++ b/model/subgraph_selector.py
@@ -0,0 +1,51 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+class SubgraphSelector(nn.Module):
+ def __init__(self, args):
+ super(SubgraphSelector,self).__init__()
+ self.args = args
+ self.temperature = self.args.temperature
+ self.num_layers = self.args.ss_num_layer
+ self.cat_type = self.args.combine_type
+ if self.cat_type == 'mult':
+ in_channels = self.args.ss_input_dim
+ else:
+ in_channels = self.args.ss_input_dim * 2
+ hidden_channels = self.args.ss_hidden_dim
+ self.trans = nn.ModuleList()
+ for i in range(self.num_layers - 1):
+ if i == 0:
+ self.trans.append(nn.Linear(in_channels, hidden_channels, bias=False))
+ else:
+ self.trans.append(nn.Linear(hidden_channels, hidden_channels, bias=False))
+ self.trans.append(nn.Linear(hidden_channels, 1, bias=False))
+
+ def forward(self, x, mode='argmax', search_algorithm='darts'):
+ for layer in self.trans[:-1]:
+ x = layer(x)
+ x= F.relu(x)
+ x = self.trans[-1](x)
+ x = torch.squeeze(x,dim=2)
+ if search_algorithm == 'darts':
+ arch_set = torch.softmax(x/self.temperature,dim=1)
+ elif search_algorithm == 'snas':
+ arch_set = self._get_categ_mask(x)
+ if mode == 'argmax':
+ device = arch_set.device
+ n, c = arch_set.shape
+ eyes_atten = torch.eye(c).to(device)
+ atten_ , atten_indice = torch.max(arch_set, dim=1)
+ arch_set = eyes_atten[atten_indice]
+ return arch_set
+ # raise NotImplementedError
+
+ def _get_categ_mask(self, alpha):
+ # log_alpha = torch.log(alpha)
+ log_alpha = alpha
+ u = torch.zeros_like(log_alpha).uniform_()
+ softmax = torch.nn.Softmax(-1)
+ one_hot = softmax((log_alpha + (-((-(u.log())).log()))) / self.temperature)
+ return one_hot
\ No newline at end of file
diff --git a/run.py b/run.py
new file mode 100644
index 0000000..5dd1ae8
--- /dev/null
+++ b/run.py
@@ -0,0 +1,3089 @@
+import csv
+import os
+import argparse
+import time
+import logging
+from pprint import pprint
+import numpy as np
+import random
+
+import pandas as pd
+import torch
+from torch.utils.data import DataLoader
+import dgl
+from data.knowledge_graph import load_data
+from model import GCN_TransE, GCN_DistMult, GCN_ConvE, SubgraphSelector, GCN_ConvE_Rel, GCN_Transformer, GCN_None, \
+ GCN_MLP, GCN_MLP_NCN, SearchGCN_MLP, SearchedGCN_MLP, NetworkGNN_MLP, SearchGCN_MLP_SPOS, SEAL_GCN
+from model.lte_models import TransE, DistMult, ConvE
+from utils import process, process_multi_label, TrainDataset, TestDataset, get_logger, GraphTrainDataset, GraphTestDataset, \
+ get_f1_score_list, get_acc_list, get_neighbor_nodes, NCNDataset, Temp_Scheduler
+import wandb
+from os.path import exists
+from os import mkdir, makedirs
+from dgl.dataloading import GraphDataLoader
+from sklearn import metrics
+import matplotlib.pyplot as plt
+import pickle
+from tqdm import tqdm
+from torch.nn.utils import clip_grad_norm_
+import setproctitle
+from hyperopt import fmin, tpe, hp, Trials, partial, STATUS_OK, rand, space_eval
+from model.genotypes import *
+import optuna
+from optuna.samplers import RandomSampler
+from utils import CategoricalASNG
+import itertools
+from sortedcontainers import SortedDict
+from dgl import NID, EID
+
+torch.multiprocessing.set_sharing_strategy('file_system')
+
+
+class Runner(object):
+ def __init__(self, params):
+ self.p = params
+ self.prj_path = os.getcwd()
+ self.data = load_data(self.p.dataset)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.num_ent, self.train_data, self.num_rels = self.data.num_nodes, self.data.train_graph, self.data.num_rels
+ # self.train_input, self.valid_rel, self.test_rel = self.data.train_rel, self.data.valid_rel, self.data.test_rel
+ # self.train_pos_neg, self.valid_pos_neg, self.test_pos_neg = self.data.train_pos_neg, self.data.valid_pos_neg, self.data.test_pos_neg
+ self.triplets = process_multi_label(
+ {'train': self.data.train_input, 'valid': self.data.valid_input, 'test': self.data.test_input},
+ {'train': self.data.train_multi_label, 'valid': self.data.valid_multi_label, 'test': self.data.test_multi_label},
+ {'train': self.data.train_pos_neg, 'valid': self.data.valid_pos_neg, 'test': self.data.test_pos_neg}
+ )
+ else:
+ self.num_ent, self.train_data, self.valid_data, self.test_data, self.num_rels = self.data.num_nodes, self.data.train, self.data.valid, self.data.test, self.data.num_rels
+ self.triplets, self.class2num = process(
+ {'train': self.train_data, 'valid': self.valid_data, 'test': self.test_data},
+ self.num_rels, self.p.n_layer, self.p.add_reverse)
+ self.p.embed_dim = self.p.k_w * \
+ self.p.k_h if self.p.embed_dim is None else self.p.embed_dim # output dim of gnn
+ self.g = self.build_graph()
+ self.edge_type, self.edge_norm = self.get_edge_dir_and_norm()
+ if self.p.input_type == "subgraph":
+ self.get_subgraph()
+ self.data_iter = self.get_data_iter()
+
+ if (self.p.search_mode != 'arch_random' or self.p.search_mode != 'arch_search') and self.p.search_algorithm!='random_ps2':
+ self.model = self.get_model()
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.p.lr, weight_decay=self.p.l2)
+ self.best_val_f1, self.best_val_auroc, self.best_epoch, self.best_val_results, self.best_test_results = 0., 0., 0., {}, {}
+ self.best_test_f1, self.best_test_auroc = 0., 0.
+ self.early_stop_cnt = 0
+ os.makedirs(f'logs/{self.p.dataset}/', exist_ok=True)
+ if self.p.train_mode == 'tune':
+ tmp_name = self.p.name + '_tune'
+ self.logger = get_logger(f'logs/{self.p.dataset}/', tmp_name)
+ else:
+ self.logger = get_logger(f'logs/{self.p.dataset}/', self.p.name)
+ pprint(vars(self.p))
+
+ def save_model(self, path):
+ """
+ Function to save a model. It saves the model parameters, best validation scores,
+ best epoch corresponding to best validation, state of the optimizer and all arguments for the run.
+ :param path: path where the model is saved
+ :return:
+ """
+ state = {
+ 'model': self.model.state_dict(),
+ 'best_val': self.best_val_results,
+ 'best_epoch': self.best_epoch,
+ 'optimizer': self.optimizer.state_dict(),
+ 'args': vars(self.p)
+ }
+ torch.save(state, path)
+
+ def save_search_model(self, path):
+ """
+ Function to save a model. It saves the model parameters, best validation scores,
+ best epoch corresponding to best validation, state of the optimizer and all arguments for the run.
+ :param path: path where the model is saved
+ :return:
+ """
+ state = {
+ 'model': self.model.state_dict(),
+ 'best_val': self.best_val_results,
+ 'best_epoch': self.best_epoch,
+ 'optimizer': self.optimizer.state_dict(),
+ 'args': vars(self.p)
+ }
+ torch.save(state, path)
+
+ def load_model(self, path):
+ """
+ Function to load a saved model
+ :param path: path where model is loaded
+ :return:
+ """
+ state = torch.load(path)
+ self.best_val_results = state['best_val']
+ self.best_epoch = state['best_epoch']
+ self.model.load_state_dict(state['model'])
+ self.optimizer.load_state_dict(state['optimizer'])
+
+ def build_graph(self):
+ g = dgl.DGLGraph()
+ g.add_nodes(self.num_ent + 1)
+ g.add_edges(self.train_data[:, 0], self.train_data[:, 2])
+ if self.p.add_reverse:
+ g.add_edges(self.train_data[:, 2], self.train_data[:, 0])
+ return g
+
+ def get_data_iter(self):
+
+ def get_data_loader(dataset_class, split, shuffle=True):
+ return DataLoader(
+ dataset_class(self.triplets[split], self.num_ent, self.num_rels, self.p),
+ batch_size=self.p.batch_size,
+ shuffle=shuffle,
+ num_workers=self.p.num_workers
+ )
+
+ def get_graph_data_loader(dataset_class, split, db_name_pos=None):
+ return GraphDataLoader(
+ dataset_class(self.triplets[split], self.num_ent, self.num_rels, self.p, self.g, db_name_pos),
+ batch_size=self.p.batch_size,
+ shuffle=True,
+ num_workers=self.p.num_workers
+ )
+
+ def get_ncndata_loader(dataset_class, split, db_name_pos=None):
+ return DataLoader(
+ dataset_class(self.triplets[split], self.num_ent, self.num_rels, self.p, self.adj, db_name_pos),
+ batch_size=self.p.batch_size,
+ shuffle=True,
+ num_workers=self.p.num_workers
+ )
+
+ if self.p.input_type == 'subgraph' or self.p.fine_tune_with_implicit_subgraph:
+ if self.p.add_reverse:
+ return {
+ 'train_rel': get_graph_data_loader(GraphTrainDataset, 'train_rel'),
+ 'valid_rel': get_graph_data_loader(GraphTestDataset, 'valid_rel'),
+ 'valid_rel_inv': get_graph_data_loader(GraphTestDataset, 'valid_rel_inv'),
+ 'test_rel': get_graph_data_loader(GraphTestDataset, 'test_rel'),
+ 'test_rel_inv': get_graph_data_loader(GraphTestDataset, 'test_rel_inv'),
+ # 'valid_head': get_data_loader(TestDataset, 'valid_head'),
+ # 'valid_tail': get_data_loader(TestDataset, 'valid_tail'),
+ # 'test_head': get_data_loader(TestDataset, 'test_head'),
+ # 'test_tail': get_data_loader(TestDataset, 'test_tail'),
+ }
+ else:
+ return {
+ 'train_rel': get_graph_data_loader(GraphTrainDataset, 'train_rel', 'train_pos'),
+ 'valid_rel': get_graph_data_loader(GraphTestDataset, 'valid_rel', 'valid_pos'),
+ 'test_rel': get_graph_data_loader(GraphTestDataset, 'test_rel', 'test_pos'),
+ }
+ elif self.p.input_type == 'allgraph' and self.p.score_func == 'mlp_ncn':
+ return {
+ 'train_rel': get_ncndata_loader(NCNDataset, 'train_rel', 'train_pos'),
+ 'valid_rel': get_ncndata_loader(NCNDataset, 'valid_rel', 'valid_pos'),
+ 'test_rel': get_ncndata_loader(NCNDataset, 'test_rel', 'test_pos'),
+ }
+ else:
+ if self.p.add_reverse:
+ return {
+ 'train_rel': get_data_loader(TrainDataset, 'train_rel'),
+ 'valid_rel': get_data_loader(TestDataset, 'valid_rel'),
+ 'valid_rel_inv': get_data_loader(TestDataset, 'valid_rel_inv'),
+ 'test_rel': get_data_loader(TestDataset, 'test_rel'),
+ 'test_rel_inv': get_data_loader(TestDataset, 'test_rel_inv')
+ }
+ else:
+ return {
+ 'train_rel': get_data_loader(TrainDataset, 'train_rel'),
+ 'valid_rel': get_data_loader(TestDataset, 'valid_rel'),
+ 'test_rel': get_data_loader(TestDataset, 'test_rel')
+ }
+
+ def get_edge_dir_and_norm(self):
+ """
+ :return: edge_type: indicates type of each edge: [E]
+ """
+ in_deg = self.g.in_degrees(range(self.g.number_of_nodes())).float()
+ norm = in_deg ** -0.5
+ norm[torch.isinf(norm).bool()] = 0
+ self.g.ndata['xxx'] = norm
+ self.g.apply_edges(
+ lambda edges: {'xxx': edges.dst['xxx'] * edges.src['xxx']})
+ if self.p.gpu >= 0:
+ norm = self.g.edata.pop('xxx').squeeze().to("cuda:0")
+ if self.p.add_reverse:
+ edge_type = torch.tensor(np.concatenate(
+ [self.train_data[:, 1], self.train_data[:, 1] + self.num_rels])).to("cuda:0")
+ else:
+ edge_type = torch.tensor(self.train_data[:, 1]).to("cuda:0")
+ else:
+ norm = self.g.edata.pop('xxx').squeeze()
+ edge_type = torch.tensor(np.concatenate(
+ [self.train_data[:, 1], self.train_data[:, 1] + self.num_rels]))
+ return edge_type, norm
+
+ def get_model(self):
+ if self.p.n_layer > 0:
+ if self.p.score_func.lower() == 'transe':
+ model = GCN_TransE(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases,
+ init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim,
+ n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm,
+ bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn,
+ hid_drop=self.p.hid_drop, gamma=self.p.gamma, wni=self.p.wni, wsi=self.p.wsi,
+ encoder=self.p.encoder, use_bn=(not self.p.nobn), ltr=(not self.p.noltr))
+ elif self.p.encoder == 'gcn':
+ model = SEAL_GCN(self.num_ent, self.num_rels, self.p.init_dim, self.p.gcn_dim, self.p.embed_dim, self.p.n_layer, loss_type=self.p.loss_type)
+ elif self.p.score_func.lower() == 'distmult':
+ model = GCN_DistMult(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases,
+ init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim,
+ n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm,
+ bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn,
+ hid_drop=self.p.hid_drop, wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder,
+ use_bn=(not self.p.nobn), ltr=(not self.p.noltr))
+ elif self.p.score_func.lower() == 'conve':
+ model = GCN_ConvE(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases,
+ init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim,
+ n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm,
+ bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn,
+ hid_drop=self.p.hid_drop, input_drop=self.p.input_drop,
+ conve_hid_drop=self.p.conve_hid_drop, feat_drop=self.p.feat_drop,
+ num_filt=self.p.num_filt, ker_sz=self.p.ker_sz, k_h=self.p.k_h, k_w=self.p.k_w,
+ wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder, use_bn=(not self.p.nobn),
+ ltr=(not self.p.noltr))
+ elif self.p.score_func.lower() == 'conve_rel':
+ model = GCN_ConvE_Rel(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases,
+ init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim,
+ n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm,
+ bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn,
+ hid_drop=self.p.hid_drop, input_drop=self.p.input_drop,
+ conve_hid_drop=self.p.conve_hid_drop, feat_drop=self.p.feat_drop,
+ num_filt=self.p.num_filt, ker_sz=self.p.ker_sz, k_h=self.p.k_h, k_w=self.p.k_w,
+ wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder, use_bn=(not self.p.nobn),
+ ltr=(not self.p.noltr), input_type=self.p.input_type)
+ elif self.p.score_func.lower() == 'transformer':
+ model = GCN_Transformer(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases,
+ init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim,
+ n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm,
+ bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn,
+ hid_drop=self.p.hid_drop, input_drop=self.p.input_drop,
+ conve_hid_drop=self.p.conve_hid_drop, feat_drop=self.p.feat_drop,
+ num_filt=self.p.num_filt, ker_sz=self.p.ker_sz, k_h=self.p.k_h, k_w=self.p.k_w,
+ wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder,
+ use_bn=(not self.p.nobn),
+ ltr=(not self.p.noltr), input_type=self.p.input_type,
+ d_model=self.p.d_model, num_transformer_layers=self.p.num_transformer_layers,
+ nhead=self.p.nhead, dim_feedforward=self.p.dim_feedforward,
+ transformer_dropout=self.p.transformer_dropout,
+ transformer_activation=self.p.transformer_activation,
+ graph_pooling=self.p.graph_pooling_type, concat_type=self.p.concat_type,
+ max_input_len=self.p.subgraph_max_num_nodes, loss_type=self.p.loss_type)
+ elif self.p.score_func.lower() == 'none':
+ model = GCN_None(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases,
+ init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim,
+ n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm,
+ bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn,
+ hid_drop=self.p.hid_drop, input_drop=self.p.input_drop,
+ conve_hid_drop=self.p.conve_hid_drop, feat_drop=self.p.feat_drop,
+ num_filt=self.p.num_filt, ker_sz=self.p.ker_sz, k_h=self.p.k_h, k_w=self.p.k_w,
+ wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder,
+ use_bn=(not self.p.nobn),
+ ltr=(not self.p.noltr), input_type=self.p.input_type,
+ graph_pooling=self.p.graph_pooling_type, concat_type=self.p.concat_type,
+ loss_type=self.p.loss_type, add_reverse=self.p.add_reverse)
+ elif self.p.score_func.lower() == 'mlp' and self.p.encoder != 'searchgcn' and self.p.genotype is None:
+ model = GCN_MLP(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases,
+ init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim,
+ n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm,
+ bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn,
+ hid_drop=self.p.hid_drop, input_drop=self.p.input_drop,
+ conve_hid_drop=self.p.conve_hid_drop, feat_drop=self.p.feat_drop,
+ num_filt=self.p.num_filt, ker_sz=self.p.ker_sz, k_h=self.p.k_h, k_w=self.p.k_w,
+ wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder,
+ use_bn=(not self.p.nobn),
+ ltr=(not self.p.noltr), input_type=self.p.input_type,
+ graph_pooling=self.p.graph_pooling_type, combine_type=self.p.combine_type,
+ loss_type=self.p.loss_type, add_reverse=self.p.add_reverse)
+ elif self.p.score_func.lower() == 'mlp_ncn':
+ model = GCN_MLP_NCN(num_ent=self.num_ent, num_rel=self.num_rels, num_base=self.p.num_bases,
+ init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim,
+ n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm,
+ bias=self.p.bias, gcn_drop=self.p.gcn_drop, opn=self.p.opn,
+ hid_drop=self.p.hid_drop, input_drop=self.p.input_drop,
+ conve_hid_drop=self.p.conve_hid_drop, feat_drop=self.p.feat_drop,
+ num_filt=self.p.num_filt, ker_sz=self.p.ker_sz, k_h=self.p.k_h, k_w=self.p.k_w,
+ wni=self.p.wni, wsi=self.p.wsi, encoder=self.p.encoder,
+ use_bn=(not self.p.nobn),
+ ltr=(not self.p.noltr), input_type=self.p.input_type,
+ graph_pooling=self.p.graph_pooling_type, combine_type=self.p.combine_type,
+ loss_type=self.p.loss_type, add_reverse=self.p.add_reverse)
+ elif self.p.genotype is not None:
+ model = SearchedGCN_MLP(args=self.p, num_ent=self.num_ent, num_rel=self.num_rels,
+ num_base=self.p.num_bases,
+ init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim,
+ n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm,
+ bias=self.p.bias, gcn_drop=self.p.gcn_drop, hid_drop=self.p.hid_drop,
+ input_drop=self.p.input_drop,
+ wni=self.p.wni, wsi=self.p.wsi, use_bn=(not self.p.nobn),
+ ltr=(not self.p.noltr),
+ combine_type=self.p.combine_type, loss_type=self.p.loss_type,
+ genotype=self.p.genotype)
+ elif "spos" in self.p.search_algorithm or self.p.train_mode=='vis_hop' or (self.p.train_mode=='spos_tune' and self.p.weight_sharing == True):
+ model = SearchGCN_MLP_SPOS(args=self.p, num_ent=self.num_ent, num_rel=self.num_rels,
+ num_base=self.p.num_bases,
+ init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim,
+ n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm,
+ bias=self.p.bias, gcn_drop=self.p.gcn_drop, hid_drop=self.p.hid_drop,
+ input_drop=self.p.input_drop,
+ wni=self.p.wni, wsi=self.p.wsi, use_bn=(not self.p.nobn), ltr=(not self.p.noltr),
+ combine_type=self.p.combine_type, loss_type=self.p.loss_type)
+ elif self.p.score_func.lower() == 'mlp' and self.p.genotype is None:
+ model = SearchGCN_MLP(args=self.p, num_ent=self.num_ent, num_rel=self.num_rels,
+ num_base=self.p.num_bases,
+ init_dim=self.p.init_dim, gcn_dim=self.p.gcn_dim, embed_dim=self.p.embed_dim,
+ n_layer=self.p.n_layer, edge_type=self.edge_type, edge_norm=self.edge_norm,
+ bias=self.p.bias, gcn_drop=self.p.gcn_drop, hid_drop=self.p.hid_drop,
+ input_drop=self.p.input_drop,
+ wni=self.p.wni, wsi=self.p.wsi, use_bn=(not self.p.nobn), ltr=(not self.p.noltr),
+ combine_type=self.p.combine_type, loss_type=self.p.loss_type)
+ else:
+ raise KeyError(
+ f'score function {self.p.score_func} not recognized.')
+ else:
+ if self.p.score_func.lower() == 'transe':
+ model = TransE(self.num_ent, self.num_rels, params=self.p)
+ elif self.p.score_func.lower() == 'distmult':
+ model = DistMult(self.num_ent, self.num_rels, params=self.p)
+ elif self.p.score_func.lower() == 'conve':
+ model = ConvE(self.num_ent, self.num_rels, params=self.p)
+ else:
+ raise NotImplementedError
+
+ if self.p.gpu >= 0:
+ model.to("cuda:0")
+ return model
+
+ def get_subgraph(self):
+ subgraph_dir = f'subgraph/{args.dataset}/{self.p.subgraph_type}_{self.p.subgraph_hop}_{self.p.subgraph_max_num_nodes}_{self.p.subgraph_sample_type}_{self.p.seed}'
+ if not exists(subgraph_dir):
+ makedirs(subgraph_dir)
+
+ for mode in ['train_rel', 'valid_rel', 'test_rel']:
+ if self.p.subgraph_is_saved:
+ if self.p.save_mode == 'pickle':
+ with open(
+ f'{subgraph_dir}/{mode}_{self.p.subgraph_type}_{self.p.subgraph_hop}_{self.p.subgraph_max_num_nodes}_{self.p.subgraph_sample_type}_{self.p.seed}.pkl',
+ 'rb') as f:
+ sample_nodes = pickle.load(f)
+ # graph_list = dgl.load_graphs(f'{subgraph_dir}/{mode}_{self.p.subgraph_type}_{self.p.subgraph_hop}_{self.p.subgraph_max_num_nodes}_{self.p.subgraph_sample_type}_{self.p.seed}.bin')[0]
+ for idx, _ in enumerate(self.triplets[mode]):
+ self.triplets[mode][idx]['sample_nodes'] = sample_nodes[idx][0]
+ elif self.p.save_mode == 'graph':
+ if self.p.add_reverse:
+ graph_list = dgl.load_graphs(
+ f'{subgraph_dir}/{mode}_add_reverse.bin')[0]
+ else:
+ graph_list = dgl.load_graphs(
+ f'{subgraph_dir}/{mode}.bin')[0]
+ for idx, _ in enumerate(self.triplets[mode]):
+ self.triplets[mode][idx]['subgraph'] = graph_list[idx]
+ def random_search(self):
+ save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}'
+ os.makedirs(save_root, exist_ok=True)
+ save_path = f'{save_root}/{self.p.name}_random.pt'
+ for epoch in range(self.p.max_epochs):
+ start_time = time.time()
+ train_loss = self.train_epoch_rs()
+ val_results, valid_loss = self.evaluate_epoch('valid', mode='random')
+ if self.p.dataset == 'drugbank':
+ if val_results['macro_f1'] > self.best_val_f1:
+ self.best_val_results = val_results
+ self.best_val_f1 = val_results['macro_f1']
+ self.best_epoch = epoch
+ self.save_model(save_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if self.p.dataset == 'drugbank':
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}")
+ wandb.log({"train_loss": train_loss, "best_valid_f1": self.best_val_f1})
+ if self.early_stop_cnt == 15:
+ self.logger.info("Early stop!")
+ break
+ self.load_model(save_path)
+ self.logger.info(
+ f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data')
+ start = time.time()
+ test_results, test_loss = self.evaluate_epoch('test', mode='random')
+ end = time.time()
+ if self.p.dataset == 'drugbank':
+ self.logger.info(
+ f"f1: Rel {test_results['left_f1']:.5}, Rel_rev {test_results['right_f1']:.5}, Avg {test_results['macro_f1']:.5}")
+ self.logger.info(
+ f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ wandb.log({
+ "test_acc": test_results['acc'],
+ "test_f1": test_results['macro_f1'],
+ "test_cohen": test_results['kappa']
+ })
+
+ def train(self):
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.train_mode}/{self.p.name}/'
+ os.makedirs(save_root, exist_ok=True)
+ save_path = f'{save_root}/model_weight.pt'
+ for epoch in range(self.p.max_epochs):
+ start_time = time.time()
+ train_loss = self.train_epoch()
+ val_results, valid_loss = self.evaluate_epoch('valid', mode='normal')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ if val_results['auroc'] > self.best_val_auroc:
+ self.best_val_results = val_results
+ self.best_val_auroc = val_results['auroc']
+ self.best_epoch = epoch
+ self.save_model(save_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ else:
+ if val_results['macro_f1'] > self.best_val_f1:
+ self.best_val_results = val_results
+ self.best_val_f1 = val_results['macro_f1']
+ self.best_epoch = epoch
+ self.save_model(save_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}")
+ # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ # "best_valid_auroc": self.best_val_auroc})
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}")
+ # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ # "best_valid_f1": self.best_val_f1})
+ if self.early_stop_cnt == 10:
+ self.logger.info("Early stop!")
+ break
+ # self.logger.info(vars(self.p))
+ self.load_model(save_path)
+ self.logger.info(
+ f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data')
+ start = time.time()
+ test_results, test_loss = self.evaluate_epoch('test', mode='normal')
+ end = time.time()
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ # wandb.log({
+ # "test_auroc": test_results['auroc'],
+ # "test_auprc": test_results['auprc'],
+ # "test_ap": test_results['ap']
+ # })
+ else:
+ self.logger.info(
+ f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ # wandb.log({
+ # "test_acc": test_results['acc'],
+ # "test_f1": test_results['macro_f1'],
+ # "test_cohen": test_results['kappa']
+ # })
+
+ def train_epoch(self):
+ self.model.train()
+ losses = []
+ train_iter = self.data_iter['train_rel']
+ # train_bar = tqdm(train_iter, ncols=0)
+ for step, batch in enumerate(train_iter):
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ if self.p.score_func == 'mlp_ncn':
+ cns = batch[2].to("cuda:0")
+ pred = self.model(g, subj, obj, cns)
+ else:
+ if self.p.encoder == 'gcn':
+ pred = self.model(g, g.ndata['z'])
+ else:
+ pred = self.model(g, subj, obj)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ loss = self.model.calc_loss(pred, labels, pos_neg)
+ else:
+ loss = self.model.calc_loss(pred, labels)
+ self.optimizer.zero_grad()
+ loss.backward()
+ if self.p.clip_grad:
+ clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2)
+ self.optimizer.step()
+ losses.append(loss.item())
+ loss = np.mean(losses)
+ return loss
+
+ def evaluate_epoch(self, split, mode='normal'):
+
+ def get_combined_results(left, right):
+ results = dict()
+ results['acc'] = round((left['acc'] + right['acc']) / 2, 5)
+ results['left_f1'] = round(left['macro_f1'], 5)
+ results['right_f1'] = round(right['macro_f1'], 5)
+ results['macro_f1'] = round((left['macro_f1'] + right['macro_f1']) / 2, 5)
+ results['kappa'] = round((left['kappa'] + right['kappa']) / 2, 5)
+ results['macro_f1_per_class'] = (np.array(left['macro_f1_per_class']) + np.array(
+ right['macro_f1_per_class'])) / 2.0
+ results['acc_per_class'] = (np.array(left['acc_per_class']) + np.array(right['acc_per_class'])) / 2.0
+ return results
+
+ def get_results(left):
+ results = dict()
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ results['auroc'] = round(left['auroc'], 5)
+ results['auprc'] = round(left['auprc'], 5)
+ results['ap'] = round(left['ap'], 5)
+ else:
+ results['acc'] = round(left['acc'], 5)
+ # results['auc_pr'] = round((left['auc_pr'] + right['auc_pr']) / 2, 5)
+ # results['micro_f1'] = round((left['micro_f1'] + right['micro_f1']) / 2, 5)
+ results['macro_f1'] = round(left['macro_f1'], 5)
+ results['kappa'] = round(left['kappa'], 5)
+ return results
+
+ self.model.eval()
+ if mode == 'normal':
+ if self.p.add_reverse:
+ left_result, left_loss = self.predict(split, '')
+ right_result, right_loss = self.predict(split, '_inv')
+ else:
+ left_result, left_loss = self.predict(split, '')
+ elif mode == 'normal_mix_hop':
+ left_result, left_loss = self.predict_mix_hop(split, '')
+ elif mode == 'random':
+ left_result, left_loss = self.predict_rs(split, '')
+ right_result, right_loss = self.predict_rs(split, '_inv')
+ elif mode == 'ps2':
+ if self.p.add_reverse:
+ left_result, left_loss = self.predict_search(split, '')
+ right_result, right_loss = self.predict_search(split, '_inv')
+ else:
+ left_result, left_loss = self.predict_search(split, '')
+ elif mode == 'arch_search':
+ left_result, left_loss = self.predict_arch_search(split, '')
+ elif mode == 'arch_search_s':
+ left_result, left_loss = self.predict_arch_search(split, '', 'evaluate_single_path')
+ elif mode == 'joint_search':
+ left_result, left_loss = self.predict_joint_search(split, '')
+ elif mode == 'joint_search_s':
+ left_result, left_loss = self.predict_joint_search(split, '', 'evaluate_single_path')
+ elif mode == 'spos_train_supernet':
+ left_result, left_loss = self.predict_spos_search(split, '')
+ elif mode == 'spos_arch_search':
+ left_result, left_loss = self.predict_spos_search(split, '', spos_mode='arch_search')
+ elif mode == 'spos_train_supernet_ps2':
+ left_result, left_loss = self.predict_spos_search_ps2(split, '')
+ elif mode == 'spos_arch_search_ps2':
+ left_result, left_loss = self.predict_spos_search_ps2(split, '', spos_mode='arch_search')
+ # res = get_results(left_result)
+ # return res, left_loss
+ if self.p.add_reverse:
+ res = get_combined_results(left_result, right_result)
+ return res, (left_loss + right_loss) / 2.0
+ else:
+ return left_result, left_loss
+
+ def predict(self, split, mode):
+ loss_list = []
+ pos_scores = []
+ pos_labels = []
+ pred_class = {}
+ self.model.eval()
+ with torch.no_grad():
+ results = dict()
+ eval_iter = self.data_iter[f'{split}_rel{mode}']
+ for step, batch in enumerate(eval_iter):
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ if self.p.score_func == 'mlp_ncn':
+ cns = batch[2].to("cuda:0")
+ pred = self.model(g, subj, obj, cns)
+ else:
+ if self.p.encoder == 'gcn':
+ pred = self.model(g, g.ndata['z'])
+ else:
+ pred = self.model(g, subj, obj)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ eval_loss = self.model.calc_loss(pred, labels, pos_neg)
+ m = torch.nn.Sigmoid()
+ pred = m(pred)
+ labels = labels.detach().to('cpu').numpy()
+ preds = pred.detach().to('cpu').numpy()
+ pos_neg = pos_neg.detach().to('cpu').numpy()
+ for (label_ids, pred, label_t) in zip(labels, preds, pos_neg):
+ for i, (l, p) in enumerate(zip(label_ids, pred)):
+ if l == 1:
+ if i in pred_class:
+ pred_class[i]['pred'] += [p]
+ pred_class[i]['l'] += [label_t]
+ pred_class[i]['pred_label'] += [1 if p > 0.5 else 0]
+ else:
+ pred_class[i] = {'pred': [p], 'l': [label_t], 'pred_label': [1 if p > 0.5 else 0]}
+ else:
+ eval_loss = self.model.calc_loss(pred, labels)
+ pos_labels += rel.to('cpu').numpy().flatten().tolist()
+ pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist()
+ loss_list.append(eval_loss.item())
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ roc_auc = [metrics.roc_auc_score(pred_class[l]['l'], pred_class[l]['pred']) for l in pred_class]
+ prc_auc = [metrics.average_precision_score(pred_class[l]['l'], pred_class[l]['pred']) for l in
+ pred_class]
+ ap = [metrics.accuracy_score(pred_class[l]['l'], pred_class[l]['pred_label']) for l in pred_class]
+ results['auroc'] = np.mean(roc_auc)
+ results['auprc'] = np.mean(prc_auc)
+ results['ap'] = np.mean(ap)
+ else:
+ results['acc'] = metrics.accuracy_score(pos_labels, pos_scores)
+ results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro')
+ results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores)
+ # dict_res = metrics.classification_report(pos_labels, pos_scores, output_dict=True, zero_division=1)
+ # results['macro_f1_per_class'] = get_f1_score_list(dict_res)
+ # results['acc_per_class'] = get_acc_list(dict_res)
+ # self.logger.info(f'Macro f1 per class: {results["macro_f1_per_class"]}')
+ # self.logger.info(f'Acc per class: {results["acc_per_class"]}')
+ loss = np.mean(loss_list)
+ return results, loss
+
+ def train_epoch_rs(self):
+ self.model.train()
+ loss_list = []
+ train_iter = self.data_iter['train_rel']
+ for step, batch in enumerate(train_iter):
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels, input_ids = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0"), \
+ batch[3].to("cuda:0")
+ else:
+ triplets, labels, random_hops = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to('cuda:0')
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_rel = self.model.forward_search(g, subj, obj)
+ pred = self.model.compute_pred_rs(hidden_all_ent, all_rel, subj, obj, random_hops)
+ # pred = self.model(g, subj, obj, random_hops) # [batch_size, num_ent]
+ loss = self.model.calc_loss(pred, labels)
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+ loss_list.append(loss.item())
+ loss = np.mean(loss_list)
+ return loss
+
+ def predict_rs(self, split, mode):
+ loss_list = []
+ pos_scores = []
+ pos_labels = []
+ self.model.eval()
+ with torch.no_grad():
+ results = dict()
+ eval_iter = self.data_iter[f'{split}_rel{mode}']
+ for step, batch in enumerate(eval_iter):
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels, input_ids = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to(
+ "cuda:0"), batch[3].to("cuda:0")
+ else:
+ triplets, labels, random_hops = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to('cuda:0')
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_rel = self.model.forward_search(g, subj, obj)
+ pred = self.model.compute_pred_rs(hidden_all_ent, all_rel, subj, obj, random_hops)
+ loss = self.model.calc_loss(pred, labels)
+ loss_list.append(loss.item())
+ pos_labels += rel.to('cpu').numpy().flatten().tolist()
+ pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist()
+ results['acc'] = metrics.accuracy_score(pos_labels, pos_scores)
+ results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro')
+ results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores)
+ loss = np.mean(loss_list)
+ return results, loss
+
+ def fine_tune(self):
+ self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0")
+ save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}'
+ save_model_path = f'{save_root}/{self.p.name}.pt'
+ save_ss_path = f'{save_root}/{self.p.name}_ss.pt'
+ save_tune_path = f'{save_root}/{self.p.name}_tune.pt'
+ self.model.load_state_dict(torch.load(str(save_model_path)))
+ self.subgraph_selector.load_state_dict(torch.load(str(save_ss_path)))
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='max', factor=0.2, patience=10, verbose=True)
+ val_results, val_loss = self.evaluate_epoch('valid', 'ps2')
+ test_results, test_loss = self.evaluate_epoch('test', 'ps2')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Validation]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ wandb.log({
+ "init_test_auroc": test_results['auroc'],
+ "init_test_auprc": test_results['auprc'],
+ "init_test_ap": test_results['ap']
+ })
+ else:
+ self.logger.info(
+ f"[Validation]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ wandb.log({
+ "init_test_acc": test_results['acc'],
+ "init_test_f1": test_results['macro_f1'],
+ "init_test_cohen": test_results['kappa']
+ })
+ for epoch in range(self.p.max_epochs):
+ start_time = time.time()
+ train_loss = self.train_epoch_fine_tune()
+ val_results, valid_loss = self.evaluate_epoch('valid', 'ps2')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ if val_results['auroc'] > self.best_val_auroc:
+ self.best_val_results = val_results
+ self.best_val_auroc = val_results['auroc']
+ self.best_epoch = epoch
+ self.save_model(save_tune_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ else:
+ if val_results['macro_f1'] > self.best_val_f1:
+ self.best_val_results = val_results
+ self.best_val_f1 = val_results['macro_f1']
+ self.best_epoch = epoch
+ self.save_model(save_tune_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}")
+ wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ "best_valid_auroc": self.best_val_auroc})
+ self.scheduler.step(self.best_val_auroc)
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}")
+ wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ "best_valid_f1": self.best_val_f1})
+ self.scheduler.step(self.best_val_f1)
+ if self.early_stop_cnt == 15:
+ self.logger.info("Early stop!")
+ break
+ self.load_model(save_tune_path)
+ self.logger.info(
+ f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data')
+ test_results, test_loss = self.evaluate_epoch('test', 'ps2')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ wandb.log({
+ "test_auroc": test_results['auroc'],
+ "test_auprc": test_results['auprc'],
+ "test_ap": test_results['ap']
+ })
+ else:
+ self.logger.info(
+ f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ wandb.log({
+ "test_acc": test_results['acc'],
+ "test_f1": test_results['macro_f1'],
+ "test_cohen": test_results['kappa']
+ })
+
+ def train_epoch_fine_tune(self, mode=None):
+ self.model.train()
+ self.subgraph_selector.eval()
+ loss_list = []
+ train_iter = self.data_iter['train_rel']
+ # train_bar = tqdm(train_iter, ncols=0)
+ for step, batch in enumerate(train_iter):
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_ent = self.model.forward_search(g, mode=mode) # [batch_size, num_ent]
+ if self.p.score_func == 'mlp_ncn':
+ cns = batch[2].to("cuda:0")
+ pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector,
+ mode='argmax')
+ else:
+ pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, mode='argmax', search_algorithm=self.p.ss_search_algorithm)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ train_loss = self.model.calc_loss(pred, labels, pos_neg)
+ else:
+ train_loss = self.model.calc_loss(pred, labels)
+ loss_list.append(train_loss.item())
+ self.optimizer.zero_grad()
+ train_loss.backward()
+ self.optimizer.step()
+
+ loss = np.mean(loss_list)
+ return loss
+
+ def ps2(self):
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}'
+ os.makedirs(save_root, exist_ok=True)
+ self.logger = get_logger(f'{save_root}/',f'train')
+ save_model_path = f'{save_root}/weight.pt'
+ save_ss_path = f'{save_root}/weight_ss.pt'
+
+ self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0")
+ self.subgraph_selector_optimizer = torch.optim.Adam(
+ self.subgraph_selector.parameters(), lr=self.p.ss_lr, weight_decay=self.p.l2)
+ # temp_scheduler = Temp_Scheduler(self.p.max_epochs, self.p.temperature, self.p.temperature, temp_min=self.p.temperature_min)
+ for epoch in range(self.p.max_epochs):
+ start_time = time.time()
+ # if self.p.cos_temp:
+ # self.p.temperature = temp_scheduler.step()
+ # else:
+ # self.p.temperature = self.p.temperature
+ train_loss = self.search_epoch()
+ val_results, valid_loss = self.evaluate_epoch('valid', 'ps2')
+ test_results, test_loss = self.evaluate_epoch('test', 'ps2')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ if val_results['auroc'] > self.best_val_auroc:
+ self.best_val_results = val_results
+ self.best_test_results = test_results
+ self.best_val_auroc = val_results['auroc']
+ self.best_test_auroc = test_results['auroc']
+ self.best_epoch = epoch
+ torch.save(self.model.state_dict(), str(save_model_path))
+ torch.save(self.subgraph_selector.state_dict(), str(save_ss_path))
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ else:
+ if val_results['macro_f1'] > self.best_val_f1:
+ self.best_val_results = val_results
+ self.best_test_results = test_results
+ self.best_val_f1 = val_results['macro_f1']
+ self.best_test_f1 = test_results['macro_f1']
+ self.best_epoch = epoch
+ torch.save(self.model.state_dict(), str(save_model_path))
+ torch.save(self.subgraph_selector.state_dict(), str(save_ss_path))
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}")
+ wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ "best_valid_auroc": self.best_val_auroc})
+ wandb.log({'best_test_auroc': self.best_test_auroc})
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}")
+ # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ # "best_valid_f1": self.best_val_f1})
+ # wandb.log({'best_test_f1': self.best_test_f1})
+ if self.early_stop_cnt == 10:
+ self.logger.info("Early stop!")
+ break
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ # self.logger.info(
+ # f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ wandb.log({
+ "test_auroc": self.best_test_results['auroc'],
+ "test_auprc": self.best_test_results['auprc'],
+ "test_ap": self.best_test_results['ap']
+ })
+ else:
+ # self.logger.info(
+ # f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ wandb.log({
+ "test_acc": self.best_test_results['acc'],
+ "test_f1": self.best_test_results['macro_f1'],
+ "test_cohen": self.best_test_results['kappa']
+ })
+
+
+ def search_epoch(self):
+ self.model.train()
+ self.subgraph_selector.train()
+ loss_list = []
+ train_iter = self.data_iter['train_rel']
+ valid_iter = self.data_iter['valid_rel']
+ # train_bar = tqdm(train_iter, ncols=0)
+ for step, batch in enumerate(train_iter):
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim]
+ if self.p.score_func == 'mlp_ncn':
+ cns = batch[2].to("cuda:0")
+ pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector)
+ else:
+ pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, search_algorithm=self.p.ss_search_algorithm)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ train_loss = self.model.calc_loss(pred, labels, pos_neg)
+ else:
+ train_loss = self.model.calc_loss(pred, labels)
+ loss_list.append(train_loss.item())
+ self.optimizer.zero_grad()
+ train_loss.backward()
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ self.subgraph_selector_optimizer.zero_grad()
+
+ batch = next(iter(valid_iter))
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim]
+ if self.p.score_func == 'mlp_ncn':
+ cns = batch[2].to("cuda:0")
+ pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector)
+ else:
+ pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, search_algorithm=self.p.ss_search_algorithm)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ valid_loss = self.model.calc_loss(pred, labels, pos_neg)
+ else:
+ valid_loss = self.model.calc_loss(pred, labels)
+ valid_loss.backward()
+ self.subgraph_selector_optimizer.step()
+ loss = np.mean(loss_list)
+ return loss
+
+ def predict_search(self, split, mode):
+ loss_list = []
+ pos_scores = []
+ pos_labels = []
+ pred_class = {}
+ self.model.eval()
+ self.subgraph_selector.eval()
+ with torch.no_grad():
+ results = dict()
+ test_iter = self.data_iter[f'{split}_rel{mode}']
+ for step, batch in enumerate(test_iter):
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim]
+ if self.p.score_func == 'mlp_ncn':
+ cns = batch[2].to("cuda:0")
+ pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector,
+ mode='argmax')
+ else:
+ pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, mode='argmax', search_algorithm=self.p.ss_search_algorithm)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ eval_loss = self.model.calc_loss(pred, labels, pos_neg)
+ m = torch.nn.Sigmoid()
+ pred = m(pred)
+ labels = labels.detach().to('cpu').numpy()
+ preds = pred.detach().to('cpu').numpy()
+ pos_neg = pos_neg.detach().to('cpu').numpy()
+ for (label_ids, pred, label_t) in zip(labels, preds, pos_neg):
+ for i, (l, p) in enumerate(zip(label_ids, pred)):
+ if l == 1:
+ if i in pred_class:
+ pred_class[i]['pred'] += [p]
+ pred_class[i]['l'] += [label_t]
+ pred_class[i]['pred_label'] += [1 if p > 0.5 else 0]
+ else:
+ pred_class[i] = {'pred': [p], 'l': [label_t], 'pred_label': [1 if p > 0.5 else 0]}
+ else:
+ eval_loss = self.model.calc_loss(pred, labels)
+ pos_labels += rel.to('cpu').numpy().flatten().tolist()
+ pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist()
+ loss_list.append(eval_loss.item())
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ roc_auc = [metrics.roc_auc_score(pred_class[l]['l'], pred_class[l]['pred']) for l in pred_class]
+ prc_auc = [metrics.average_precision_score(pred_class[l]['l'], pred_class[l]['pred']) for l in
+ pred_class]
+ ap = [metrics.accuracy_score(pred_class[l]['l'], pred_class[l]['pred_label']) for l in pred_class]
+ results['auroc'] = np.mean(roc_auc)
+ results['auprc'] = np.mean(prc_auc)
+ results['ap'] = np.mean(ap)
+ else:
+ results['acc'] = metrics.accuracy_score(pos_labels, pos_scores)
+ results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro')
+ results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores)
+ # dict_res = metrics.classification_report(pos_labels, pos_scores, output_dict=True, zero_division=1)
+ # results['macro_f1_per_class'] = get_f1_score_list(dict_res)
+ # results['acc_per_class'] = get_acc_list(dict_res)
+ loss = np.mean(loss_list)
+ return results, loss
+
+ def architecture_search(self):
+ save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}'
+ os.makedirs(save_root, exist_ok=True)
+ save_model_path = f'{save_root}/{self.p.name}.pt'
+
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, self.p.max_epochs, eta_min=self.p.lr_min)
+ self.arch_optimizer = torch.optim.Adam(self.model.arch_parameters(), lr=self.p.arch_lr, weight_decay=self.p.arch_weight_decay)
+ self.arch_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.arch_optimizer, self.p.max_epochs, eta_min=self.p.arch_lr_min)
+ temp_scheduler = Temp_Scheduler(self.p.max_epochs, self.p.temperature, self.p.temperature, temp_min=self.p.temperature_min)
+ for epoch in range(self.p.max_epochs):
+ start_time = time.time()
+ # genotype = self.model.genotype()
+ # self.logger.info(f'Genotype: {genotype}')
+ if self.p.cos_temp:
+ self.p.temperature = temp_scheduler.step()
+ else:
+ self.p.temperature = self.p.temperature
+ # print(self.p.temperature)
+ train_loss = self.arch_search_epoch()
+ self.scheduler.step()
+ self.arch_scheduler.step()
+ val_results, valid_loss = self.evaluate_epoch('valid', 'arch_search')
+ s_val_results, s_valid_loss = self.evaluate_epoch('valid', 'arch_search_s')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ if val_results['auroc'] > self.best_val_auroc:
+ self.best_val_results = val_results
+ self.best_val_auroc = val_results['auroc']
+ self.best_epoch = epoch
+ # self.save_model(save_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ else:
+ if val_results['macro_f1'] > self.best_val_f1:
+ self.best_val_results = val_results
+ self.best_val_f1 = val_results['macro_f1']
+ self.best_epoch = epoch
+ # self.save_model(save_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ genotype = self.model.genotype()
+ self.logger.info(f'[Epoch {epoch}]: LR: {self.scheduler.get_last_lr()[0]}, TEMP: {self.p.temperature}, Genotype: {genotype}')
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Valid_S Loss: {s_valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid_S AUROC: {s_val_results['auroc']:.5}, Valid_S AUPRC: {s_val_results['auprc']:.5}, Valid_S AP@50: {s_val_results['ap']:.5}")
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid_S ACC: {s_val_results['acc']:.5}, Valid_S Macro F1: {s_val_results['macro_f1']:.5}, Valid_S Cohen: {s_val_results['kappa']:.5}")
+ # if self.early_stop_cnt == 15:
+ # self.logger.info("Early stop!")
+ # break
+
+ def arch_search_epoch(self):
+ self.model.train()
+ loss_list = []
+ train_iter = self.data_iter['train_rel']
+ valid_iter = self.data_iter['valid_rel']
+ # train_bar = tqdm(train_iter, ncols=0)
+ for step, batch in enumerate(train_iter):
+ for update_idx in range(self.p.w_update_epoch):
+ self.optimizer.zero_grad()
+ self.arch_optimizer.zero_grad()
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ # hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim]
+ pred = self.model(g, subj, obj)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ loss = self.model.calc_loss(pred, labels, pos_neg)
+ else:
+ loss = self.model.calc_loss(pred, labels)
+ self.arch_optimizer.zero_grad()
+ loss.backward(retain_graph=True)
+ if self.p.clip_grad:
+ clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2)
+ self.optimizer.step()
+ loss_list.append(loss.item())
+ if self.p.alpha_mode == 'train_loss':
+ self.arch_optimizer.step()
+ elif self.p.alpha_mode == 'valid_loss':
+ self.optimizer.zero_grad()
+ self.arch_optimizer.zero_grad()
+ batch = next(iter(valid_iter))
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ pred = self.model(g, subj, obj)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ valid_loss = self.model.calc_loss(pred, labels, pos_neg)
+ else:
+ valid_loss = self.model.calc_loss(pred, labels)
+ valid_loss.backward()
+ self.arch_optimizer.step()
+ loss = np.mean(loss_list)
+ return loss
+
+ def predict_arch_search(self, split, mode, eval_mode=None):
+ loss_list = []
+ pos_scores = []
+ pos_labels = []
+ pred_class = {}
+ self.model.eval()
+ with torch.no_grad():
+ results = dict()
+ test_iter = self.data_iter[f'{split}_rel{mode}']
+ for step, batch in enumerate(test_iter):
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ # hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim]
+ pred = self.model(g, subj, obj, mode=eval_mode)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ eval_loss = self.model.calc_loss(pred, labels, pos_neg)
+ m = torch.nn.Sigmoid()
+ pred = m(pred)
+ labels = labels.detach().to('cpu').numpy()
+ preds = pred.detach().to('cpu').numpy()
+ pos_neg = pos_neg.detach().to('cpu').numpy()
+ for (label_ids, pred, label_t) in zip(labels, preds, pos_neg):
+ for i, (l, p) in enumerate(zip(label_ids, pred)):
+ if l == 1:
+ if i in pred_class:
+ pred_class[i]['pred'] += [p]
+ pred_class[i]['l'] += [label_t]
+ pred_class[i]['pred_label'] += [1 if p > 0.5 else 0]
+ else:
+ pred_class[i] = {'pred': [p], 'l': [label_t], 'pred_label': [1 if p > 0.5 else 0]}
+ else:
+ eval_loss = self.model.calc_loss(pred, labels)
+ pos_labels += rel.to('cpu').numpy().flatten().tolist()
+ pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist()
+ loss_list.append(eval_loss.item())
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ roc_auc = [metrics.roc_auc_score(pred_class[l]['l'], pred_class[l]['pred']) for l in pred_class]
+ prc_auc = [metrics.average_precision_score(pred_class[l]['l'], pred_class[l]['pred']) for l in
+ pred_class]
+ ap = [metrics.accuracy_score(pred_class[l]['l'], pred_class[l]['pred_label']) for l in pred_class]
+ results['auroc'] = np.mean(roc_auc)
+ results['auprc'] = np.mean(prc_auc)
+ results['ap'] = np.mean(ap)
+ else:
+ results['acc'] = metrics.accuracy_score(pos_labels, pos_scores)
+ results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro')
+ results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores)
+ # dict_res = metrics.classification_report(pos_labels, pos_scores, output_dict=True, zero_division=1)
+ # results['macro_f1_per_class'] = get_f1_score_list(dict_res)
+ # results['acc_per_class'] = get_acc_list(dict_res)
+ loss = np.mean(loss_list)
+ return results, loss
+
+ def arch_random_search_each(self, trial):
+ genotype_space = []
+ for i in range(self.p.n_layer):
+ genotype_space.append(trial.suggest_categorical("mess"+ str(i), COMP_PRIMITIVES))
+ genotype_space.append(trial.suggest_categorical("agg"+ str(i), AGG_PRIMITIVES))
+ genotype_space.append(trial.suggest_categorical("comb"+ str(i), COMB_PRIMITIVES))
+ genotype_space.append(trial.suggest_categorical("act"+ str(i), ACT_PRIMITIVES))
+ self.best_val_f1 = 0.0
+ self.best_val_auroc = 0.0
+ self.early_stop_cnt = 0
+ # self.best_valid_metric, self.best_test_metric = 0.0, {}
+ self.p.genotype = "||".join(genotype_space)
+ # run = self.reinit_wandb()
+ self.model = self.get_model().to("cuda:0")
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.p.lr, weight_decay=self.p.l2)
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}'
+ os.makedirs(save_root, exist_ok=True)
+ save_path = f'{save_root}/random_search.pt'
+ for epoch in range(self.p.max_epochs):
+ start_time = time.time()
+ train_loss = self.train_epoch()
+ val_results, valid_loss = self.evaluate_epoch('valid', mode='normal')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ if val_results['auroc'] > self.best_val_auroc:
+ self.best_val_results = val_results
+ self.best_val_auroc = val_results['auroc']
+ self.best_epoch = epoch
+ self.save_model(save_path)
+ self.early_stop_cnt = 0
+ # self.logger.info("Update best valid auroc!")
+ else:
+ self.early_stop_cnt += 1
+ else:
+ if val_results['macro_f1'] > self.best_val_f1:
+ self.best_val_results = val_results
+ self.best_val_f1 = val_results['macro_f1']
+ self.best_epoch = epoch
+ self.save_model(save_path)
+ self.early_stop_cnt = 0
+ # self.logger.info("Update best valid f1!")
+ else:
+ self.early_stop_cnt += 1
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}")
+ # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ # "best_valid_auroc": self.best_val_auroc})
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}")
+ # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ # "best_valid_f1": self.best_val_f1})
+ if self.early_stop_cnt == 10:
+ self.logger.info("Early stop!")
+ break
+ # self.logger.info(vars(self.p))
+ self.load_model(save_path)
+ self.logger.info(
+ f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data')
+ self.logger.info(f'{self.p.genotype}')
+ start = time.time()
+ test_results, test_loss = self.evaluate_epoch('test', mode='normal')
+ end = time.time()
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ # wandb.log({
+ # "test_auroc": test_results['auroc'],
+ # "test_auprc": test_results['auprc'],
+ # "test_ap": test_results['ap']
+ # })
+ # run.finish()
+ if self.best_val_auroc > self.best_valid_metric:
+ self.best_valid_metric = self.best_val_auroc
+ self.best_test_metric = test_results
+ with open(f'{save_root}/random_search_arch_list.csv', "a") as f:
+ writer = csv.writer(f)
+ writer.writerow([self.p.genotype, self.best_val_auroc, test_results['auroc']])
+ return self.best_val_auroc
+ else:
+ self.logger.info(
+ f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ # wandb.log({
+ # "test_acc": test_results['acc'],
+ # "test_f1": test_results['macro_f1'],
+ # "test_cohen": test_results['kappa']
+ # })
+ # run.finish()
+ if self.best_val_f1 > self.best_valid_metric:
+ self.best_valid_metric = self.best_val_f1
+ self.best_test_metric = test_results
+ with open(f'{save_root}/random_search_arch_list.csv', "a") as f:
+ writer = csv.writer(f)
+ writer.writerow([self.p.genotype, self.best_val_f1, test_results['macro_f1']])
+ return self.best_val_f1
+
+ def arch_random_search(self):
+ self.best_valid_metric = 0.0
+ self.best_test_metric = {}
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/'
+ os.makedirs(save_root, exist_ok=True)
+ self.logger = get_logger(f'{save_root}/', f'random_search')
+ study = optuna.create_study(directions=["maximize"], sampler=RandomSampler())
+ study.optimize(self.arch_random_search_each, n_trials=self.p.baseline_sample_num)
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/'
+ with open(f'{save_root}/random_search_res.txt', "w") as f1:
+ f1.write(f'{self.p.__dict__}\n')
+ f1.write(f'{self.p.genotype}\n')
+ f1.write(f'Valid performance: {study.best_value}\n')
+ f1.write(f'Test performance: {self.best_test_metric}')
+
+ def train_parameter(self, parameter):
+ self.best_val_f1 = 0.0
+ self.best_val_auroc = 0.0
+ self.early_stop_cnt = 0
+ self.p.genotype = "||".join(parameter)
+ run = self.reinit_wandb()
+ self.model = self.get_model().to("cuda:0")
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.p.lr, weight_decay=self.p.l2)
+ save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}'
+ os.makedirs(save_root, exist_ok=True)
+ save_path = f'{save_root}/{self.p.name}.pt'
+ for epoch in range(self.p.max_epochs):
+ start_time = time.time()
+ train_loss = self.train_epoch()
+ val_results, valid_loss = self.evaluate_epoch('valid', mode='normal')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ if val_results['auroc'] > self.best_val_auroc:
+ self.best_val_results = val_results
+ self.best_val_auroc = val_results['auroc']
+ self.best_epoch = epoch
+ self.save_model(save_path)
+ self.early_stop_cnt = 0
+ self.logger.info("Update best valid auroc!")
+ else:
+ self.early_stop_cnt += 1
+ else:
+ if val_results['macro_f1'] > self.best_val_f1:
+ self.best_val_results = val_results
+ self.best_val_f1 = val_results['macro_f1']
+ self.best_epoch = epoch
+ self.save_model(save_path)
+ self.early_stop_cnt = 0
+ self.logger.info("Update best valid f1!")
+ else:
+ self.early_stop_cnt += 1
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}")
+ wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ "best_valid_auroc": self.best_val_auroc})
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}")
+ wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ "best_valid_f1": self.best_val_f1})
+ if self.early_stop_cnt == 15:
+ self.logger.info("Early stop!")
+ break
+ # self.logger.info(vars(self.p))
+ self.load_model(save_path)
+ self.logger.info(
+ f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data')
+ start = time.time()
+ test_results, test_loss = self.evaluate_epoch('test', mode='normal')
+ end = time.time()
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ wandb.log({
+ "test_auroc": test_results['auroc'],
+ "test_auprc": test_results['auprc'],
+ "test_ap": test_results['ap']
+ })
+ else:
+ if self.p.add_reverse:
+ self.logger.info(
+ f"f1: Rel {test_results['left_f1']:.5}, Rel_rev {test_results['right_f1']:.5}, Avg {test_results['macro_f1']:.5}")
+ self.logger.info(
+ f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ wandb.log({
+ "test_acc": test_results['acc'],
+ "test_f1": test_results['macro_f1'],
+ "test_cohen": test_results['kappa']
+ })
+ run.finish()
+ return {'loss': -self.best_val_f1, "status": STATUS_OK}
+
+ def reinit_wandb(self):
+ if self.p.train_mode == 'spos_tune':
+ run = wandb.init(
+ reinit=True,
+ project=self.p.wandb_project,
+ settings=wandb.Settings(start_method="fork"),
+ config={
+ "dataset": self.p.dataset,
+ "encoder": self.p.encoder,
+ "score_function": self.p.score_func,
+ "batch_size": self.p.batch_size,
+ "learning_rate": self.p.lr,
+ "weight_decay":self.p.l2,
+ "encoder_layer_num": self.p.n_layer,
+ "epochs": self.p.max_epochs,
+ "seed": self.p.seed,
+ "train_mode": self.p.train_mode,
+ "init_dim": self.p.init_dim,
+ "embed_dim": self.p.embed_dim,
+ "input_type": self.p.input_type,
+ "loss_type": self.p.loss_type,
+ "search_mode": self.p.search_mode,
+ "combine_type": self.p.combine_type,
+ "genotype": self.p.genotype,
+ "exp_note": self.p.exp_note,
+ "alpha_mode": self.p.alpha_mode,
+ "few_shot_op": self.p.few_shot_op,
+ "tune_sample_num": self.p.tune_sample_num
+ }
+ )
+ elif self.p.search_mode == 'arch_random':
+ run = wandb.init(
+ reinit=True,
+ project=self.p.wandb_project,
+ settings=wandb.Settings(start_method="fork"),
+ config={
+ "dataset": self.p.dataset,
+ "encoder": self.p.encoder,
+ "score_function": self.p.score_func,
+ "batch_size": self.p.batch_size,
+ "learning_rate": self.p.lr,
+ "weight_decay":self.p.l2,
+ "encoder_layer_num": self.p.n_layer,
+ "epochs": self.p.max_epochs,
+ "seed": self.p.seed,
+ "train_mode": self.p.train_mode,
+ "init_dim": self.p.init_dim,
+ "embed_dim": self.p.embed_dim,
+ "input_type": self.p.input_type,
+ "loss_type": self.p.loss_type,
+ "search_mode": self.p.search_mode,
+ "combine_type": self.p.combine_type,
+ "genotype": self.p.genotype,
+ "exp_note": self.p.exp_note,
+ "tune_sample_num": self.p.tune_sample_num
+ }
+ )
+ return run
+
+ def joint_search(self):
+ save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}'
+ os.makedirs(save_root, exist_ok=True)
+ save_model_path = f'{save_root}/{self.p.name}.pt'
+ save_ss_path = f'{save_root}/{self.p.name}_ss.pt'
+ self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0")
+ self.upper_optimizer = torch.optim.Adam([{'params': self.model.arch_parameters()},
+ {'params': self.subgraph_selector.parameters(), 'lr': self.p.ss_lr}],
+ lr=self.p.arch_lr, weight_decay=self.p.arch_weight_decay)
+ for epoch in range(self.p.max_epochs):
+ start_time = time.time()
+ # genotype = self.model.genotype()
+ # self.logger.info(f'Genotype: {genotype}')
+ if self.p.cos_temp and epoch % 5 == 0 and epoch != 0:
+ self.p.temperature = self.p.temperature * 0.5
+ else:
+ self.p.temperature = self.p.temperature
+ train_loss = self.joint_search_epoch()
+ # self.scheduler.step()
+ # self.arch_scheduler.step()
+ val_results, valid_loss = self.evaluate_epoch('valid', 'joint_search')
+ s_val_results, s_valid_loss = self.evaluate_epoch('valid', 'joint_search_s')
+ if val_results['macro_f1'] > self.best_val_f1:
+ self.best_val_results = val_results
+ self.best_val_f1 = val_results['macro_f1']
+ self.best_epoch = epoch
+ torch.save(self.model.state_dict(), str(save_model_path))
+ torch.save(self.subgraph_selector.state_dict(), str(save_ss_path))
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ genotype = self.model.genotype()
+ # self.logger.info(f'[Epoch {epoch}]: LR: {self.scheduler.get_last_lr()[0]}, TEMP: {self.p.temperature}, Genotype: {genotype}')
+ self.logger.info(
+ f'[Epoch {epoch}]: TEMP: {self.p.temperature}, Genotype: {genotype}')
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Valid_S Loss: {s_valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if self.p.dataset == 'drugbank':
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid_S ACC: {s_val_results['acc']:.5}, Valid Macro F1: {s_val_results['macro_f1']:.5}, Valid Cohen: {s_val_results['kappa']:.5}")
+ if self.early_stop_cnt == 50:
+ self.logger.info("Early stop!")
+ break
+
+ def joint_search_epoch(self):
+ self.model.train()
+ self.subgraph_selector.train()
+ loss_list = []
+ train_iter = self.data_iter['train_rel']
+ valid_iter = self.data_iter['valid_rel']
+ # train_bar = tqdm(train_iter, ncols=0)
+ for step, batch in enumerate(train_iter):
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim]
+ if self.p.score_func == 'mlp_ncn':
+ cns = batch[2].to("cuda:0")
+ pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector)
+ else:
+ pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, search_algorithm=self.p.search_algorithm)
+ train_loss = self.model.calc_loss(pred, labels)
+ loss_list.append(train_loss.item())
+ self.optimizer.zero_grad()
+ train_loss.backward()
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ self.upper_optimizer.zero_grad()
+
+ batch = next(iter(valid_iter))
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim]
+ if self.p.score_func == 'mlp_ncn':
+ cns = batch[2].to("cuda:0")
+ pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector)
+ else:
+ pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, search_algorithm=self.p.search_algorithm)
+ valid_loss = self.model.calc_loss(pred, labels)
+ valid_loss.backward()
+ self.upper_optimizer.step()
+ loss = np.mean(loss_list)
+ return loss
+
+ def predict_joint_search(self, split, mode, eval_mode=None):
+ loss_list = []
+ pos_scores = []
+ pos_labels = []
+ self.model.eval()
+ self.subgraph_selector.eval()
+ with torch.no_grad():
+ results = dict()
+ test_iter = self.data_iter[f'{split}_rel{mode}']
+ for step, batch in enumerate(test_iter):
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_ent = self.model.forward_search(g, mode=eval_mode) # [batch_size, num_ent]
+ if self.p.score_func == 'mlp_ncn':
+ cns = batch[2].to("cuda:0")
+ pred = self.model.compute_pred(hidden_all_ent, all_ent, subj, obj, cns, self.subgraph_selector,
+ mode='argmax')
+ else:
+ pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, mode='argmax', search_algorithm=self.p.search_algorithm)
+ eval_loss = self.model.calc_loss(pred, labels)
+ loss_list.append(eval_loss.item())
+ pos_labels += rel.to('cpu').numpy().flatten().tolist()
+ pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist()
+ results['acc'] = metrics.accuracy_score(pos_labels, pos_scores)
+ results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro')
+ results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores)
+ loss = np.mean(loss_list)
+ return results, loss
+
+ def joint_tune(self):
+ self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0")
+ save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}'
+ save_model_path = f'{save_root}/{self.p.name}.pt'
+ save_ss_path = f'{save_root}/{self.p.name}_ss.pt'
+ save_tune_path = f'{save_root}/{self.p.name}_tune.pt'
+ self.model.load_state_dict(torch.load(str(save_model_path)))
+ self.subgraph_selector.load_state_dict(torch.load(str(save_ss_path)))
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='max', factor=0.2, patience=10, verbose=True)
+
+ val_results, val_loss = self.evaluate_epoch('valid', 'joint_search_s')
+ test_results, test_loss = self.evaluate_epoch('test', 'joint_search_s')
+ if self.p.dataset == 'drugbank':
+ # if self.p.add_reverse:
+ # self.logger.info(
+ # f"f1: Rel {test_results['left_f1']:.5}, Rel_rev {test_results['right_f1']:.5}, Avg {test_results['macro_f1']:.5}")
+ self.logger.info(
+ f"[Validation]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ wandb.log({
+ "init_test_acc": test_results['acc'],
+ "init_test_f1": test_results['macro_f1'],
+ "init_test_cohen": test_results['kappa']
+ })
+ for epoch in range(self.p.max_epochs):
+ start_time = time.time()
+ train_loss = self.train_epoch_fine_tune(mode='evaluate_single_path')
+ val_results, valid_loss = self.evaluate_epoch('valid', 'joint_search_s')
+ if val_results['macro_f1'] > self.best_val_f1:
+ self.best_val_results = val_results
+ self.best_val_f1 = val_results['macro_f1']
+ self.best_epoch = epoch
+ self.save_model(save_tune_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ # torch.save(self.model.state_dict(), str(save_model_path))
+ # torch.save(self.subgraph_selector.state_dict(), str(save_ss_path))
+ genotype = self.model.genotype()
+ self.logger.info(
+ f'[Epoch {epoch}]: Genotype: {genotype}')
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}")
+ wandb.log({"train_loss": train_loss, "best_valid_f1": self.best_val_f1})
+ if self.early_stop_cnt == 15:
+ self.logger.info("Early stop!")
+ break
+ self.scheduler.step(self.best_val_f1)
+ self.load_model(save_tune_path)
+ self.logger.info(
+ f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data')
+ test_results, test_loss = self.evaluate_epoch('test', 'joint_search_s')
+ test_acc = test_results['acc']
+ test_f1 = test_results['macro_f1']
+ test_cohen = test_results['kappa']
+ wandb.log({
+ "test_acc": test_acc,
+ "test_f1": test_f1,
+ "test_cohen": test_cohen
+ })
+ if self.p.dataset == 'drugbank':
+ test_acc = test_results['acc']
+ test_f1 = test_results['macro_f1']
+ test_cohen = test_results['kappa']
+ wandb.log({
+ "test_acc": test_acc,
+ "test_f1": test_f1,
+ "test_cohen": test_cohen
+ })
+ if self.p.add_reverse:
+ self.logger.info(
+ f"f1: Rel {test_results['left_f1']:.5}, Rel_rev {test_results['right_f1']:.5}, Avg {test_results['macro_f1']:.5}")
+ self.logger.info(
+ f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+
+ def spos_train_supernet(self):
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}'
+ print(save_root)
+ os.makedirs(save_root, exist_ok=True)
+ if self.p.weight_sharing:
+ self.logger = get_logger(f'{save_root}/', f'train_supernet_ws_{self.p.few_shot_op}')
+ save_model_path = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{"_".join(save_root.split("/")[-1].split("_")[:-3])}/400.pt'
+ self.model.load_state_dict(torch.load(save_model_path))
+ else:
+ self.logger = get_logger(f'{save_root}/', f'train_supernet')
+ for epoch in range(1, self.p.max_epochs+1):
+ start_time = time.time()
+ train_loss = self.architecture_search_spos_epoch()
+ val_results, valid_loss = self.evaluate_epoch('valid', 'spos_train_supernet')
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ # self.scheduler.step(train_loss)
+ wandb.log({
+ "train_loss": train_loss,
+ "valid_loss": valid_loss
+ })
+ if epoch % 100 == 0:
+ torch.save(self.model.state_dict(), f'{save_root}/{epoch}.pt')
+
+ def architecture_search_spos_epoch(self):
+ self.model.train()
+ loss_list = []
+ train_iter = self.data_iter['train_rel']
+ # train_bar = tqdm(train_iter, ncols=0)
+ for step, batch in enumerate(train_iter):
+ self.optimizer.zero_grad()
+ self.generate_single_path()
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ # hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim]
+ pred = self.model(g, subj, obj)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ loss = self.model.calc_loss(pred, labels, pos_neg)
+ else:
+ loss = self.model.calc_loss(pred, labels)
+ loss.backward()
+ if self.p.clip_grad:
+ clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2)
+ self.optimizer.step()
+ loss_list.append(loss.item())
+ loss = np.mean(loss_list)
+ return loss
+
+ def predict_spos_search(self, split, mode, eval_mode=None, spos_mode='train_supernet'):
+ loss_list = []
+ pos_scores = []
+ pos_labels = []
+ pred_class = {}
+ self.model.eval()
+ with torch.no_grad():
+ results = dict()
+ test_iter = self.data_iter[f'{split}_rel{mode}']
+ for step, batch in enumerate(test_iter):
+ if spos_mode == 'train_supernet':
+ self.generate_single_path()
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ # hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim]
+ pred = self.model(g, subj, obj, mode=eval_mode)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ eval_loss = self.model.calc_loss(pred, labels, pos_neg)
+ m = torch.nn.Sigmoid()
+ pred = m(pred)
+ labels = labels.detach().to('cpu').numpy()
+ preds = pred.detach().to('cpu').numpy()
+ pos_neg = pos_neg.detach().to('cpu').numpy()
+ for (label_ids, pred, label_t) in zip(labels, preds, pos_neg):
+ for i, (l, p) in enumerate(zip(label_ids, pred)):
+ if l == 1:
+ if i in pred_class:
+ pred_class[i]['pred'] += [p]
+ pred_class[i]['l'] += [label_t]
+ pred_class[i]['pred_label'] += [1 if p > 0.5 else 0]
+ else:
+ pred_class[i] = {'pred': [p], 'l': [label_t], 'pred_label': [1 if p > 0.5 else 0]}
+ else:
+ eval_loss = self.model.calc_loss(pred, labels)
+ pos_labels += rel.to('cpu').numpy().flatten().tolist()
+ pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist()
+ loss_list.append(eval_loss.item())
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ roc_auc = [metrics.roc_auc_score(pred_class[l]['l'], pred_class[l]['pred']) for l in pred_class]
+ prc_auc = [metrics.average_precision_score(pred_class[l]['l'], pred_class[l]['pred']) for l in
+ pred_class]
+ ap = [metrics.accuracy_score(pred_class[l]['l'], pred_class[l]['pred_label']) for l in pred_class]
+ results['auroc'] = np.mean(roc_auc)
+ results['auprc'] = np.mean(prc_auc)
+ results['ap'] = np.mean(ap)
+ else:
+ results['acc'] = metrics.accuracy_score(pos_labels, pos_scores)
+ results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro')
+ results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores)
+ loss = np.mean(loss_list)
+ return results, loss
+
+ def spos_arch_search(self):
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}'
+ os.makedirs(save_root, exist_ok=True)
+ for save_epoch in [800,700,600,500,400,300,200,100]:
+ try:
+ self.model.load_state_dict(torch.load(f'{save_root}/{save_epoch}.pt'))
+ except:
+ continue
+ self.logger = get_logger(f'{save_root}/', f'{save_epoch}_arch_search')
+ valid_loss_searched_arch_res = dict()
+ valid_f1_searched_arch_res = dict()
+ valid_auroc_searched_arch_res = dict()
+ search_time = 0.0
+ t_start = time.time()
+ for epoch in range(1, self.p.spos_arch_sample_num + 1):
+ self.generate_single_path()
+ arch = "||".join(self.model.ops)
+ val_results, valid_loss = self.evaluate_epoch('valid', 'spos_arch_search')
+ test_results, test_loss = self.evaluate_epoch('test', 'spos_arch_search')
+ valid_loss_searched_arch_res.setdefault(arch, valid_loss)
+ self.logger.info(f'[Epoch {epoch}]: Path:{arch}')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid Loss: {valid_loss:.5}, Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Test Loss: {test_loss:.5}, Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ valid_auroc_searched_arch_res.setdefault(arch, val_results['auroc'])
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid Loss: {valid_loss:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid ACC: {val_results['acc']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Test Loss: {test_loss:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test ACC: {test_results['acc']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ valid_f1_searched_arch_res.setdefault(arch, val_results['macro_f1'])
+
+ t_end = time.time()
+ search_time = (t_end - t_start)
+
+ search_time = search_time / 3600
+ self.logger.info(f'The search process costs {search_time:.2f}h.')
+ import csv
+ with open(f'{save_root}/valid_loss_{save_epoch}.csv', 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['arch', 'valid loss'])
+ valid_loss_searched_arch_res_sorted = sorted(valid_loss_searched_arch_res.items(), key=lambda x :x[1])
+ res = valid_loss_searched_arch_res_sorted
+ for i in range(len(res)):
+ writer.writerow([res[i][0], res[i][1]])
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ with open(f'{save_root}/valid_auroc_{save_epoch}.csv', 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['arch', 'valid auroc'])
+ valid_auroc_searched_arch_res_sorted = sorted(valid_auroc_searched_arch_res.items(), key=lambda x: x[1],
+ reverse=True)
+ res = valid_auroc_searched_arch_res_sorted
+ for i in range(len(res)):
+ writer.writerow([res[i][0], res[i][1]])
+ else:
+ with open(f'{save_root}/valid_f1_{save_epoch}.csv', 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['arch', 'valid f1'])
+ valid_f1_searched_arch_res_sorted = sorted(valid_f1_searched_arch_res.items(), key=lambda x: x[1], reverse=True)
+ res = valid_f1_searched_arch_res_sorted
+ for i in range(len(res)):
+ writer.writerow([res[i][0], res[i][1]])
+
+ def generate_single_path(self):
+ if self.p.exp_note is None:
+ self.model.ops = self.model.generate_single_path()
+ elif self.p.exp_note == 'only_search_act':
+ self.model.ops = self.model.generate_single_path_act()
+ elif self.p.exp_note == 'only_search_comb':
+ self.model.ops = self.model.generate_single_path_comb()
+ elif self.p.exp_note == 'only_search_comp':
+ self.model.ops = self.model.generate_single_path_comp()
+ elif self.p.exp_note == 'only_search_agg':
+ self.model.ops = self.model.generate_single_path_agg()
+ elif self.p.exp_note == 'only_search_agg_comb':
+ self.model.ops = self.model.generate_single_path_agg_comb()
+ elif self.p.exp_note == 'only_search_agg_comb_comp':
+ self.model.ops = self.model.generate_single_path_agg_comb_comp()
+ elif self.p.exp_note == 'only_search_agg_comb_act_rotate':
+ self.model.ops = self.model.generate_single_path_agg_comb_act_rotate()
+ elif self.p.exp_note == 'only_search_agg_comb_act_mult':
+ self.model.ops = self.model.generate_single_path_agg_comb_act_mult()
+ elif self.p.exp_note == 'only_search_agg_comb_act_ccorr':
+ self.model.ops = self.model.generate_single_path_agg_comb_act_ccorr()
+ elif self.p.exp_note == 'only_search_agg_comb_act_sub':
+ self.model.ops = self.model.generate_single_path_agg_comb_act_sub()
+ elif self.p.exp_note == 'spfs' and self.p.search_algorithm == 'spos_arch_search_ps2':
+ self.model.ops = self.model.generate_single_path()
+ elif self.p.exp_note == 'spfs' and self.p.few_shot_op is not None:
+ self.model.ops = self.model.generate_single_path_agg_comb_act_few_shot_comp(self.p.few_shot_op)
+ # print(1)
+
+ def spos_train_supernet_ps2(self):
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}'
+ log_root = f'{save_root}/log'
+ os.makedirs(log_root, exist_ok=True)
+ self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0")
+ self.subgraph_selector_optimizer = torch.optim.Adam(
+ self.subgraph_selector.parameters(), lr=self.p.ss_lr, weight_decay=self.p.l2)
+ self.logger = get_logger(f'{log_root}/', f'train_supernet')
+ if self.p.weight_sharing:
+ save_model_path = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{"_".join(save_root.split("/")[-1].split("_")[:-3])}/400.pt'
+ save_ss_path = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{"_".join(save_root.split("/")[-1].split("_")[:-3])}/400_ss.pt'
+ print(save_model_path)
+ print(save_ss_path)
+ self.model.load_state_dict(torch.load(save_model_path))
+ self.subgraph_selector.load_state_dict(torch.load(save_ss_path))
+ for epoch in range(1, self.p.max_epochs+1):
+ start_time = time.time()
+ train_loss = self.architecture_search_spos_ps2_epoch()
+ val_results, valid_loss = self.evaluate_epoch('valid', 'spos_train_supernet_ps2')
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ # self.scheduler.step(train_loss)
+ wandb.log({
+ "train_loss": train_loss,
+ "valid_loss": valid_loss
+ })
+ if epoch % 100 == 0:
+ torch.save(self.model.state_dict(), f'{save_root}/{epoch}.pt')
+ torch.save(self.subgraph_selector.state_dict(), f'{save_root}/{epoch}_ss.pt')
+
+ def architecture_search_spos_ps2_epoch(self):
+ self.model.train()
+ self.subgraph_selector.train()
+ loss_list = []
+ train_iter = self.data_iter['train_rel']
+ valid_iter = self.data_iter['valid_rel']
+ # train_bar = tqdm(train_iter, ncols=0)
+ for step, batch in enumerate(train_iter):
+ self.optimizer.zero_grad()
+ self.generate_single_path()
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector,
+ search_algorithm=self.p.ss_search_algorithm)
+
+ # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim]
+ # pred = self.model(g, subj, obj)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ train_loss = self.model.calc_loss(pred, labels, pos_neg)
+ else:
+ train_loss = self.model.calc_loss(pred, labels)
+ train_loss.backward()
+ if self.p.clip_grad:
+ clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2)
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ if self.p.alpha_mode == 'train_loss':
+ self.subgraph_selector_optimizer.step()
+ elif self.p.alpha_mode == 'valid_loss':
+ self.optimizer.zero_grad()
+ self.subgraph_selector_optimizer.zero_grad()
+ batch = next(iter(valid_iter))
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector,
+ search_algorithm=self.p.ss_search_algorithm)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ valid_loss = self.model.calc_loss(pred, labels, pos_neg)
+ else:
+ valid_loss = self.model.calc_loss(pred, labels)
+ valid_loss.backward()
+ self.subgraph_selector_optimizer.step()
+ loss_list.append(train_loss.item())
+ loss = np.mean(loss_list)
+ return loss
+
+ def predict_spos_search_ps2(self, split, mode, eval_mode=None, spos_mode='train_supernet'):
+ loss_list = []
+ pos_scores = []
+ pos_labels = []
+ pred_class = {}
+ self.model.eval()
+ self.subgraph_selector.eval()
+ with torch.no_grad():
+ results = dict()
+ test_iter = self.data_iter[f'{split}_rel{mode}']
+ for step, batch in enumerate(test_iter):
+ if spos_mode == 'train_supernet':
+ self.generate_single_path()
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ pred = self.model.compute_pred(hidden_all_ent, subj, obj, self.subgraph_selector, mode='argmax',
+ search_algorithm=self.p.ss_search_algorithm)
+ # print(hidden_all_ent.size()) # [num_ent, encoder_layer, dim]
+ # pred = self.model(g, subj, obj, mode=eval_mode)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ eval_loss = self.model.calc_loss(pred, labels, pos_neg)
+ m = torch.nn.Sigmoid()
+ pred = m(pred)
+ labels = labels.detach().to('cpu').numpy()
+ preds = pred.detach().to('cpu').numpy()
+ pos_neg = pos_neg.detach().to('cpu').numpy()
+ for (label_ids, pred, label_t) in zip(labels, preds, pos_neg):
+ for i, (l, p) in enumerate(zip(label_ids, pred)):
+ if l == 1:
+ if i in pred_class:
+ pred_class[i]['pred'] += [p]
+ pred_class[i]['l'] += [label_t]
+ pred_class[i]['pred_label'] += [1 if p > 0.5 else 0]
+ else:
+ pred_class[i] = {'pred': [p], 'l': [label_t], 'pred_label': [1 if p > 0.5 else 0]}
+ else:
+ eval_loss = self.model.calc_loss(pred, labels)
+ pos_labels += rel.to('cpu').numpy().flatten().tolist()
+ pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist()
+ loss_list.append(eval_loss.item())
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ roc_auc = [metrics.roc_auc_score(pred_class[l]['l'], pred_class[l]['pred']) for l in pred_class]
+ prc_auc = [metrics.average_precision_score(pred_class[l]['l'], pred_class[l]['pred']) for l in
+ pred_class]
+ ap = [metrics.accuracy_score(pred_class[l]['l'], pred_class[l]['pred_label']) for l in pred_class]
+ results['auroc'] = np.mean(roc_auc)
+ results['auprc'] = np.mean(prc_auc)
+ results['ap'] = np.mean(ap)
+ else:
+ results['acc'] = metrics.accuracy_score(pos_labels, pos_scores)
+ results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro')
+ results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores)
+ # dict_res = metrics.classification_report(pos_labels, pos_scores, output_dict=True, zero_division=1)
+ # results['macro_f1_per_class'] = get_f1_score_list(dict_res)
+ # results['acc_per_class'] = get_acc_list(dict_res)
+ loss = np.mean(loss_list)
+ return results, loss
+
+ def spos_arch_search_ps2(self):
+ res_list = []
+ sorted_list = []
+ exp_note = '_' + self.p.exp_note if self.p.exp_note is not None else ''
+ self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0")
+ # save_root = f'{self.prj_path}/checkpoints/{self.p.dataset}/{self.p.search_mode}'
+ # save_model_path = f'{save_root}/{self.p.name}.pt'
+ # save_ss_path = f'{save_root}/{self.p.name}_ss.pt'
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}'
+ if self.p.exp_note == 'spfs':
+ epoch_list = [400]
+ weight_sharing = '_ws'
+ else:
+ epoch_list = [800, 700,600]
+ weight_sharing = ''
+ for save_epoch in epoch_list:
+ # try:
+ # self.model.load_state_dict(torch.load(f'{save_root}/{save_epoch}.pt'))
+ # self.subgraph_selector.load_state_dict(torch.load(f'{save_root}/{save_epoch}_ss.pt'))
+ # except:
+ # continue
+ self.logger = get_logger(f'{save_root}/log/', f'{save_epoch}_arch_search{exp_note}')
+ # res_root = f'{self.prj_path}/search_res/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/'
+ # os.makedirs(res_root, exist_ok=True)
+ valid_loss_searched_arch_res = dict()
+ valid_f1_searched_arch_res = dict()
+ valid_auroc_searched_arch_res = dict()
+ t_start = time.time()
+ for epoch in range(0, self.p.spos_arch_sample_num):
+ for sample_idx in range(1, self.p.asng_sample_num+1):
+ self.generate_single_path()
+ arch = "||".join(self.model.ops)
+ if self.p.exp_note == 'spfs':
+ few_shot_op = self.model.ops[0]
+ else:
+ few_shot_op = ''
+ self.model.load_state_dict(torch.load(f'{save_root}{exp_note}_{few_shot_op}{weight_sharing}/{save_epoch}.pt'))
+ self.subgraph_selector.load_state_dict(torch.load(f'{save_root}{exp_note}_{few_shot_op}{weight_sharing}/{save_epoch}_ss.pt'))
+ val_results, valid_loss = self.evaluate_epoch('valid', 'spos_arch_search_ps2')
+ test_results, test_loss = self.evaluate_epoch('test', 'spos_arch_search_ps2')
+ valid_loss_searched_arch_res.setdefault(arch, valid_loss)
+ self.logger.info(f'[Epoch {epoch*self.p.asng_sample_num+sample_idx}]: Path:{arch}')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch*self.p.asng_sample_num+sample_idx}]: Valid Loss: {valid_loss:.5}, Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch*self.p.asng_sample_num+sample_idx}]: Test Loss: {test_loss:.5}, Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ valid_auroc_searched_arch_res.setdefault(arch, val_results['auroc'])
+ sorted_list.append(val_results['auroc'])
+ else:
+ self.logger.info(
+ f"[Epoch {epoch*self.p.asng_sample_num+sample_idx}]: Valid Loss: {valid_loss:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid ACC: {val_results['acc']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch*self.p.asng_sample_num+sample_idx}]: Test Loss: {test_loss:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test ACC: {test_results['acc']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ valid_f1_searched_arch_res.setdefault(arch, val_results['macro_f1'])
+ sorted_list.append(val_results['macro_f1'])
+ res_list.append(sorted(sorted_list, reverse=True)[:self.p.asng_sample_num])
+ with open(f"{save_root}/topK_{save_epoch}{exp_note}.pkl", "wb") as f:
+ pickle.dump(res_list, f)
+ t_end = time.time()
+ search_time = (t_end - t_start)
+
+ search_time = search_time / 3600
+ self.logger.info(f'The search process costs {search_time:.2f}h.')
+ import csv
+ with open(f'{save_root}/valid_loss_{save_epoch}{exp_note}.csv', 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['arch', 'valid loss'])
+ valid_loss_searched_arch_res_sorted = sorted(valid_loss_searched_arch_res.items(), key=lambda x :x[1])
+ res = valid_loss_searched_arch_res_sorted
+ for i in range(len(res)):
+ writer.writerow([res[i][0], res[i][1]])
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ with open(f'{save_root}/valid_auroc_{save_epoch}{exp_note}.csv', 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['arch', 'valid auroc'])
+ valid_auroc_searched_arch_res_sorted = sorted(valid_auroc_searched_arch_res.items(), key=lambda x: x[1],
+ reverse=True)
+ res = valid_auroc_searched_arch_res_sorted
+ for i in range(len(res)):
+ writer.writerow([res[i][0], res[i][1]])
+ else:
+ with open(f'{save_root}/valid_f1_{save_epoch}{exp_note}.csv', 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['arch', 'valid f1'])
+ valid_f1_searched_arch_res_sorted = sorted(valid_f1_searched_arch_res.items(), key=lambda x: x[1], reverse=True)
+ res = valid_f1_searched_arch_res_sorted
+ for i in range(len(res)):
+ writer.writerow([res[i][0], res[i][1]])
+
+ def joint_spos_ps2_fine_tune(self):
+ self.best_valid_metric = 0.0
+ self.best_test_metric = {}
+ arch_rank = 1
+ exp_note = '_' + self.p.exp_note if self.p.exp_note is not None else ''
+ for save_epoch in [400]:
+ self.save_epoch = save_epoch
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}'
+ res_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/res'
+ log_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/log'
+ tmp_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/tmp'
+ os.makedirs(res_root, exist_ok=True)
+ os.makedirs(log_root, exist_ok=True)
+ os.makedirs(tmp_root, exist_ok=True)
+ print(save_root)
+ self.logger = get_logger(f'{log_root}/', f'fine_tune_e{save_epoch}_vmtop{arch_rank}{exp_note}')
+ metric = 'auroc' if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset else 'f1'
+ with open(f'{save_root}/valid_{metric}_{self.save_epoch}{exp_note}_ng.csv', 'r') as csv_file:
+ reader = csv.reader(csv_file)
+ for _ in range(arch_rank):
+ next(reader) # rank 1 is to skip head of csv
+ self.p.genotype = next(reader)[0]
+ self.model.ops = self.p.genotype.split("||")
+ if self.p.exp_note == 'spfs':
+ weight_sharing = '_ws'
+ few_shot_op = self.model.ops[0]
+ else:
+ weight_sharing = ''
+ few_shot_op = ''
+ self.ss_path = f'{save_root}{exp_note}_{few_shot_op}{weight_sharing}/{self.save_epoch}_ss.pt'
+ study = optuna.create_study(directions=["maximize"])
+ study.optimize(self.spfs_fine_tune_each, n_trials=self.p.tune_sample_num)
+ self.p.lr = 10 ** study.best_params["learning_rate"]
+ self.p.l2 = 10 ** study.best_params["weight_decay"]
+ with open(f'{res_root}/fine_tune_e{save_epoch}_vmtop{arch_rank}{exp_note}.txt', "w") as f1:
+ f1.write(f'{self.p.__dict__}\n')
+ f1.write(f'Valid performance: {study.best_value}\n')
+ f1.write(f'Test performance: {self.best_test_metric}')
+
+ def joint_spos_fine_tune(self, parameter):
+ self.best_val_f1 = 0.0
+ self.best_val_auroc = 0.0
+ self.early_stop_cnt = 0
+ self.p.lr = 10 ** parameter['learning_rate']
+ self.p.l2 = 10 ** parameter['weight_decay']
+ run = self.reinit_wandb()
+ run.finish()
+ return {'loss': -self.p.lr, 'test_metric':self.p.l2, "status": STATUS_OK}
+
+ def spfs_fine_tune_each(self, trial):
+ exp_note = '_' + self.p.exp_note if self.p.exp_note is not None else ''
+ self.best_val_f1 = 0.0
+ self.best_val_auroc = 0.0
+ self.early_stop_cnt = 0
+ learning_rate = trial.suggest_float("learning_rate", -3.05, -2.95)
+ weight_decay = trial.suggest_float("weight_decay", -5, -3)
+ self.p.lr = 10 ** learning_rate
+ self.p.l2 = 10 ** weight_decay
+ self.model = self.get_model().to("cuda:0")
+ print(type(self.model).__name__)
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.p.lr, weight_decay=self.p.l2)
+ self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0")
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}'
+ save_model_path = f'{save_root}/tmp/fine_tune_e{self.save_epoch}{exp_note}.pt'
+ self.subgraph_selector.load_state_dict(torch.load(str(self.ss_path)))
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='max', factor=0.2, patience=10, verbose=True)
+ for epoch in range(self.p.max_epochs):
+ start_time = time.time()
+ train_loss = self.train_epoch_fine_tune()
+ val_results, valid_loss = self.evaluate_epoch('valid', 'spos_arch_search_ps2')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ if val_results['auroc'] > self.best_val_auroc:
+ self.best_val_results = val_results
+ self.best_val_auroc = val_results['auroc']
+ self.best_epoch = epoch
+ self.save_model(save_model_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ else:
+ if val_results['macro_f1'] > self.best_val_f1:
+ self.best_val_results = val_results
+ self.best_val_f1 = val_results['macro_f1']
+ self.best_epoch = epoch
+ self.save_model(save_model_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}")
+ # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ # "best_valid_auroc": self.best_val_auroc})
+ self.scheduler.step(self.best_val_auroc)
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}")
+ # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ # "best_valid_f1": self.best_val_f1})
+ self.scheduler.step(self.best_val_f1)
+ if self.early_stop_cnt == 15:
+ self.logger.info("Early stop!")
+ break
+ self.load_model(save_model_path)
+ self.logger.info(
+ f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data')
+ test_results, test_loss = self.evaluate_epoch('test', 'spos_arch_search_ps2')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ # wandb.log({
+ # "test_auroc": test_results['auroc'],
+ # "test_auprc": test_results['auprc'],
+ # "test_ap": test_results['ap']
+ # })
+ # run.finish()
+ if self.best_val_auroc > self.best_valid_metric:
+ self.best_valid_metric = self.best_val_auroc
+ self.best_test_metric = test_results
+ return self.best_val_auroc
+ else:
+ self.logger.info(
+ f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ # wandb.log({
+ # "test_acc": test_results['acc'],
+ # "test_f1": test_results['macro_f1'],
+ # "test_cohen": test_results['kappa']
+ # })
+ # run.finish()
+ if self.best_val_f1 > self.best_valid_metric:
+ self.best_valid_metric = self.best_val_f1
+ self.best_test_metric = test_results
+ return self.best_val_f1
+
+ def spos_fine_tune(self):
+ self.best_valid_metric = 0.0
+ self.best_test_metric = {}
+ arch_rank = 1
+ for save_epoch in [800]:
+ self.save_epoch = save_epoch
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/'
+ print(save_root)
+ self.logger = get_logger(f'{save_root}/', f'{save_epoch}_finu_tune_vmtop{arch_rank}')
+ metric = 'auroc' if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset else 'f1'
+ with open(f'{save_root}/valid_{metric}_{self.save_epoch}.csv', 'r') as csv_file:
+ reader = csv.reader(csv_file)
+ for _ in range(arch_rank):
+ next(reader) # rank 1 is to skip head of csv
+ self.p.genotype = next(reader)[0]
+ self.model.ops = self.p.genotype.split("||")
+ study = optuna.create_study(directions=["maximize"])
+ study.optimize(self.spos_fine_tune_each, n_trials=self.p.tune_sample_num)
+ self.p.lr = 10 ** study.best_params["learning_rate"]
+ self.p.l2 = 10 ** study.best_params["weight_decay"]
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/'
+ with open(f'{save_root}/{save_epoch}_tune_res_vmtop{arch_rank}.txt', "w") as f1:
+ f1.write(f'{self.p.__dict__}\n')
+ f1.write(f'Valid performance: {study.best_value}\n')
+ f1.write(f'Test performance: {self.best_test_metric}')
+
+ def spos_fine_tune_each(self, trial):
+ self.best_val_f1 = 0.0
+ self.best_val_auroc = 0.0
+ self.early_stop_cnt = 0
+ learning_rate = trial.suggest_float("learning_rate", -3.05, -2.95)
+ weight_decay = trial.suggest_float("weight_decay", -5, -3)
+ self.p.lr = 10 ** learning_rate
+ self.p.l2 = 10 ** weight_decay
+ # run = self.reinit_wandb()
+ self.model = self.get_model().to("cuda:0")
+ print(type(self.model).__name__)
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.p.lr, weight_decay=self.p.l2)
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/'
+ # for save_epoch in [100, 200, 300, 400]:
+ save_model_path = f'{save_root}/{self.save_epoch}_tune.pt'
+ # self.model.load_state_dict(torch.load(str(supernet_path)))
+ for epoch in range(self.p.max_epochs):
+ start_time = time.time()
+ train_loss = self.train_epoch()
+ val_results, valid_loss = self.evaluate_epoch('valid', 'spos_arch_search')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ if val_results['auroc'] > self.best_val_auroc:
+ self.best_val_results = val_results
+ self.best_val_auroc = val_results['auroc']
+ self.best_epoch = epoch
+ self.save_model(save_model_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ else:
+ if val_results['macro_f1'] > self.best_val_f1:
+ self.best_val_results = val_results
+ self.best_val_f1 = val_results['macro_f1']
+ self.best_epoch = epoch
+ self.save_model(save_model_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}")
+ # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ # "best_valid_auroc": self.best_val_auroc})
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}")
+ # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ # "best_valid_f1": self.best_val_f1})
+ if self.early_stop_cnt == 10:
+ self.logger.info("Early stop!")
+ break
+ self.load_model(save_model_path)
+ self.logger.info(
+ f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data')
+ test_results, test_loss = self.evaluate_epoch('test', 'spos_arch_search')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ # wandb.log({
+ # "test_auroc": test_results['auroc'],
+ # "test_auprc": test_results['auprc'],
+ # "test_ap": test_results['ap']
+ # })
+ # run.finish()
+ if self.best_val_auroc > self.best_valid_metric:
+ self.best_valid_metric = self.best_val_auroc
+ self.best_test_metric = test_results
+ return self.best_val_auroc
+ else:
+ self.logger.info(
+ f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ # wandb.log({
+ # "test_acc": test_results['acc'],
+ # "test_f1": test_results['macro_f1'],
+ # "test_cohen": test_results['kappa']
+ # })
+ # run.finish()
+ if self.best_val_f1 > self.best_valid_metric:
+ self.best_valid_metric = self.best_val_f1
+ self.best_test_metric = test_results
+ return self.best_val_f1
+
+ def train_mix_hop(self):
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.train_mode}/{self.p.name}/'
+ os.makedirs(save_root, exist_ok=True)
+ self.logger = get_logger(f'{save_root}/', f'train')
+ save_path = f'{save_root}/model_weight.pt'
+ for epoch in range(self.p.max_epochs):
+ start_time = time.time()
+ train_loss = self.train_epoch_mix_hop()
+ val_results, valid_loss = self.evaluate_epoch('valid', mode='normal_mix_hop')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ if val_results['auroc'] > self.best_val_auroc:
+ self.best_val_results = val_results
+ self.best_val_auroc = val_results['auroc']
+ self.best_epoch = epoch
+ self.save_model(save_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ else:
+ if val_results['macro_f1'] > self.best_val_f1:
+ self.best_val_results = val_results
+ self.best_val_f1 = val_results['macro_f1']
+ self.best_epoch = epoch
+ self.save_model(save_path)
+ self.early_stop_cnt = 0
+ else:
+ self.early_stop_cnt += 1
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}")
+ wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ "best_valid_auroc": self.best_val_auroc})
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}")
+ wandb.log({"train_loss": train_loss, "valid_loss": valid_loss,
+ "best_valid_f1": self.best_val_f1})
+ if self.early_stop_cnt == 10:
+ self.logger.info("Early stop!")
+ break
+ # self.logger.info(vars(self.p))
+ self.load_model(save_path)
+ self.logger.info(
+ f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data')
+ start = time.time()
+ test_results, test_loss = self.evaluate_epoch('test', mode='normal_mix_hop')
+ end = time.time()
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ wandb.log({
+ "test_auroc": test_results['auroc'],
+ "test_auprc": test_results['auprc'],
+ "test_ap": test_results['ap']
+ })
+ else:
+ self.logger.info(
+ f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ wandb.log({
+ "test_acc": test_results['acc'],
+ "test_f1": test_results['macro_f1'],
+ "test_cohen": test_results['kappa']
+ })
+
+ def train_epoch_mix_hop(self):
+ self.model.train()
+ losses = []
+ train_iter = self.data_iter['train_rel']
+ # train_bar = tqdm(train_iter, ncols=0)
+ for step, batch in enumerate(train_iter):
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ mix_hop_index = self.transform_hop_index()
+ pred = self.model.compute_mix_hop_pred(hidden_all_ent, subj, obj, mix_hop_index)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ loss = self.model.calc_loss(pred, labels, pos_neg)
+ else:
+ loss = self.model.calc_loss(pred, labels)
+ self.optimizer.zero_grad()
+ loss.backward()
+ if self.p.clip_grad:
+ clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2)
+ self.optimizer.step()
+ losses.append(loss.item())
+ loss = np.mean(losses)
+ return loss
+
+ def transform_hop_index(self):
+ ij = self.p.exp_note.split("_")
+ return self.p.n_layer * (int(ij[0]) - 1) + int(ij[1]) - 1
+
+ def predict_mix_hop(self, split, mode):
+ loss_list = []
+ pos_scores = []
+ pos_labels = []
+ pred_class = {}
+ self.model.eval()
+ with torch.no_grad():
+ results = dict()
+ eval_iter = self.data_iter[f'{split}_rel{mode}']
+ for step, batch in enumerate(eval_iter):
+ if self.p.input_type == 'subgraph':
+ g, triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0"), batch[2].to("cuda:0")
+ else:
+ triplets, labels = batch[0].to("cuda:0"), batch[1].to("cuda:0")
+ g = self.g.to("cuda:0")
+ subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2]
+ hidden_all_ent, all_ent = self.model.forward_search(g, mode=self.p.input_type) # [batch_size, num_ent]
+ mix_hop_index = self.transform_hop_index()
+ pred = self.model.compute_mix_hop_pred(hidden_all_ent, subj, obj, mix_hop_index)
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ pos_neg = batch[2].to("cuda:0")
+ eval_loss = self.model.calc_loss(pred, labels, pos_neg)
+ m = torch.nn.Sigmoid()
+ pred = m(pred)
+ labels = labels.detach().to('cpu').numpy()
+ preds = pred.detach().to('cpu').numpy()
+ pos_neg = pos_neg.detach().to('cpu').numpy()
+ for (label_ids, pred, label_t) in zip(labels, preds, pos_neg):
+ for i, (l, p) in enumerate(zip(label_ids, pred)):
+ if l == 1:
+ if i in pred_class:
+ pred_class[i]['pred'] += [p]
+ pred_class[i]['l'] += [label_t]
+ pred_class[i]['pred_label'] += [1 if p > 0.5 else 0]
+ else:
+ pred_class[i] = {'pred': [p], 'l': [label_t], 'pred_label': [1 if p > 0.5 else 0]}
+ else:
+ eval_loss = self.model.calc_loss(pred, labels)
+ pos_labels += rel.to('cpu').numpy().flatten().tolist()
+ pos_scores += torch.argmax(pred, dim=1).cpu().flatten().tolist()
+ loss_list.append(eval_loss.item())
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ roc_auc = [metrics.roc_auc_score(pred_class[l]['l'], pred_class[l]['pred']) for l in pred_class]
+ prc_auc = [metrics.average_precision_score(pred_class[l]['l'], pred_class[l]['pred']) for l in
+ pred_class]
+ ap = [metrics.accuracy_score(pred_class[l]['l'], pred_class[l]['pred_label']) for l in pred_class]
+ results['auroc'] = np.mean(roc_auc)
+ results['auprc'] = np.mean(prc_auc)
+ results['ap'] = np.mean(ap)
+ else:
+ results['acc'] = metrics.accuracy_score(pos_labels, pos_scores)
+ results['macro_f1'] = metrics.f1_score(pos_labels, pos_scores, average='macro')
+ results['kappa'] = metrics.cohen_kappa_score(pos_labels, pos_scores)
+ loss = np.mean(loss_list)
+ return results, loss
+
+ def joint_random_ps2_each(self, trial):
+ genotype_space = []
+ for i in range(self.p.n_layer):
+ genotype_space.append(trial.suggest_categorical("mess"+ str(i), COMP_PRIMITIVES))
+ genotype_space.append(trial.suggest_categorical("agg"+ str(i), AGG_PRIMITIVES))
+ genotype_space.append(trial.suggest_categorical("comb"+ str(i), COMB_PRIMITIVES))
+ genotype_space.append(trial.suggest_categorical("act"+ str(i), ACT_PRIMITIVES))
+ self.best_val_f1 = 0.0
+ self.best_val_auroc = 0.0
+ self.early_stop_cnt = 0
+ # self.best_valid_metric, self.best_test_metric = 0.0, {}
+ self.p.genotype = "||".join(genotype_space)
+ self.model = self.get_model().to("cuda:0")
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.p.lr, weight_decay=self.p.l2)
+ self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0")
+ self.subgraph_selector_optimizer = torch.optim.Adam(
+ self.subgraph_selector.parameters(), lr=self.p.ss_lr, weight_decay=self.p.l2)
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}'
+ os.makedirs(save_root, exist_ok=True)
+ save_model_path = f'{save_root}/model.pt'
+ save_ss_path = f'{save_root}/model_ss.pt'
+ save_model_best_path = f'{save_root}/model_best.pt'
+ save_ss_best_path = f'{save_root}/model_best_ss.pt'
+ for epoch in range(self.p.max_epochs):
+ start_time = time.time()
+ train_loss = self.search_epoch()
+ val_results, valid_loss = self.evaluate_epoch('valid', mode='ps2')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ if val_results['auroc'] > self.best_val_auroc:
+ self.best_val_results = val_results
+ self.best_val_auroc = val_results['auroc']
+ self.best_epoch = epoch
+ torch.save(self.model.state_dict(), str(save_model_path))
+ torch.save(self.subgraph_selector.state_dict(), str(save_ss_path))
+ self.early_stop_cnt = 0
+ # self.logger.info("Update best valid auroc!")
+ else:
+ self.early_stop_cnt += 1
+ else:
+ if val_results['macro_f1'] > self.best_val_f1:
+ self.best_val_results = val_results
+ self.best_val_f1 = val_results['macro_f1']
+ self.best_epoch = epoch
+ torch.save(self.model.state_dict(), str(save_model_path))
+ torch.save(self.subgraph_selector.state_dict(), str(save_ss_path))
+ self.early_stop_cnt = 0
+ # self.logger.info("Update best valid f1!")
+ else:
+ self.early_stop_cnt += 1
+ self.logger.info(
+ f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid Loss: {valid_loss:.5}, Cost: {time.time() - start_time:.2f}s")
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best AUROC: {self.best_val_auroc:.5}")
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}]: Valid ACC: {val_results['acc']:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}]: Best Macro F1: {self.best_val_f1:.5}")
+ if self.early_stop_cnt == 10:
+ self.logger.info("Early stop!")
+ break
+ self.model.load_state_dict(torch.load(str(save_model_path)))
+ self.subgraph_selector.load_state_dict(torch.load(str(save_ss_path)))
+ self.logger.info(
+ f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data')
+ self.logger.info(f'{self.p.genotype}')
+ start = time.time()
+ test_results, test_loss = self.evaluate_epoch('test', mode='ps2')
+ end = time.time()
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Inference]: Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ if self.best_val_auroc > self.best_valid_metric:
+ self.best_valid_metric = self.best_val_auroc
+ self.best_test_metric = test_results
+ torch.save(self.model.state_dict(), str(save_model_best_path))
+ torch.save(self.subgraph_selector.state_dict(), str(save_ss_best_path))
+ with open(f'{save_root}/random_ps2_arch_list.csv', "a") as f:
+ writer = csv.writer(f)
+ writer.writerow([self.p.genotype, self.best_val_auroc, test_results['auroc']])
+ return self.best_val_auroc
+ else:
+ self.logger.info(
+ f"[Inference]: Test ACC: {test_results['acc']:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ if self.best_val_f1 > self.best_valid_metric:
+ self.best_valid_metric = self.best_val_f1
+ self.best_test_metric = test_results
+ torch.save(self.model.state_dict(), str(save_model_best_path))
+ torch.save(self.subgraph_selector.state_dict(), str(save_ss_best_path))
+ with open(f'{save_root}/random_ps2_arch_list.csv', "a") as f:
+ writer = csv.writer(f)
+ writer.writerow([self.p.genotype, self.best_val_f1, test_results['macro_f1']])
+ return self.best_val_f1
+
+ def joint_random_ps2(self):
+ self.best_valid_metric = 0.0
+ self.best_test_metric = {}
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}/'
+ os.makedirs(save_root, exist_ok=True)
+ print(save_root)
+ self.logger = get_logger(f'{save_root}/', f'random_ps2')
+ study = optuna.create_study(directions=["maximize"], sampler=RandomSampler())
+ study.optimize(self.joint_random_ps2_each, n_trials=self.p.baseline_sample_num)
+ with open(f'{save_root}/random_ps2_res.txt', "w") as f1:
+ f1.write(f'{self.p.__dict__}\n')
+ f1.write(f'{self.p.genotype}\n')
+ f1.write(f'Valid performance: {study.best_value}\n')
+ f1.write(f'Test performance: {self.best_test_metric}')
+
+ def spos_arch_search_ps2_ng(self):
+ res_list = []
+ sorted_list = []
+ arch_list = []
+ for _ in range(self.p.n_layer):
+ arch_list.append(len(COMP_PRIMITIVES))
+ arch_list.append(len(AGG_PRIMITIVES))
+ arch_list.append(len(COMB_PRIMITIVES))
+ arch_list.append(len(ACT_PRIMITIVES))
+ asng = CategoricalASNG(np.array(arch_list), alpha=1.5, delta_init=1)
+ exp_note = '_' + self.p.exp_note if self.p.exp_note is not None else ''
+ self.subgraph_selector = SubgraphSelector(self.p).to("cuda:0")
+ save_root = f'{self.prj_path}/exp/{self.p.dataset}/{self.p.search_mode}/{self.p.name}'
+ if self.p.exp_note == 'spfs':
+ epoch_list = [400]
+ weight_sharing = '_ws'
+ else:
+ epoch_list = [800, 700,600]
+ weight_sharing = ''
+ for save_epoch in epoch_list:
+ valid_metric = 0.0
+ self.logger = get_logger(f'{save_root}/', f'{save_epoch}_arch_search_ng{exp_note}')
+ valid_loss_searched_arch_res = dict()
+ valid_f1_searched_arch_res = dict()
+ valid_auroc_searched_arch_res = dict()
+ search_time = 0.0
+ t_start = time.time()
+ for epoch in range(1, self.p.spos_arch_sample_num + 1):
+ Ms = []
+ ma_structs = []
+ scores = []
+ for i in range(self.p.asng_sample_num):
+ M = asng.sampling()
+ struct = np.argmax(M, axis=1)
+ Ms.append(M)
+ ma_structs.append(list(struct))
+ # print(list(struct))
+ self.generate_single_path_ng(list(struct))
+ arch = "||".join(self.model.ops)
+ if self.p.exp_note == 'spfs':
+ few_shot_op = self.model.ops[0]
+ else:
+ few_shot_op = ''
+ self.model.load_state_dict(
+ torch.load(f'{save_root}{exp_note}_{few_shot_op}{weight_sharing}/{save_epoch}.pt'))
+ self.subgraph_selector.load_state_dict(
+ torch.load(f'{save_root}{exp_note}_{few_shot_op}{weight_sharing}/{save_epoch}_ss.pt'))
+ val_results, valid_loss = self.evaluate_epoch('valid', 'spos_arch_search_ps2')
+ test_results, test_loss = self.evaluate_epoch('test', 'spos_arch_search_ps2')
+ valid_loss_searched_arch_res.setdefault(arch, valid_loss)
+ self.logger.info(f'[Epoch {epoch}, {i}-th arch]: Path:{arch}')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"[Epoch {epoch}, {i}-th arch]: Valid Loss: {valid_loss:.5}, Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}, {i}-th arch]: Test Loss: {test_loss:.5}, Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ scores.append(val_results['auroc'])
+ sorted_list.append(val_results['auroc'])
+ valid_auroc_searched_arch_res.setdefault(arch, val_results['auroc'])
+ else:
+ self.logger.info(
+ f"[Epoch {epoch}, {i}-th arch]: Valid Loss: {valid_loss:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid ACC: {val_results['acc']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"[Epoch {epoch}, {i}-th arch]: Test Loss: {test_loss:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test ACC: {test_results['acc']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ valid_f1_searched_arch_res.setdefault(arch, val_results['macro_f1'])
+ scores.append(val_results['macro_f1'])
+ sorted_list.append(val_results['macro_f1'])
+ res_list.append(sorted(sorted_list, reverse=True)[:self.p.asng_sample_num])
+ asng.update(np.array(Ms), -np.array(scores), True)
+ best_struct = list(asng.theta.argmax(axis=1))
+ self.generate_single_path_ng(best_struct)
+ arch = "||".join(self.model.ops)
+ val_results, valid_loss = self.evaluate_epoch('valid', 'spos_arch_search_ps2')
+ test_results, test_loss = self.evaluate_epoch('test', 'spos_arch_search_ps2')
+ self.logger.info(f'Path:{arch}')
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ self.logger.info(
+ f"Valid Loss: {valid_loss:.5}, Valid AUROC: {val_results['auroc']:.5}, Valid AUPRC: {val_results['auprc']:.5}, Valid AP@50: {val_results['ap']:.5}")
+ self.logger.info(
+ f"Test Loss: {test_loss:.5}, Test AUROC: {test_results['auroc']:.5}, Test AUPRC: {test_results['auprc']:.5}, Test AP@50: {test_results['ap']:.5}")
+ else:
+ self.logger.info(
+ f"Valid Loss: {valid_loss:.5}, Valid Macro F1: {val_results['macro_f1']:.5}, Valid ACC: {val_results['acc']:.5}, Valid Cohen: {val_results['kappa']:.5}")
+ self.logger.info(
+ f"Test Loss: {test_loss:.5}, Test Macro F1: {test_results['macro_f1']:.5}, Test ACC: {test_results['acc']:.5}, Test Cohen: {test_results['kappa']:.5}")
+ with open(f"{save_root}/topK_{save_epoch}{exp_note}_ng.pkl", "wb") as f:
+ pickle.dump(res_list, f)
+ t_end = time.time()
+ search_time = (t_end - t_start)
+
+ search_time = search_time / 3600
+ self.logger.info(f'The search process costs {search_time:.2f}h.')
+ import csv
+ with open(f'{save_root}/valid_loss_{save_epoch}{exp_note}_ng.csv', 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['arch', 'valid loss'])
+ valid_loss_searched_arch_res_sorted = sorted(valid_loss_searched_arch_res.items(), key=lambda x :x[1])
+ res = valid_loss_searched_arch_res_sorted
+ for i in range(len(res)):
+ writer.writerow([res[i][0], res[i][1]])
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ with open(f'{save_root}/valid_auroc_{save_epoch}{exp_note}_ng.csv', 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['arch', 'valid auroc'])
+ valid_auroc_searched_arch_res_sorted = sorted(valid_auroc_searched_arch_res.items(), key=lambda x: x[1],
+ reverse=True)
+ res = valid_auroc_searched_arch_res_sorted
+ for i in range(len(res)):
+ writer.writerow([res[i][0], res[i][1]])
+ else:
+ with open(f'{save_root}/valid_f1_{save_epoch}{exp_note}_ng.csv', 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['arch', 'valid f1'])
+ valid_f1_searched_arch_res_sorted = sorted(valid_f1_searched_arch_res.items(), key=lambda x: x[1], reverse=True)
+ res = valid_f1_searched_arch_res_sorted
+ for i in range(len(res)):
+ writer.writerow([res[i][0], res[i][1]])
+ def generate_single_path_ng(self, struct):
+ single_path = []
+ for ops_index, index in enumerate(struct):
+ if ops_index % 4 == 0:
+ single_path.append(COMP_PRIMITIVES[index])
+ elif ops_index % 4 == 1:
+ single_path.append(AGG_PRIMITIVES[index])
+ elif ops_index % 4 == 2:
+ single_path.append(COMB_PRIMITIVES[index])
+ elif ops_index % 4 == 3:
+ single_path.append(ACT_PRIMITIVES[index])
+ self.model.ops = single_path
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Parser For Arguments',
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+ parser.add_argument('--name', default='test_run',
+ help='Set run name for saving/restoring models')
+ parser.add_argument('--dataset', default='drugbank',
+ help='Dataset to use, default: FB15k-237')
+ parser.add_argument('--input_type', type=str, default='allgraph', choices=['subgraph', 'allgraph'])
+ parser.add_argument('--score_func', dest='score_func', default='none',
+ help='Score Function for Link prediction')
+ parser.add_argument('--opn', dest='opn', default='corr',
+ help='Composition Operation to be used in CompGCN')
+
+ parser.add_argument('--batch', dest='batch_size',
+ default=256, type=int, help='Batch size')
+ parser.add_argument('--gpu', type=int, default=0,
+ help='Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0')
+ parser.add_argument('--epoch', dest='max_epochs',
+ type=int, default=500, help='Number of epochs')
+ parser.add_argument('--l2', type=float, default=5e-4,
+ help='L2 Regularization for Optimizer')
+ parser.add_argument('--lr', type=float, default=0.001,
+ help='Starting Learning Rate')
+ parser.add_argument('--lbl_smooth', dest='lbl_smooth',
+ type=float, default=0.1, help='Label Smoothing')
+ parser.add_argument('--num_workers', type=int, default=0,
+ help='Number of processes to construct batches')
+ parser.add_argument('--seed', dest='seed', default=12345,
+ type=int, help='Seed for randomization')
+
+ parser.add_argument('--restore', dest='restore', action='store_true',
+ help='Restore from the previously saved model')
+ parser.add_argument('--bias', dest='bias', action='store_true',
+ help='Whether to use bias in the model')
+
+ parser.add_argument('--num_bases', dest='num_bases', default=-1, type=int,
+ help='Number of basis relation vectors to use')
+ parser.add_argument('--init_dim', dest='init_dim', default=100, type=int,
+ help='Initial dimension size for entities and relations')
+ parser.add_argument('--gcn_dim', dest='gcn_dim', default=200,
+ type=int, help='Number of hidden units in GCN')
+ parser.add_argument('--embed_dim', dest='embed_dim', default=None, type=int,
+ help='Embedding dimension to give as input to score function')
+ parser.add_argument('--n_layer', dest='n_layer', default=1,
+ type=int, help='Number of GCN Layers to use')
+ parser.add_argument('--gcn_drop', dest='gcn_drop', default=0.1,
+ type=float, help='Dropout to use in GCN Layer')
+ parser.add_argument('--hid_drop', dest='hid_drop',
+ default=0.3, type=float, help='Dropout after GCN')
+
+ # ConvE specific hyperparameters
+ parser.add_argument('--conve_hid_drop', dest='conve_hid_drop', default=0.3, type=float,
+ help='ConvE: Hidden dropout')
+ parser.add_argument('--feat_drop', dest='feat_drop',
+ default=0.2, type=float, help='ConvE: Feature Dropout')
+ parser.add_argument('--input_drop', dest='input_drop', default=0.2,
+ type=float, help='ConvE: Stacked Input Dropout')
+ parser.add_argument('--k_w', dest='k_w', default=20,
+ type=int, help='ConvE: k_w')
+ parser.add_argument('--k_h', dest='k_h', default=10,
+ type=int, help='ConvE: k_h')
+ parser.add_argument('--num_filt', dest='num_filt', default=200, type=int,
+ help='ConvE: Number of filters in convolution')
+ parser.add_argument('--ker_sz', dest='ker_sz', default=7,
+ type=int, help='ConvE: Kernel size to use')
+
+ parser.add_argument('--gamma', dest='gamma', default=9.0,
+ type=float, help='TransE: Gamma to use')
+
+ parser.add_argument('--rat', action='store_true',
+ default=False, help='random adacency tensors')
+ parser.add_argument('--wni', action='store_true',
+ default=False, help='without neighbor information')
+ parser.add_argument('--wsi', action='store_true',
+ default=False, help='without self-loop information')
+ parser.add_argument('--ss', dest='ss', default=-1,
+ type=int, help='sample size (sample neighbors)')
+ parser.add_argument('--nobn', action='store_true',
+ default=False, help='no use of batch normalization in aggregation')
+ parser.add_argument('--noltr', action='store_true',
+ default=False, help='no use of linear transformations for relation embeddings')
+
+ parser.add_argument('--encoder', dest='encoder',
+ default='compgcn', type=str, help='which encoder to use')
+
+ # for lte models
+ parser.add_argument('--x_ops', dest='x_ops', default="")
+ parser.add_argument('--r_ops', dest='r_ops', default="")
+
+ parser.add_argument("--ss_num_layer", default=2, type=int)
+ parser.add_argument("--ss_input_dim", default=200, type=int)
+ parser.add_argument("--ss_hidden_dim", default=200, type=int)
+ parser.add_argument("--ss_lr", default=0.001, type=float)
+ parser.add_argument('--train_mode', default='', type=str,
+ choices=["train", "tune", "DEBUG", "vis_hop", "inference", "vis_class","joint_tune","spos_tune","vis_hop_pred","vis_rank_ccorelation","vis_rank_ccorelation_spfs"])
+ parser.add_argument("--ss_model_path", type=str)
+ parser.add_argument("--ss_search_algorithm", default='darts', type=str)
+ parser.add_argument("--search_algorithm", default='darts', type=str)
+ parser.add_argument("--temperature", default=0.07, type=float)
+ parser.add_argument("--temperature_min", default=0.005, type=float)
+ parser.add_argument("--lr_min", type=float, default=0.001)
+ parser.add_argument("--arch_lr", default=0.001, type=float)
+ parser.add_argument("--arch_lr_min", default=0.001, type=float)
+ parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
+ parser.add_argument("--w_update_epoch", type=int, default=1)
+ parser.add_argument("--alpha_mode", type=str, default='valid_loss')
+ parser.add_argument('--loc_mean', type=float, default=10.0, help='initial mean value to generate the location')
+ parser.add_argument('--loc_std', type=float, default=0.01, help='initial std to generate the location')
+ parser.add_argument("--genotype", type=str, default=None)
+ parser.add_argument("--baseline_sample_num", type=int, default=200)
+ parser.add_argument("--cos_temp", action='store_true', default=False, help='temp decay')
+ parser.add_argument("--spos_arch_sample_num", default=1000, type=int)
+ parser.add_argument("--tune_sample_num", default=10, type=int)
+
+ # subgraph config
+ parser.add_argument("--subgraph_type", type=str, default='seal')
+ parser.add_argument("--subgraph_hop", type=int, default=2)
+ parser.add_argument("--subgraph_edge_sample_ratio", type=float, default=1)
+ parser.add_argument("--subgraph_is_saved", type=bool, default=True)
+ parser.add_argument("--subgraph_max_num_nodes", type=int, default=100)
+ parser.add_argument("--subgraph_sample_type", type=str, default='enclosing_subgraph')
+ parser.add_argument("--save_mode", type=str, default='graph')
+ parser.add_argument("--num_neg_samples_per_link", type=int, default=0)
+
+ # transformer config
+ parser.add_argument("--d_model", type=int, default=100)
+ parser.add_argument("--num_transformer_layers", type=int, default=2)
+ parser.add_argument("--nhead", type=int, default=4)
+ parser.add_argument("--dim_feedforward", type=int, default=100)
+ parser.add_argument("--transformer_dropout", type=float, default=0.1)
+ parser.add_argument("--transformer_activation", type=str, default='relu')
+ parser.add_argument("--concat_type", type=str, default='so')
+ parser.add_argument("--graph_pooling_type", type=str, default='mean')
+
+ parser.add_argument("--loss_type", type=str, default='ce')
+ parser.add_argument("--eval_mode", type=str, default='rel')
+ parser.add_argument("--wandb_project", type=str, default='')
+ parser.add_argument("--search_mode", type=str, default='', choices=["ps2", "ps2_random", "", "arch_search", "arch_random","joint_search","arch_spos"])
+ parser.add_argument("--add_reverse", action='store_true', default=False)
+ parser.add_argument("--clip_grad", action='store_true', default=False)
+ parser.add_argument("--fine_tune_with_implicit_subgraph", action='store_true', default=False)
+ parser.add_argument("--combine_type", type=str, default='concat')
+ parser.add_argument("--exp_note", type=str, default=None)
+ parser.add_argument("--few_shot_op", type=str, default=None)
+ parser.add_argument("--weight_sharing", action='store_true', default=False)
+
+ parser.add_argument("--asng_sample_num", type=int, default=16)
+ parser.add_argument("--arch_search_mode", type=str, default='random')
+
+ args = parser.parse_args()
+ opn = '_' + args.opn if args.encoder == 'compgcn' else ''
+ reverse = '_add_reverse' if args.add_reverse else ''
+ genotype = '_'+args.genotype if args.genotype is not None else ''
+ input_type = '_'+args.input_type if args.input_type == 'subgraph' else ''
+ exp_note = '_'+args.exp_note if args.exp_note is not None else ''
+ num_bases = '_b'+str(args.num_bases) if args.num_bases!=-1 else ''
+ alpha_mode = '_'+str(args.alpha_mode)
+ few_shot_op = '_'+args.few_shot_op if args.few_shot_op is not None else ''
+ weight_sharing = '_ws' if args.weight_sharing else ''
+ ss_search_algorithm = '_snas' if args.ss_search_algorithm == 'snas' else ''
+
+ if args.input_type == 'subgraph':
+ args.name = 'seal'
+ else:
+ if args.train_mode in ['train', 'vis_hop_pred']:
+ args.name = f'{args.encoder}{opn}_{args.score_func}_{args.combine_type}_train_layer{args.n_layer}_seed{args.seed}{num_bases}{reverse}{genotype}{exp_note}'
+ elif args.search_mode in ['ps2']:
+ args.name = f'{args.encoder}{opn}_{args.score_func}_{args.combine_type}_{args.search_mode}{ss_search_algorithm}_layer{args.n_layer}_seed{args.seed}{num_bases}{reverse}{genotype}{exp_note}'
+ elif args.search_mode in ['arch_random']:
+ args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_{args.search_mode}_layer{args.n_layer}_seed{args.seed}{reverse}{exp_note}'
+ elif args.search_algorithm == 'spos_arch_search':
+ args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_spos_train_supernet_layer{args.n_layer}_seed{args.seed}{reverse}{genotype}{exp_note}{few_shot_op}{weight_sharing}'
+ elif args.search_mode == 'joint_search' and args.train_mode == 'vis_hop':
+ args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_spos_train_supernet_ps2{ss_search_algorithm}_layer{args.n_layer}_seed{args.seed}{reverse}{exp_note}{few_shot_op}{weight_sharing}'
+ elif args.search_algorithm == 'spos_arch_search_ps2':
+ args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_spos_train_supernet_ps2{ss_search_algorithm}_layer{args.n_layer}_seed{args.seed}{reverse}{genotype}'
+ elif args.search_mode == 'joint_search' and args.train_mode == 'spos_tune':
+ args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_spos_train_supernet_ps2{ss_search_algorithm}_layer{args.n_layer}_seed{args.seed}{reverse}{genotype}'
+ elif args.search_mode == 'joint_search' and args.search_algorithm == 'random_ps2':
+ args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_{args.search_algorithm}{ss_search_algorithm}_layer{args.n_layer}_seed{args.seed}'
+ elif args.search_algorithm == 'spos_train_supernet_ps2':
+ args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_{args.search_algorithm}{ss_search_algorithm}_layer{args.n_layer}_seed{args.seed}{reverse}{genotype}{exp_note}{few_shot_op}{weight_sharing}'
+ elif (args.search_algorithm == 'spos_train_supernet' and args.train_mode!='spos_tune') or args.train_mode == 'vis_rank_correlation' or args.train_mode == 'vis_rank_ccorelation_spfs':
+ args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_{args.search_algorithm}_layer{args.n_layer}_seed{args.seed}{reverse}{genotype}{exp_note}{few_shot_op}{weight_sharing}'
+ else:
+ args.name = f'{args.encoder}_{args.score_func}_{args.combine_type}_spos_train_supernet_layer{args.n_layer}_seed{args.seed}{reverse}{genotype}{exp_note}{few_shot_op}{weight_sharing}'
+
+ args.embed_dim = args.k_w * \
+ args.k_h if args.embed_dim is None else args.embed_dim
+
+ # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ os.environ.setdefault("HYPEROPT_FMIN_SEED", str(args.seed))
+
+ runner = Runner(args)
+
+ if args.search_mode == 'joint_search' and args.search_algorithm!='spos_arch_search_ps2' and args.train_mode!='spos_tune':
+ wandb.init(
+ project=args.wandb_project,
+ config={
+ "dataset": args.dataset,
+ "encoder": args.encoder,
+ "score_function": args.score_func,
+ "batch_size": args.batch_size,
+ "learning_rate": args.lr,
+ "encoder_layer_num": args.n_layer,
+ "epochs": args.max_epochs,
+ "seed": args.seed,
+ "init_dim": args.init_dim,
+ "embed_dim": args.embed_dim,
+ "loss_type": args.loss_type,
+ "search_mode": args.search_mode,
+ "combine_type": args.combine_type,
+ "note": args.exp_note,
+ "search_algorithm": args.search_algorithm,
+ "weight_sharing": args.weight_sharing,
+ "few_shot_op": args.few_shot_op,
+ "ss_search_algorithm": args.ss_search_algorithm
+ })
+
+ if args.train_mode == 'train':
+ if args.exp_note is not None:
+ runner.train_mix_hop()
+ else:
+ runner.train()
+ elif args.train_mode == 'tune' and args.search_mode in ['ps2', 'joint_search']:
+ runner.fine_tune()
+ elif args.train_mode == 'vis_hop' and args.search_mode in ['ps2', "ps2_random", "joint_search"]:
+ runner.vis_hop_distrubution()
+ elif args.train_mode == 'vis_hop_pred':
+ runner.vis_hop_pred()
+ elif args.train_mode == 'inference' and args.search_mode in ['ps2']:
+ runner.inference()
+ elif args.train_mode == 'vis_class':
+ runner.vis_class_distribution()
+ elif args.train_mode == 'joint_tune':
+ runner.joint_tune()
+ else:
+ if args.search_mode == 'ps2_random':
+ runner.random_search()
+ elif args.search_mode == 'ps2':
+ runner.ps2()
+ elif args.search_mode == 'arch_search':
+ if args.train_mode == 'spos_tune':
+ runner.spos_fine_tune()
+ elif args.train_mode in ["vis_rank_ccorelation"]:
+ runner.vis_rank_ccorelation()
+ elif args.train_mode in ["vis_rank_ccorelation_spfs"]:
+ runner.vis_rank_ccorelation_spfs()
+ elif args.search_algorithm in ["darts", "snas"]:
+ runner.architecture_search()
+ elif args.search_algorithm in ["spos_train_supernet"]:
+ runner.spos_train_supernet()
+ elif args.search_algorithm in ["spos_arch_search"]:
+ runner.spos_arch_search()
+ elif args.search_mode == 'arch_random':
+ runner.arch_random_search()
+ elif args.search_mode == 'joint_search':
+ if args.search_algorithm in ["spos_train_supernet_ps2"]:
+ runner.spos_train_supernet_ps2()
+ elif args.search_algorithm in ["spos_arch_search_ps2"] and args.arch_search_mode == 'random':
+ runner.spos_arch_search_ps2()
+ elif args.search_algorithm in ["spos_arch_search_ps2"] and args.arch_search_mode == 'ng':
+ runner.spos_arch_search_ps2_ng()
+ elif args.train_mode in ["spos_tune"]:
+ runner.joint_spos_ps2_fine_tune()
+ elif args.search_algorithm == 'random_ps2':
+ runner.joint_random_ps2()
+ else:
+ runner.joint_search()
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..7b83ce9
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,7 @@
+from .process_data import process, process_multi_label
+from .data_set import TestDataset, TrainDataset, GraphTrainDataset, GraphTestDataset, NCNDataset
+from .logger import get_logger
+from .sampler import TripleSampler, SEALSampler, SEALSampler_NG, GrailSampler
+from .utils import get_current_memory_gb, get_f1_score_list, get_acc_list, deserialize, Temp_Scheduler
+from .dgl_utils import get_neighbor_nodes
+from .asng import CategoricalASNG
\ No newline at end of file
diff --git a/utils/asng.py b/utils/asng.py
new file mode 100644
index 0000000..139b33c
--- /dev/null
+++ b/utils/asng.py
@@ -0,0 +1,169 @@
+import numpy as np
+
+
+class CategoricalASNG:
+ """
+ code refers to https://github.com/shirakawas/ASNG-NAS
+ """
+
+ def __init__(self, categories, alpha=1.5, delta_init=1., lam=2., Delta_max=10, init_theta=None):
+ if init_theta is not None:
+ self.theta = init_theta
+
+ self.N = np.sum(categories - 1)
+ self.d = len(categories)
+ self.C = categories
+ self.Cmax = np.max(categories)
+ self.theta = np.zeros((self.d, self.Cmax))
+
+ for i in range(self.d):
+ self.theta[i, :self.C[i]] = 1. / self.C[i]
+
+ for i in range(self.d):
+ self.theta[i, self.C[i]:] = 0.
+
+ self.valid_param_num = int(np.sum(self.C - 1))
+ self.valid_d = len(self.C[self.C > 1])
+
+ self.alpha = alpha
+ self.delta_init = delta_init
+ self.lam = lam
+ self.Delta_max = Delta_max
+
+ self.Delta = 1.
+ self.gamma = 0.0
+ self.delta = self.delta_init / self.Delta
+ self.eps = self.delta
+ self.s = np.zeros(self.N)
+
+ def get_lam(self):
+ return self.lam
+
+ def get_delta(self):
+ return self.delta / self.Delta
+
+ def sampling(self):
+ rand = np.random.rand(self.d, 1)
+ cum_theta = self.theta.cumsum(axis=1)
+
+ c = (cum_theta - self.theta <= rand) & (rand < cum_theta)
+ return c
+
+ def sampling_lam(self, lam):
+ rand = np.random.rand(lam, self.d, 1)
+ cum_theta = self.theta.cumsum(axis=1)
+
+ c = (cum_theta - self.theta <= rand) & (rand < cum_theta)
+ return c
+
+ def mle(self):
+ m = self.theta.argmax(axis=1)
+ x = np.zeros((self.d, self.Cmax))
+ for i, c in enumerate(m):
+ x[i, c] = 1
+ return x
+
+ def update(self, c_one, fxc, range_restriction=True):
+ delta = self.get_delta()
+ # print('delta:', delta)
+ beta = self.delta * self.N ** (-0.5)
+
+ aru, idx = self.utility(fxc)
+ mu_W, var_W = aru.mean(), aru.var()
+ # print(fxc, idx, aru)
+ if var_W == 0:
+ return
+
+ ngrad = np.mean((aru - mu_W)[:, np.newaxis, np.newaxis] * (c_one[idx] - self.theta), axis=0)
+
+ if (np.abs(ngrad) < 1e-18).all():
+ # print('skip update')
+ return
+
+ sl = []
+ for i, K in enumerate(self.C):
+ theta_i = self.theta[i, :K - 1]
+ theta_K = self.theta[i, K - 1]
+ s_i = 1. / np.sqrt(theta_i) * ngrad[i, :K - 1]
+ s_i += np.sqrt(theta_i) * ngrad[i, :K - 1].sum() / (theta_K + np.sqrt(theta_K))
+ sl += list(s_i)
+ sl = np.array(sl)
+
+ ngnorm = np.sqrt(np.dot(sl, sl)) + 1e-8
+ dp = ngrad / ngnorm
+ assert not np.isnan(dp).any(), (ngrad, ngnorm)
+
+ self.theta += delta * dp
+
+ self.s = (1 - beta) * self.s + np.sqrt(beta * (2 - beta)) * sl / ngnorm
+ self.gamma = (1 - beta) ** 2 * self.gamma + beta * (2 - beta)
+ self.Delta *= np.exp(beta * (self.gamma - np.dot(self.s, self.s) / self.alpha))
+ self.Delta = min(self.Delta, self.Delta_max)
+
+ for i in range(self.d):
+ ci = self.C[i]
+ theta_min = 1. / (self.valid_d * (ci - 1)) if range_restriction and ci > 1 else 0.
+ self.theta[i, :ci] = np.maximum(self.theta[i, :ci], theta_min)
+ theta_sum = self.theta[i, :ci].sum()
+ tmp = theta_sum - theta_min * ci
+ self.theta[i, :ci] -= (theta_sum - 1.) * (self.theta[i, :ci] - theta_min) / tmp
+ self.theta[i, :ci] /= self.theta[i, :ci].sum()
+
+ def get_arch(self, ):
+ return np.argmax(self.theta, axis=1)
+
+ def get_max(self, ):
+ return np.max(self.theta, axis=1)
+
+ def get_entropy(self, ):
+ ent = 0
+ for i, K in enumerate(self.C):
+ the = self.theta[i, :K]
+ ent += np.sum(the * np.log(the))
+ return -ent
+
+ @staticmethod
+ def utility(f, rho=0.25, negative=True):
+ eps = 1e-3
+ idx = np.argsort(f)
+ lam = len(f)
+ mu = int(np.ceil(lam * rho))
+ _w = np.zeros(lam)
+ _w[:mu] = 1 / mu
+ _w[lam - mu:] = -1 / mu if negative else 0
+ w = np.zeros(lam)
+ istart = 0
+ for i in range(len(f) - 1):
+ if f[idx[i + 1]] - f[idx[i]] < eps * f[idx[i]]:
+ pass
+ elif istart < i:
+ w[istart:i + 1] = np.mean(_w[istart:i + 1])
+ istart = i + 1
+ else:
+ w[i] = _w[i]
+ istart = i + 1
+ w[istart:] = np.mean(_w[istart:])
+ return w, idx
+
+ def log_header(self, theta_log=False):
+ header_list = ['delta', 'eps', 'snorm_alha', 'theta_converge']
+ if theta_log:
+ for i in range(self.d):
+ header_list += ['theta%d_%d' % (i, j) for j in range(self.C[i])]
+ return header_list
+
+ def log(self, theta_log=False):
+ log_list = [self.delta, self.eps, np.dot(self.s, self.s) / self.alpha, self.theta.max(axis=1).mean()]
+
+ if theta_log:
+ for i in range(self.d):
+ log_list += ['%f' % self.theta[i, j] for j in range(self.C[i])]
+ return log_list
+
+ def load_theta_from_log(self, theta_log):
+ self.theta = np.zeros((self.d, self.Cmax))
+ k = 0
+ for i in range(self.d):
+ for j in range(self.C[i]):
+ self.theta[i, j] = theta_log[k]
+ k += 1
\ No newline at end of file
diff --git a/utils/data_set.py b/utils/data_set.py
new file mode 100644
index 0000000..005dcbe
--- /dev/null
+++ b/utils/data_set.py
@@ -0,0 +1,367 @@
+from torch.utils.data import Dataset
+import numpy as np
+import torch
+import dgl
+from dgl import NID
+from scipy.sparse.csgraph import shortest_path
+import lmdb
+import pickle
+from utils.dgl_utils import get_neighbor_nodes
+
+
+class TrainDataset(Dataset):
+ def __init__(self, triplets, num_ent, num_rel, params):
+ super(TrainDataset, self).__init__()
+ self.p = params
+ self.triplets = triplets
+ self.label_smooth = params.lbl_smooth
+ self.num_ent = num_ent
+ self.num_rel = num_rel
+
+ def __len__(self):
+ return len(self.triplets)
+
+ def __getitem__(self, item):
+ ele = self.triplets[item]
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ triple, label, pos_neg = torch.tensor(ele['triple'], dtype=torch.long), torch.tensor(ele['label'], dtype=torch.long), torch.tensor(ele['pos_neg'], dtype=torch.float)
+ return triple, label, pos_neg
+ else:
+ triple, label, random_hop = torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label']), torch.tensor(ele['random_hop'])
+ label = triple[1]
+ # label = torch.tensor(label, dtype=torch.long)
+ # label = self.get_label_rel(label)
+ # if self.label_smooth != 0.0:
+ # label = (1.0 - self.label_smooth) * label + (1.0 / self.num_rel)
+ if self.p.search_mode == 'random':
+ return triple, label, random_hop
+ else:
+ return triple, label
+
+ def get_label_rel(self, label):
+ """
+ get label corresponding to a (sub, rel) pair
+ :param label: a list containing indices of objects corresponding to a (sub, rel) pair
+ :return: a tensor of shape [nun_ent]
+ """
+ y = np.zeros([self.num_rel*2], dtype=np.float32)
+ y[label] = 1
+ return torch.tensor(y, dtype=torch.float32)
+
+
+class TestDataset(Dataset):
+ def __init__(self, triplets, num_ent, num_rel, params):
+ super(TestDataset, self).__init__()
+ self.p = params
+ self.triplets = triplets
+ self.num_ent = num_ent
+ self.num_rel = num_rel
+
+ def __len__(self):
+ return len(self.triplets)
+
+ def __getitem__(self, item):
+ ele = self.triplets[item]
+ if 'twosides' in self.p.dataset or 'ogbl_biokg' in self.p.dataset:
+ triple, label, pos_neg = torch.tensor(ele['triple'], dtype=torch.long), torch.tensor(ele['label'], dtype=torch.long), torch.tensor(ele['pos_neg'], dtype=torch.float)
+ return triple, label, pos_neg
+ else:
+ triple, label, random_hop = torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label']), torch.tensor(ele['random_hop'])
+ label = triple[1]
+ # label = torch.tensor(label, dtype=torch.long)
+ # label = self.get_label_rel(label)
+ if self.p.search_mode == 'random':
+ return triple, label, random_hop
+ else:
+ return triple, label
+
+ def get_label_rel(self, label):
+ """
+ get label corresponding to a (sub, rel) pair
+ :param label: a list containing indices of objects corresponding to a (sub, rel) pair
+ :return: a tensor of shape [nun_ent]
+ """
+ y = np.zeros([self.num_rel*2], dtype=np.float32)
+ y[label] = 1
+ return torch.tensor(y, dtype=torch.float32)
+
+
+class GraphTrainDataset(Dataset):
+ def __init__(self, triplets, num_ent, num_rel, params, all_graph, db_name_pos=None):
+ super(GraphTrainDataset, self).__init__()
+ self.p = params
+ self.triplets = triplets
+ self.label_smooth = params.lbl_smooth
+ self.num_ent = num_ent
+ self.num_rel = num_rel
+ self.g = all_graph
+ db_path = f'subgraph/{self.p.dataset}/{self.p.subgraph_type}_neg_{self.p.num_neg_samples_per_link}_hop_{self.p.subgraph_hop}_seed_{self.p.seed}'
+ if self.p.save_mode == 'mdb':
+ self.main_env = lmdb.open(db_path, readonly=True, max_dbs=3, lock=False)
+ self.db_pos = self.main_env.open_db(db_name_pos.encode())
+ # with self.main_env.begin(db=self.db_pos) as txn:
+ # num_graphs_pos = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little')
+ # print(num_graphs_pos)
+
+ def __len__(self):
+ return len(self.triplets)
+
+ def __getitem__(self, item):
+ ele = self.triplets[item]
+ if self.p.save_mode == 'pickle':
+ sample_nodes, triple, label = ele['sample_nodes'], torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label'])
+ subgraph = self.subgraph_sample(triple[0], triple[2], sample_nodes)
+ elif self.p.save_mode == 'graph':
+ subgraph, triple, label = ele['subgraph'], torch.tensor(ele['triple'], dtype=torch.long), torch.tensor(ele['triple'][1], dtype=torch.long)
+ elif self.p.save_mode == 'mdb':
+ # triple, label = torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label'])
+ with self.main_env.begin(db=self.db_pos) as txn:
+ str_id = '{:08}'.format(item).encode('ascii')
+ nodes_pos, r_label_pos, g_label_pos, n_labels_pos = deserialize(txn.get(str_id)).values()
+ # print(nodes_pos)
+ # head_id = np.argwhere([label[0] == 0 and label[1] == 1 for label in n_labels_pos])
+ # tail_id = np.argwhere([label[0] == 1 and label[1] == 0 for label in n_labels_pos])
+ # print(head_id)
+ # print(nodes_pos[head_id[0]], nodes_pos[tail_id[0]])
+ # exit(0)
+ # n_ids = np.zeros(len(nodes_pos))
+ # n_ids[head_id] = 1 # head
+ # n_ids[tail_id] = 2 # tail
+ # subgraph.ndata['id'] = torch.FloatTensor(n_ids)
+ if self.p.train_mode == 'tune':
+ subgraph = torch.zeros(self.num_ent, dtype=torch.bool)
+ subgraph[nodes_pos] = 1
+ else:
+ subgraph = self.subgraph_sample(nodes_pos[0], nodes_pos[1], nodes_pos)
+ triple = torch.tensor([nodes_pos[0], r_label_pos,nodes_pos[1]], dtype=torch.long)
+ return subgraph, triple, torch.tensor(r_label_pos, dtype=torch.long)
+ # input_ids = self.convert_subgraph_to_tokens(subgraph, self.max_seq_length)
+ # label = self.get_label_rel(label)
+ # if self.label_smooth != 0.0:
+ # label = (1.0 - self.label_smooth) * label + (1.0 / self.num_rel)
+ return subgraph, triple, label
+
+ def get_label_rel(self, label):
+ """
+ get label corresponding to a (sub, rel) pair
+ :param label: a list containing indices of objects corresponding to a (sub, rel) pair
+ :return: a tensor of shape [nun_ent]
+ """
+ y = np.zeros([self.num_rel*2], dtype=np.float32)
+ y[label] = 1
+ return torch.tensor(y, dtype=torch.float32)
+
+ def subgraph_sample(self, u, v, sample_nodes):
+ subgraph = dgl.node_subgraph(self.g, sample_nodes)
+ n_ids = np.zeros(len(sample_nodes))
+ n_ids[0] = 1 # head
+ n_ids[1] = 2 # tail
+ subgraph.ndata['id'] = torch.tensor(n_ids, dtype=torch.long)
+ # print(sample_nodes)
+ # print(u,v)
+ # Each node should have unique node id in the new subgraph
+ u_id = int(torch.nonzero(subgraph.ndata[NID] == int(u), as_tuple=False))
+ # print(torch.nonzero(subgraph.ndata[NID] == int(u), as_tuple=False))
+ # print(torch.nonzero(subgraph.ndata[NID] == int(v), as_tuple=False))
+ v_id = int(torch.nonzero(subgraph.ndata[NID] == int(v), as_tuple=False))
+
+ if subgraph.has_edges_between(u_id, v_id):
+ link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2]
+ subgraph.remove_edges(link_id)
+ if subgraph.has_edges_between(v_id, u_id):
+ link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2]
+ subgraph.remove_edges(link_id)
+
+ n_ids = np.zeros(len(sample_nodes))
+ n_ids[0] = 1 # head
+ n_ids[1] = 2 # tail
+ subgraph.ndata['id'] = torch.tensor(n_ids, dtype=torch.long)
+
+ # z = drnl_node_labeling(subgraph, u_id, v_id)
+ # subgraph.ndata['z'] = z
+ return subgraph
+
+
+class GraphTestDataset(Dataset):
+ def __init__(self, triplets, num_ent, num_rel, params, all_graph, db_name_pos=None):
+ super(GraphTestDataset, self).__init__()
+ self.p = params
+ self.triplets = triplets
+ self.num_ent = num_ent
+ self.num_rel = num_rel
+ self.g = all_graph
+ db_path = f'subgraph/{self.p.dataset}/{self.p.subgraph_type}_neg_{self.p.num_neg_samples_per_link}_hop_{self.p.subgraph_hop}_seed_{self.p.seed}'
+ if self.p.save_mode == 'mdb':
+ self.main_env = lmdb.open(db_path, readonly=True, max_dbs=3, lock=False)
+ self.db_pos = self.main_env.open_db(db_name_pos.encode())
+
+ def __len__(self):
+ return len(self.triplets)
+
+ def __getitem__(self, item):
+ ele = self.triplets[item]
+ if self.p.save_mode == 'pickle':
+ sample_nodes, triple, label = ele['sample_nodes'], torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label'])
+ subgraph = self.subgraph_sample(triple[0], triple[2], sample_nodes)
+ elif self.p.save_mode == 'graph':
+ subgraph, triple, label = ele['subgraph'], torch.tensor(ele['triple'], dtype=torch.long), torch.tensor(ele['triple'][1], dtype=torch.long)
+ elif self.p.save_mode == 'mdb':
+ # triple, label = torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label'])
+ with self.main_env.begin(db=self.db_pos) as txn:
+ str_id = '{:08}'.format(item).encode('ascii')
+ nodes_pos, r_label_pos, g_label_pos, n_labels_pos = deserialize(txn.get(str_id)).values()
+ if self.p.train_mode == 'tune':
+ subgraph = torch.zeros(self.num_ent, dtype=torch.bool)
+ subgraph[nodes_pos] = 1
+ else:
+ subgraph = self.subgraph_sample(nodes_pos[0], nodes_pos[1], nodes_pos)
+ triple = torch.tensor([nodes_pos[0], r_label_pos, nodes_pos[1]], dtype=torch.long)
+ return subgraph, triple, torch.tensor(r_label_pos, dtype=torch.long)
+ # label = self.get_label_rel(label)
+ return subgraph, triple, label
+
+ def get_label_rel(self, label):
+ """
+ get label corresponding to a (sub, rel) pair
+ :param label: a list containing indices of objects corresponding to a (sub, rel) pair
+ :return: a tensor of shape [nun_ent]
+ """
+ y = np.zeros([self.num_rel*2], dtype=np.float32)
+ y[label] = 1
+ return torch.tensor(y, dtype=torch.float32)
+
+ def subgraph_sample(self, u, v, sample_nodes):
+ subgraph = dgl.node_subgraph(self.g, sample_nodes)
+ n_ids = np.zeros(len(sample_nodes))
+ n_ids[0] = 1 # head
+ n_ids[1] = 2 # tail
+ subgraph.ndata['id'] = torch.tensor(n_ids, dtype=torch.long)
+ # Each node should have unique node id in the new subgraph
+ u_id = int(torch.nonzero(subgraph.ndata[NID] == int(u), as_tuple=False))
+ v_id = int(torch.nonzero(subgraph.ndata[NID] == int(v), as_tuple=False))
+
+ if subgraph.has_edges_between(u_id, v_id):
+ link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2]
+ subgraph.remove_edges(link_id)
+ if subgraph.has_edges_between(v_id, u_id):
+ link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2]
+ subgraph.remove_edges(link_id)
+
+ # z = drnl_node_labeling(subgraph, u_id, v_id)
+ # subgraph.ndata['z'] = z
+ return subgraph
+
+
+def drnl_node_labeling(subgraph, src, dst):
+ """
+ Double Radius Node labeling
+ d = r(i,u)+r(i,v)
+ label = 1+ min(r(i,u),r(i,v))+ (d//2)*(d//2+d%2-1)
+ Isolated nodes in subgraph will be set as zero.
+ Extreme large graph may cause memory error.
+
+ Args:
+ subgraph(DGLGraph): The graph
+ src(int): node id of one of src node in new subgraph
+ dst(int): node id of one of dst node in new subgraph
+ Returns:
+ z(Tensor): node labeling tensor
+ """
+ adj = subgraph.adj().to_dense().numpy()
+ src, dst = (dst, src) if src > dst else (src, dst)
+ if src != dst:
+ idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
+ adj_wo_src = adj[idx, :][:, idx]
+
+ idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
+ adj_wo_dst = adj[idx, :][:, idx]
+
+ dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src)
+ dist2src = np.insert(dist2src, dst, 0, axis=0)
+ dist2src = torch.from_numpy(dist2src)
+
+ dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst - 1)
+ dist2dst = np.insert(dist2dst, src, 0, axis=0)
+ dist2dst = torch.from_numpy(dist2dst)
+ else:
+ dist2src = shortest_path(adj, directed=False, unweighted=True, indices=src)
+ # dist2src = np.insert(dist2src, dst, 0, axis=0)
+ dist2src = torch.from_numpy(dist2src)
+
+ dist2dst = shortest_path(adj, directed=False, unweighted=True, indices=dst)
+ # dist2dst = np.insert(dist2dst, src, 0, axis=0)
+ dist2dst = torch.from_numpy(dist2dst)
+
+ dist = dist2src + dist2dst
+ dist_over_2, dist_mod_2 = dist // 2, dist % 2
+
+ z = 1 + torch.min(dist2src, dist2dst)
+ z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
+ z[src] = 1.
+ z[dst] = 1.
+ z[torch.isnan(z)] = 0.
+
+ return z.to(torch.long)
+
+def deserialize(data):
+ data_tuple = pickle.loads(data)
+ keys = ('nodes', 'r_label', 'g_label', 'n_label', 'common_neighbor')
+ return dict(zip(keys, data_tuple))
+
+class NCNDataset(Dataset):
+ def __init__(self, triplets, num_ent, num_rel, params, adj, db_name_pos=None):
+ super(NCNDataset, self).__init__()
+ self.p = params
+ self.triplets = triplets
+ self.label_smooth = params.lbl_smooth
+ self.num_ent = num_ent
+ self.num_rel = num_rel
+ self.adj = adj
+ db_path = f'subgraph/{self.p.dataset}/{self.p.subgraph_type}_neg_{self.p.num_neg_samples_per_link}_hop_{self.p.subgraph_hop}_seed_{self.p.seed}'
+ if self.p.save_mode == 'mdb':
+ self.main_env = lmdb.open(db_path, readonly=True, max_dbs=3, lock=False)
+ self.db_pos = self.main_env.open_db(db_name_pos.encode())
+ # with self.main_env.begin(db=self.db_pos) as txn:
+ # num_graphs_pos = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little')
+ # print(num_graphs_pos)
+
+ def __len__(self):
+ return len(self.triplets)
+
+ def __getitem__(self, item):
+ with self.main_env.begin(db=self.db_pos) as txn:
+ str_id = '{:08}'.format(item).encode('ascii')
+ nodes_pos, r_label_pos, g, n, cn = deserialize(txn.get(str_id)).values()
+ triple = torch.tensor([nodes_pos[0], r_label_pos,nodes_pos[1]], dtype=torch.long)
+ label = torch.tensor(r_label_pos, dtype=torch.long)
+ cn_index = torch.zeros([self.num_ent+1], dtype=torch.bool)
+ if len(cn) == 0:
+ cn_index[self.num_ent] = 1
+ else:
+ cn_index[list(cn)] = 1
+ return triple, label, cn_index
+
+ def get_common_neighbors(self, u, v):
+ cns_list = []
+ # for i_u in range(1, self.p.n_layer+1):
+ # for i_v in range(1, self.p.n_layer+1):
+ # root_u_nei = get_neighbor_nodes({u}, self.adj, i_u)
+ # root_v_nei = get_neighbor_nodes({v}, self.adj, i_v)
+ # subgraph_nei_nodes_int = root_u_nei.intersection(root_v_nei)
+ # ng = list(subgraph_nei_nodes_int)
+ # subgraph = torch.zeros([1, self.num_ent], dtype=torch.bool)
+ # if len(ng) == 0:
+ # cns_list.append(subgraph)
+ # continue
+ # subgraph[:,ng] = 1
+ # cns_list.append(subgraph)
+ # root_u_nei = get_neighbor_nodes({u}, self.adj, 1)
+ # root_v_nei = get_neighbor_nodes({v}, self.adj, 1)
+ # subgraph_nei_nodes_int = root_u_nei.intersection(root_v_nei)
+ # ng = list(subgraph_nei_nodes_int)
+ # subgraph = torch.zeros([1, self.num_ent], dtype=torch.bool)
+ # if len(ng) == 0:
+ # return subgraph
+ # else:
+ # subgraph[:,ng] = 1
+ return 1
\ No newline at end of file
diff --git a/utils/dgl_utils.py b/utils/dgl_utils.py
new file mode 100644
index 0000000..da99c42
--- /dev/null
+++ b/utils/dgl_utils.py
@@ -0,0 +1,131 @@
+import numpy as np
+import scipy.sparse as ssp
+import random
+from scipy.sparse import csc_matrix
+
+"""All functions in this file are from dgl.contrib.data.knowledge_graph"""
+
+
+def _bfs_relational(adj, roots, max_nodes_per_hop=None):
+ """
+ BFS for graphs.
+ Modified from dgl.contrib.data.knowledge_graph to accomodate node sampling
+ """
+ visited = set()
+ current_lvl = set(roots)
+
+ next_lvl = set()
+
+ while current_lvl:
+
+ for v in current_lvl:
+ visited.add(v)
+
+ next_lvl = _get_neighbors(adj, current_lvl)
+ next_lvl -= visited # set difference
+
+ if max_nodes_per_hop and max_nodes_per_hop < len(next_lvl):
+ next_lvl = set(random.sample(next_lvl, max_nodes_per_hop))
+
+ yield next_lvl
+
+ current_lvl = set.union(next_lvl)
+
+
+def _get_neighbors(adj, nodes):
+ """Takes a set of nodes and a graph adjacency matrix and returns a set of neighbors.
+ Directly copied from dgl.contrib.data.knowledge_graph"""
+ sp_nodes = _sp_row_vec_from_idx_list(list(nodes), adj.shape[1])
+ sp_neighbors = sp_nodes.dot(adj)
+ neighbors = set(ssp.find(sp_neighbors)[1]) # convert to set of indices
+ return neighbors
+
+
+def _sp_row_vec_from_idx_list(idx_list, dim):
+ """Create sparse vector of dimensionality dim from a list of indices."""
+ shape = (1, dim)
+ data = np.ones(len(idx_list))
+ row_ind = np.zeros(len(idx_list))
+ col_ind = list(idx_list)
+ return ssp.csr_matrix((data, (row_ind, col_ind)), shape=shape)
+
+
+def process_files_ddi(files, triple_file, saved_relation2id=None, keeptrainone = False):
+ entity2id = {}
+ relation2id = {} if saved_relation2id is None else saved_relation2id
+
+ triplets = {}
+ kg_triple = []
+ ent = 0
+ rel = 0
+
+ for file_type, file_path in files.items():
+ data = []
+ # with open(file_path) as f:
+ # file_data = [line.split() for line in f.read().split('\n')[:-1]]
+ file_data = np.loadtxt(file_path)
+ for triplet in file_data:
+ #print(triplet)
+ triplet[0], triplet[1], triplet[2] = int(triplet[0]), int(triplet[1]), int(triplet[2])
+ if triplet[0] not in entity2id:
+ entity2id[triplet[0]] = triplet[0]
+ #ent += 1
+ if triplet[1] not in entity2id:
+ entity2id[triplet[1]] = triplet[1]
+ #ent += 1
+ if not saved_relation2id and triplet[2] not in relation2id:
+ if keeptrainone:
+ triplet[2] = 0
+ relation2id[triplet[2]] = 0
+ rel = 1
+ else:
+ relation2id[triplet[2]] = triplet[2]
+ rel += 1
+
+ # Save the triplets corresponding to only the known relations
+ if triplet[2] in relation2id:
+ data.append([entity2id[triplet[0]], entity2id[triplet[1]], relation2id[triplet[2]]])
+
+ triplets[file_type] = np.array(data)
+ #print(rel)
+ triplet_kg = np.loadtxt(triple_file)
+ # print(np.max(triplet_kg[:, -1]))
+ for (h, t, r) in triplet_kg:
+ h, t, r = int(h), int(t), int(r)
+ if h not in entity2id:
+ entity2id[h] = h
+ if t not in entity2id:
+ entity2id[t] = t
+ if not saved_relation2id and rel+r not in relation2id:
+ relation2id[rel+r] = rel + r
+ kg_triple.append([h, t, r])
+ kg_triple = np.array(kg_triple)
+ id2entity = {v: k for k, v in entity2id.items()}
+ id2relation = {v: k for k, v in relation2id.items()}
+ #print(relation2id, rel)
+
+ # Construct the list of adjacency matrix each corresponding to each relation. Note that this is constructed only from the train data.
+ adj_list = []
+ #print(kg_triple)
+ #for i in range(len(relation2id)):
+ for i in range(rel):
+ idx = np.argwhere(triplets['train'][:, 2] == i)
+ adj_list.append(csc_matrix((np.ones(len(idx), dtype=np.uint8), (triplets['train'][:, 0][idx].squeeze(1), triplets['train'][:, 1][idx].squeeze(1))), shape=(34124, 34124)))
+ for i in range(rel, len(relation2id)):
+ idx = np.argwhere(kg_triple[:, 2] == i-rel)
+ #print(len(idx), i)
+ adj_list.append(csc_matrix((np.ones(len(idx), dtype=np.uint8), (kg_triple[:, 0][idx].squeeze(1), kg_triple[:, 1][idx].squeeze(1))), shape=(34124, 34124)))
+ #print(adj_list)
+ #assert 0
+ return adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel
+
+
+def get_neighbor_nodes(roots, adj, h=1, max_nodes_per_hop=None):
+ bfs_generator = _bfs_relational(adj, roots, max_nodes_per_hop)
+ lvls = list()
+ for _ in range(h):
+ try:
+ lvls.append(next(bfs_generator))
+ except StopIteration:
+ pass
+ return set().union(*lvls)
\ No newline at end of file
diff --git a/utils/graph_utils.py b/utils/graph_utils.py
new file mode 100644
index 0000000..ced0178
--- /dev/null
+++ b/utils/graph_utils.py
@@ -0,0 +1,185 @@
+import statistics
+import numpy as np
+import scipy.sparse as ssp
+import torch
+import networkx as nx
+import dgl
+import pickle
+
+
+def serialize(data):
+ data_tuple = tuple(data.values())
+ return pickle.dumps(data_tuple)
+
+
+def deserialize(data):
+ data_tuple = pickle.loads(data)
+ keys = ('nodes', 'r_label', 'g_label', 'n_label')
+ return dict(zip(keys, data_tuple))
+
+
+def get_edge_count(adj_list):
+ count = []
+ for adj in adj_list:
+ count.append(len(adj.tocoo().row.tolist()))
+ return np.array(count)
+
+
+def incidence_matrix(adj_list):
+ '''
+ adj_list: List of sparse adjacency matrices
+ '''
+
+ rows, cols, dats = [], [], []
+ dim = adj_list[0].shape
+ for adj in adj_list:
+ adjcoo = adj.tocoo()
+ rows += adjcoo.row.tolist()
+ cols += adjcoo.col.tolist()
+ dats += adjcoo.data.tolist()
+ row = np.array(rows)
+ col = np.array(cols)
+ data = np.array(dats)
+ return ssp.csc_matrix((data, (row, col)), shape=dim)
+
+
+def remove_nodes(A_incidence, nodes):
+ idxs_wo_nodes = list(set(range(A_incidence.shape[1])) - set(nodes))
+ return A_incidence[idxs_wo_nodes, :][:, idxs_wo_nodes]
+
+
+def ssp_to_torch(A, device, dense=False):
+ '''
+ A : Sparse adjacency matrix
+ '''
+ idx = torch.LongTensor([A.tocoo().row, A.tocoo().col])
+ dat = torch.FloatTensor(A.tocoo().data)
+ A = torch.sparse.FloatTensor(idx, dat, torch.Size([A.shape[0], A.shape[1]])).to(device=device)
+ return A
+
+
+# def ssp_multigraph_to_dgl(graph, n_feats=None):
+# """
+# Converting ssp multigraph (i.e. list of adjs) to dgl multigraph.
+# """
+
+# g_nx = nx.MultiDiGraph()
+# g_nx.add_nodes_from(list(range(graph[0].shape[0])))
+# # Add edges
+# for rel, adj in enumerate(graph):
+# # Convert adjacency matrix to tuples for nx0
+# nx_triplets = []
+# for src, dst in list(zip(adj.tocoo().row, adj.tocoo().col)):
+# nx_triplets.append((src, dst, {'type': rel}))
+# g_nx.add_edges_from(nx_triplets)
+
+# # make dgl graph
+# g_dgl = dgl.DGLGraph(multigraph=True)
+# g_dgl.from_networkx(g_nx, edge_attrs=['type'])
+# # add node features
+# if n_feats is not None:
+# g_dgl.ndata['feat'] = torch.tensor(n_feats)
+
+# return g_dgl
+
+def ssp_multigraph_to_dgl(graph):
+ """
+ Converting ssp multigraph (i.e. list of adjs) to dgl multigraph.
+ """
+
+ g_nx = nx.MultiDiGraph()
+ g_nx.add_nodes_from(list(range(graph[0].shape[0])))
+ # Add edges
+ for rel, adj in enumerate(graph):
+ # Convert adjacency matrix to tuples for nx0
+ nx_triplets = []
+ for src, dst in list(zip(adj.tocoo().row, adj.tocoo().col)):
+ nx_triplets.append((src, dst, {'type': rel}))
+ g_nx.add_edges_from(nx_triplets)
+
+ # make dgl graph
+ g_dgl = dgl.DGLGraph(multigraph=True)
+ g_dgl = dgl.from_networkx(g_nx, edge_attrs=['type'])
+
+ return g_dgl
+
+
+def collate_dgl(samples):
+ # The input `samples` is a list of pairs
+ graphs_pos, g_labels_pos, r_labels_pos = map(list, zip(*samples))
+ # print(graphs_pos, g_labels_pos, r_labels_pos, samples)
+ batched_graph_pos = dgl.batch(graphs_pos)
+ # print(batched_graph_pos)
+
+ # graphs_neg = [item for sublist in graphs_negs for item in sublist]
+ # g_labels_neg = [item for sublist in g_labels_negs for item in sublist]
+ # r_labels_neg = [item for sublist in r_labels_negs for item in sublist]
+
+ # batched_graph_neg = dgl.batch(graphs_neg)
+ return (batched_graph_pos, r_labels_pos), g_labels_pos # , drug_idx
+
+
+def move_batch_to_device_dgl(batch, device):
+ (g_dgl_pos, r_labels_pos), targets_pos = batch
+
+ targets_pos = torch.LongTensor(targets_pos).to(device=device)
+ r_labels_pos = torch.LongTensor(r_labels_pos).to(device=device)
+ # drug_idx = torch.LongTensor(drug_idx).to(device=device)
+ # targets_neg = torch.LongTensor(targets_neg).to(device=device)
+ # r_labels_neg = torch.LongTensor(r_labels_neg).to(device=device)
+
+ g_dgl_pos = g_dgl_pos.to(device)
+ # g_dgl_neg = send_graph_to_device(g_dgl_neg, device)
+
+ return g_dgl_pos, r_labels_pos, targets_pos
+
+
+def move_batch_to_device_dgl_ddi2(batch, device):
+ (g_dgl_pos, r_labels_pos), targets_pos = batch
+
+ targets_pos = torch.LongTensor(targets_pos).to(device=device)
+ r_labels_pos = torch.FloatTensor(r_labels_pos).to(device=device)
+ # drug_idx = torch.LongTensor(drug_idx).to(device=device)
+ # targets_neg = torch.LongTensor(targets_neg).to(device=device)
+ # r_labels_neg = torch.LongTensor(r_labels_neg).to(device=device)
+
+ g_dgl_pos = send_graph_to_device(g_dgl_pos, device)
+ # g_dgl_neg = send_graph_to_device(g_dgl_neg, device)
+
+ return g_dgl_pos, r_labels_pos, targets_pos
+
+
+def send_graph_to_device(g, device):
+ # nodes
+ labels = g.node_attr_schemes()
+ for l in labels.keys():
+ g.ndata[l] = g.ndata.pop(l).to(device)
+
+ # edges
+ labels = g.edge_attr_schemes()
+ for l in labels.keys():
+ g.edata[l] = g.edata.pop(l).to(device)
+ return g
+
+
+# The following three functions are modified from networks source codes to
+# accomodate diameter and radius for dirercted graphs
+
+
+def eccentricity(G):
+ e = {}
+ for n in G.nbunch_iter():
+ length = nx.single_source_shortest_path_length(G, n)
+ e[n] = max(length.values())
+ return e
+
+
+def radius(G):
+ e = eccentricity(G)
+ e = np.where(np.array(list(e.values())) > 0, list(e.values()), np.inf)
+ return min(e)
+
+
+def diameter(G):
+ e = eccentricity(G)
+ return max(e.values())
diff --git a/utils/logger.py b/utils/logger.py
new file mode 100644
index 0000000..50fa9cc
--- /dev/null
+++ b/utils/logger.py
@@ -0,0 +1,12 @@
+import json
+import logging
+import logging.config
+from os import makedirs
+
+def get_logger(log_dir, name):
+ config_dict = json.load(open('./config/' + 'logger_config.json'))
+ config_dict['handlers']['file_handler']['filename'] = log_dir + name + '.log'
+ makedirs(log_dir, exist_ok=True)
+ logging.config.dictConfig(config_dict)
+ logger = logging.getLogger(name)
+ return logger
\ No newline at end of file
diff --git a/utils/process_data.py b/utils/process_data.py
new file mode 100644
index 0000000..cfe07f0
--- /dev/null
+++ b/utils/process_data.py
@@ -0,0 +1,78 @@
+from collections import defaultdict as ddict
+import random
+
+
+def process(dataset, num_rel, n_layer, add_reverse):
+ """
+ pre-process dataset
+ :param dataset: a dictionary containing 'train', 'valid' and 'test' data.
+ :param num_rel: relation number
+ :return:
+ """
+
+ so2r = ddict(set)
+ so2randomhop = ddict()
+ class2num = ddict()
+ # print(len(dataset['train']))
+ # index = 0
+ # cnt = 0
+ for subj, rel, obj in dataset['train']:
+ class2num[rel] = class2num.setdefault(rel, 0) + 1
+ so2r[(subj,obj)].add(rel)
+ subj_hop, obj_hop = random.randint(1, n_layer), random.randint(1, n_layer)
+ so2randomhop.setdefault((subj, obj), (subj_hop, obj_hop))
+ if add_reverse:
+ so2r[(obj, subj)].add(rel + num_rel)
+ class2num[rel + num_rel] = class2num.setdefault(rel + num_rel, 0) + 1
+ so2randomhop.setdefault((obj, subj), (obj_hop, subj_hop))
+ # index+=1
+ # print("______________________")
+ # print(subj, rel, obj)
+ # print(so2r[(subj, obj)])
+ # print(so2r[(obj, subj)])
+ # print(len(so2r))
+ # print(2*index)
+ # assert len(so2r) == 2*index
+ # print(len(so2r))
+ # print(cnt)
+ so2r_train = {k: list(v) for k, v in so2r.items()}
+ for split in ['valid', 'test']:
+ for subj, rel, obj in dataset[split]:
+ so2r[(subj, obj)].add(rel)
+ so2r[(obj, subj)].add(rel + num_rel)
+ subj_hop, obj_hop = random.randint(1, n_layer), random.randint(1, n_layer)
+ so2randomhop.setdefault((subj, obj), (subj_hop, obj_hop))
+ so2randomhop.setdefault((obj, subj), (obj_hop, subj_hop))
+ so2r_all = {k: list(v) for k, v in so2r.items()}
+ triplets = ddict(list)
+
+ # for (subj, obj), rel in so2r_train.items():
+ # triplets['train_rel'].append({'triple': (subj, rel, obj), 'label': so2r_train[(subj, obj)], 'random_hop':so2randomhop[(subj, obj)]})
+ # FOR DDI
+ for subj, rel, obj in dataset['train']:
+ triplets['train_rel'].append(
+ {'triple': (subj, rel, obj), 'label': so2r_train[(subj, obj)], 'random_hop': so2randomhop[(subj, obj)]})
+ if add_reverse:
+ triplets['train_rel'].append(
+ {'triple': (obj, rel+num_rel, subj), 'label': so2r_train[(obj, subj)], 'random_hop': so2randomhop[(obj, subj)]})
+ for split in ['valid', 'test']:
+ for subj, rel, obj in dataset[split]:
+ triplets[f"{split}_rel"].append({'triple': (subj, rel, obj), 'label': so2r_all[(subj, obj)], 'random_hop':so2randomhop[(subj, obj)]})
+ triplets[f"{split}_rel_inv"].append(
+ {'triple': (obj, rel + num_rel, subj), 'label': so2r_all[(obj, subj)], 'random_hop':so2randomhop[(obj, subj)]})
+ triplets = dict(triplets)
+ return triplets, class2num
+
+
+def process_multi_label(input, multi_label, pos_neg):
+ triplets = ddict(list)
+ for index, data in enumerate(input['train']):
+ subj, _, obj = data
+ triplets['train_rel'].append(
+ {'triple': (subj, -1, obj), 'label': multi_label['train'][index], 'pos_neg': pos_neg['train'][index]})
+ for split in ['valid', 'test']:
+ for index, data in enumerate(input[split]):
+ subj, _, obj = data
+ triplets[f"{split}_rel"].append({'triple': (subj, -1, obj), 'label': multi_label[split][index],'pos_neg': pos_neg[split][index]})
+ triplets = dict(triplets)
+ return triplets
\ No newline at end of file
diff --git a/utils/sampler.py b/utils/sampler.py
new file mode 100644
index 0000000..36fdf23
--- /dev/null
+++ b/utils/sampler.py
@@ -0,0 +1,1148 @@
+import os.path as osp
+from tqdm import tqdm
+from copy import deepcopy
+import torch
+import dgl
+from torch.utils.data import DataLoader, Dataset
+from dgl import DGLGraph, NID
+from dgl.dataloading.negative_sampler import Uniform
+from dgl import add_self_loop
+import numpy as np
+from scipy.sparse.csgraph import shortest_path
+#Grail
+from scipy.sparse import csc_matrix
+import os
+import json
+import matplotlib.pyplot as plt
+import logging
+from scipy.special import softmax
+from tqdm import tqdm
+import lmdb
+import multiprocessing as mp
+import scipy.sparse as ssp
+from utils.graph_utils import serialize, incidence_matrix, remove_nodes
+from utils.dgl_utils import _bfs_relational
+import struct
+
+
+class GraphDataSet(Dataset):
+ """
+ GraphDataset for torch DataLoader
+ """
+
+ def __init__(self, graph_list, tensor, n_ent, n_rel, max_seq_length):
+ self.graph_list = graph_list
+ self.tensor = tensor
+ self.n_ent = n_ent
+ self.n_rel = n_rel
+ self.max_seq_length = max_seq_length
+
+ def __len__(self):
+ return len(self.graph_list)
+
+ def __getitem__(self, index):
+ input_ids, input_mask, mask_position, mask_label = self.convert_subgraph_to_feature(self.graph_list[index],
+ self.max_seq_length,
+ self.tensor[index])
+ rela_mat = torch.zeros(self.max_seq_length, self.max_seq_length, self.n_rel)
+ rela_mat[self.graph_list[index].edges()[0], self.graph_list[index].edges()[1], self.graph_list[index].edata['rel']] = 1
+ return (self.graph_list[index], self.tensor[index], input_ids, input_mask, mask_position, mask_label, rela_mat)
+
+ def convert_subgraph_to_feature(self, subgraph, max_seq_length, triple):
+ input_ids = [subgraph.ndata[NID]]
+ input_mask = [1 for _ in range(subgraph.num_nodes())]
+ if max_seq_length - subgraph.num_nodes() > 0:
+ input_ids.append(torch.tensor([self.n_ent for _ in range(max_seq_length - subgraph.num_nodes())]))
+ input_mask += [0 for _ in range(max_seq_length - subgraph.num_nodes())]
+ input_ids = torch.cat(input_ids)
+ input_mask = torch.tensor(input_mask)
+ # input_ids = torch.cat(
+ # [subgraph.ndata[NID], torch.tensor([self.n_ent for _ in range(max_seq_length - subgraph.num_nodes())])])
+ # input_mask = [1 for _ in range(subgraph.num_nodes())]
+ # input_mask += [0 for _ in range(max_seq_length - subgraph.num_nodes())]
+ # while len(input_mask) < max_seq_length:
+ # # input_ids.append(self.n_ent)
+ # input_mask.append(0)
+ # TODO: predict head entity, now predict tail entity
+ mask_position = ((subgraph.ndata[NID] == triple[2]).nonzero().flatten())
+ input_ids[mask_position.item()] = self.n_ent + 1
+ mask_label = triple[2]
+ # for position in list(torch.where(subgraph.ndata['z']==1)[0]):
+ # mask_position = position
+ # input_ids[mask_position]=self.n_ent+1
+ # mask_label = subgraph.ndata[NID][mask_position]
+ # break
+ assert input_ids.size(0) == max_seq_length
+ assert input_mask.size(0) == max_seq_length
+
+ return input_ids, input_mask, mask_position, mask_label
+
+
+class GraphDataSetRP(Dataset):
+ """
+ GraphDataset for torch DataLoader
+ """
+
+ def __init__(self, graph_list, tensor, n_ent, n_rel, max_seq_length):
+ self.graph_list = graph_list
+ self.tensor = tensor
+ self.n_ent = n_ent
+ self.n_rel = n_rel
+ self.max_seq_length = max_seq_length
+
+ def __len__(self):
+ return len(self.graph_list)
+
+ def __getitem__(self, index):
+ input_ids, num_nodes, head_position, tail_position = self.convert_subgraph_to_feature(self.graph_list[index],self.tensor[index],self.max_seq_length)
+ rela_mat = torch.zeros(self.max_seq_length, self.max_seq_length, self.n_rel)
+ rela_mat[self.graph_list[index].edges()[0], self.graph_list[index].edges()[1], self.graph_list[index].edata['rel']] = 1
+ return (self.graph_list[index], self.tensor[index], input_ids, num_nodes, head_position, tail_position, rela_mat)
+
+ def convert_subgraph_to_feature(self, subgraph, triple, max_seq_length):
+ input_ids = [subgraph.ndata[NID]]
+ if max_seq_length - subgraph.num_nodes() > 0:
+ input_ids.append(torch.tensor([self.n_ent for _ in range(max_seq_length - subgraph.num_nodes())]))
+ input_ids = torch.cat(input_ids)
+ # print(subgraph)
+ # print(input_ids)
+ # print(triple)
+ head_position = torch.where(subgraph.ndata[NID] == triple[0])[0]
+ tail_position = torch.where(subgraph.ndata[NID] == triple[2])[0]
+ # print(head_position)
+ # print(tail_position)
+ # exit(0)
+ # input_ids = torch.cat(
+ # [subgraph.ndata[NID], torch.tensor([self.n_ent for _ in range(max_seq_length - subgraph.num_nodes())])])
+ # input_mask = [1 for _ in range(subgraph.num_nodes())]
+ # input_mask += [0 for _ in range(max_seq_length - subgraph.num_nodes())]
+ # while len(input_mask) < max_seq_length:
+ # # input_ids.append(self.n_ent)
+ # input_mask.append(0)
+ # for position in list(torch.where(subgraph.ndata['z']==1)[0]):
+ # mask_position = position
+ # input_ids[mask_position]=self.n_ent+1
+ # mask_label = subgraph.ndata[NID][mask_position]
+ # break
+ assert input_ids.size(0) == max_seq_length
+
+ return input_ids, subgraph.num_nodes(), head_position, tail_position
+
+class GraphDataSetGCN(Dataset):
+ """
+ GraphDataset for torch DataLoader
+ """
+
+ def __init__(self, graph_list, tensor, n_ent, n_rel):
+ self.graph_list = graph_list
+ self.tensor = tensor
+
+ def __len__(self):
+ return len(self.graph_list)
+
+ def __getitem__(self, index):
+ head_idx = torch.where(self.graph_list[index].ndata[NID] == self.tensor[index][0])[0]
+ tail_idx = torch.where(self.graph_list[index].ndata[NID] == self.tensor[index][2])[0]
+ return (self.graph_list[index], self.tensor[index], head_idx, tail_idx)
+
+class PosNegEdgesGenerator(object):
+ """
+ Generate positive and negative samples
+ Attributes:
+ g(dgl.DGLGraph): graph
+ split_edge(dict): split edge
+ neg_samples(int): num of negative samples per positive sample
+ subsample_ratio(float): ratio of subsample
+ shuffle(bool): if shuffle generated graph list
+ """
+
+ def __init__(self, g, split_edge, neg_samples=1, subsample_ratio=0.1, shuffle=True):
+ self.neg_sampler = Uniform(neg_samples)
+ self.subsample_ratio = subsample_ratio
+ self.split_edge = split_edge
+ self.g = g
+ self.shuffle = shuffle
+
+ def __call__(self, split_type):
+
+ if split_type == 'train':
+ subsample_ratio = self.subsample_ratio
+ else:
+ subsample_ratio = 1
+
+ pos_edges = self.g.edges()
+ pos_edges = torch.stack((pos_edges[0], pos_edges[1]), 1)
+
+ if split_type == 'train':
+ # Adding self loop in train avoids sampling the source node itself.
+ g = add_self_loop(self.g)
+ eids = g.edge_ids(pos_edges[:, 0], pos_edges[:, 1])
+ neg_edges = torch.stack(self.neg_sampler(g, eids), dim=1)
+ else:
+ neg_edges = self.split_edge[split_type]['edge_neg']
+ pos_edges = self.subsample(pos_edges, subsample_ratio).long()
+ neg_edges = self.subsample(neg_edges, subsample_ratio).long()
+
+ edges = torch.cat([pos_edges, neg_edges])
+ labels = torch.cat([torch.ones(pos_edges.size(0), 1), torch.zeros(neg_edges.size(0), 1)])
+ if self.shuffle:
+ perm = torch.randperm(edges.size(0))
+ edges = edges[perm]
+ labels = labels[perm]
+ return edges, labels
+
+ def subsample(self, edges, subsample_ratio):
+ """
+ Subsample generated edges.
+ Args:
+ edges(Tensor): edges to subsample
+ subsample_ratio(float): ratio of subsample
+
+ Returns:
+ edges(Tensor): edges
+
+ """
+
+ num_edges = edges.size(0)
+ perm = torch.randperm(num_edges)
+ perm = perm[:int(subsample_ratio * num_edges)]
+ edges = edges[perm]
+ return edges
+
+
+class EdgeDataSet(Dataset):
+ """
+ Assistant Dataset for speeding up the SEALSampler
+ """
+
+ def __init__(self, triples, transform):
+ self.transform = transform
+ self.triples = triples
+
+ def __len__(self):
+ return len(self.triples)
+
+ def __getitem__(self, index):
+ edge = torch.tensor([self.triples[index]['triple'][0], self.triples[index]['triple'][2]])
+ subgraph = self.transform(edge)
+ return (subgraph)
+
+
+class SEALSampler(object):
+ """
+ Sampler for SEAL in paper(no-block version)
+ The strategy is to sample all the k-hop neighbors around the two target nodes.
+ Attributes:
+ graph(DGLGraph): The graph
+ hop(int): num of hop
+ num_workers(int): num of workers
+
+ """
+
+ def __init__(self, graph, hop=1, max_num_nodes=100, num_workers=32, type='greedy', print_fn=print):
+ self.graph = graph
+ self.hop = hop
+ self.print_fn = print_fn
+ self.num_workers = num_workers
+ self.threshold = None
+ self.max_num_nodes = max_num_nodes
+ self.sample_type = type
+
+ def sample_subgraph(self, target_nodes, mode='valid'):
+ """
+ Args:
+ target_nodes(Tensor): Tensor of two target nodes
+ Returns:
+ subgraph(DGLGraph): subgraph
+ """
+ # TODO: Add sample constrain on each hop
+ sample_nodes = [target_nodes]
+ frontiers = target_nodes
+ if self.sample_type == 'greedy':
+ for i in range(self.hop):
+ # get sampled node number
+ tmp_sample_nodes = torch.cat(sample_nodes)
+ tmp_sample_nodes = torch.unique(tmp_sample_nodes)
+ if tmp_sample_nodes.size(0) < self.max_num_nodes: # whether sample or not
+ frontiers = self.graph.out_edges(frontiers)[1]
+ frontiers = torch.unique(frontiers)
+ if frontiers.size(0) > self.max_num_nodes - tmp_sample_nodes.size(0):
+ frontiers = np.random.choice(frontiers.numpy(), self.max_num_nodes - tmp_sample_nodes.size(0), replace=False)
+ frontiers = torch.unique(torch.tensor(frontiers))
+ sample_nodes.append(frontiers)
+ elif self.sample_type == 'greedy_set':
+ for i in range(self.hop):
+ # get sampled node number
+ tmp_sample_nodes = torch.cat(sample_nodes)
+ tmp_sample_nodes_set = set(tmp_sample_nodes.tolist())
+ # tmp_sample_nodes = torch.unique(tmp_sample_nodes)
+ if len(tmp_sample_nodes_set) < self.max_num_nodes: # whether sample or not
+ tmp_frontiers = self.graph.out_edges(frontiers)[1]
+ tmp_frontiers_set = set(tmp_frontiers.tolist())
+ tmp_frontiers_set = tmp_frontiers_set - tmp_sample_nodes_set
+ if not tmp_frontiers_set:
+ break
+ else:
+ frontiers_set = tmp_frontiers_set
+ frontiers = torch.tensor(list(frontiers_set))
+ if frontiers.size(0) > self.max_num_nodes - len(tmp_sample_nodes_set):
+ frontiers = np.random.choice(frontiers.numpy(), self.max_num_nodes - len(tmp_sample_nodes_set), replace=False)
+ frontiers = torch.unique(torch.tensor(frontiers))
+ sample_nodes.append(frontiers)
+ elif self.sample_type == 'average_set':
+ for i in range(self.hop):
+ # get sampled node number
+ tmp_sample_nodes = torch.cat(sample_nodes)
+ tmp_sample_nodes_set = set(tmp_sample_nodes.tolist())
+ num_nodes = int(self.max_num_nodes / self.hop * (i + 1))
+ if len(tmp_sample_nodes_set) < num_nodes: # whether sample or not
+ tmp_frontiers = self.graph.out_edges(frontiers)[1]
+ tmp_frontiers_set = set(tmp_frontiers.tolist())
+ tmp_frontiers_set = tmp_frontiers_set - tmp_sample_nodes_set
+ if not tmp_frontiers_set:
+ break
+ else:
+ frontiers_set = tmp_frontiers_set
+ frontiers = torch.tensor(list(frontiers_set))
+ if frontiers.size(0) > num_nodes - len(tmp_sample_nodes_set):
+ frontiers = np.random.choice(frontiers.numpy(), num_nodes - len(tmp_sample_nodes_set), replace=False)
+ frontiers = torch.unique(torch.tensor(frontiers))
+ sample_nodes.append(frontiers)
+ elif self.sample_type == 'average':
+ for i in range(self.hop):
+ tmp_sample_nodes = torch.cat(sample_nodes)
+ tmp_sample_nodes = torch.unique(tmp_sample_nodes)
+ num_nodes = int(self.max_num_nodes/self.hop*(i+1))
+ if tmp_sample_nodes.size(0) < num_nodes: # whether sample or not
+ frontiers = self.graph.out_edges(frontiers)[1]
+ frontiers = torch.unique(frontiers)
+ if frontiers.size(0) > num_nodes - tmp_sample_nodes.size(0):
+ frontiers = np.random.choice(frontiers.numpy(), num_nodes - tmp_sample_nodes.size(0),
+ replace=False)
+ frontiers = torch.unique(torch.tensor(frontiers))
+ sample_nodes.append(frontiers)
+ elif self.sample_type == 'enclosing_subgraph':
+ u, v = target_nodes[0], target_nodes[1]
+ u_neighbour, v_neighbour = u, v
+ # print(target_nodes)
+ u_sample_nodes = [u_neighbour.reshape(1)]
+ v_sample_nodes = [v_neighbour.reshape(1)]
+ graph = self.graph
+ if graph.has_edges_between(u, v):
+ link_id = graph.edge_ids(u, v, return_uv=True)[2]
+ graph.remove_edges(link_id)
+ if graph.has_edges_between(v, u):
+ link_id = graph.edge_ids(v, u, return_uv=True)[2]
+ graph.remove_edges(link_id)
+ for i in range(self.hop):
+ u_frontiers = graph.out_edges(u_neighbour)[1]
+ # v_frontiers = self.graph.out_edges(v_neighbour)[1]
+ u_neighbour = u_frontiers
+ # set(u_frontiers.tolist())
+ if u_frontiers.size(0) > self.max_num_nodes:
+ u_frontiers = np.random.choice(u_frontiers.numpy(), self.max_num_nodes, replace=False)
+ u_frontiers = torch.tensor(u_frontiers)
+ u_sample_nodes.append(u_frontiers)
+ for i in range(self.hop):
+ v_frontiers = graph.out_edges(v_neighbour)[1]
+ # v_frontiers = self.graph.out_edges(v_neighbour)[1]
+ v_neighbour = v_frontiers
+ if v_frontiers.size(0) > self.max_num_nodes:
+ v_frontiers = np.random.choice(v_frontiers.numpy(), self.max_num_nodes, replace=False)
+ v_frontiers = torch.tensor(v_frontiers)
+ v_sample_nodes.append(v_frontiers)
+ # print('U', u_sample_nodes)
+ # print('V', v_sample_nodes)
+ u_sample_nodes = torch.cat(u_sample_nodes)
+ u_sample_nodes = torch.unique(u_sample_nodes)
+ v_sample_nodes = torch.cat(v_sample_nodes)
+ v_sample_nodes = torch.unique(v_sample_nodes)
+ # print('U', u_sample_nodes)
+ # print('V', v_sample_nodes)
+ u_sample_nodes_set = set(u_sample_nodes.tolist())
+ v_sample_nodes_set = set(v_sample_nodes.tolist())
+ uv_inter_neighbour = u_sample_nodes_set.intersection(v_sample_nodes_set)
+ frontiers = torch.tensor(list(uv_inter_neighbour), dtype=torch.int64)
+ # print(frontiers)
+ sample_nodes.append(frontiers)
+ else:
+ raise NotImplementedError
+ sample_nodes = torch.cat(sample_nodes)
+ sample_nodes = torch.unique(sample_nodes)
+ subgraph = dgl.node_subgraph(self.graph, sample_nodes)
+
+ # Each node should have unique node id in the new subgraph
+ u_id = int(torch.nonzero(subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False))
+ v_id = int(torch.nonzero(subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False))
+ # remove link between target nodes in positive subgraphs.
+ # Edge removing will rearange NID and EID, which lose the original NID and EID.
+
+ # if dgl.__version__[:5] < '0.6.0':
+ # nids = subgraph.ndata[NID]
+ # eids = subgraph.edata[EID]
+ # if subgraph.has_edges_between(u_id, v_id):
+ # link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2]
+ # subgraph.remove_edges(link_id)
+ # eids = eids[subgraph.edata[EID]]
+ # if subgraph.has_edges_between(v_id, u_id):
+ # link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2]
+ # subgraph.remove_edges(link_id)
+ # eids = eids[subgraph.edata[EID]]
+ # subgraph.ndata[NID] = nids
+ # subgraph.edata[EID] = eids
+
+ if subgraph.has_edges_between(u_id, v_id):
+ link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2]
+ subgraph.remove_edges(link_id)
+ if subgraph.has_edges_between(v_id, u_id):
+ link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2]
+ subgraph.remove_edges(link_id)
+
+ z = drnl_node_labeling(subgraph, u_id, v_id)
+ subgraph.ndata['z'] = z
+ return subgraph
+
+ def _collate(self, batch_graphs):
+
+ # batch_graphs = map(list, zip(*batch))
+ # print(batch_graphs)
+ # print(batch_triples)
+ # print(batch_labels)
+
+ batch_graphs = dgl.batch(batch_graphs)
+ return batch_graphs
+
+ def __call__(self, triples):
+ subgraph_list = []
+ triples_list = []
+ labels_list = []
+ edge_dataset = EdgeDataSet(triples, transform=self.sample_subgraph)
+ self.print_fn('Using {} workers in sampling job.'.format(self.num_workers))
+ sampler = DataLoader(edge_dataset, batch_size=128, num_workers=self.num_workers,
+ shuffle=False, collate_fn=self._collate)
+ for subgraph in tqdm(sampler, ncols=100):
+ subgraph = dgl.unbatch(subgraph)
+
+ subgraph_list += subgraph
+
+ return subgraph_list
+
+
+class SEALSampler_NG(object):
+ """
+ Sampler for SEAL in paper(no-block version)
+ The strategy is to sample all the k-hop neighbors around the two target nodes.
+ Attributes:
+ graph(DGLGraph): The graph
+ hop(int): num of hop
+ num_workers(int): num of workers
+
+ """
+
+ def __init__(self, graph, hop=1, max_num_nodes=100, num_workers=32, type='greedy', print_fn=print):
+ self.graph = graph
+ self.hop = hop
+ self.print_fn = print_fn
+ self.num_workers = num_workers
+ self.threshold = None
+ self.max_num_nodes = max_num_nodes
+ self.sample_type = type
+
+ def sample_subgraph(self, target_nodes, mode='valid'):
+ """
+ Args:
+ target_nodes(Tensor): Tensor of two target nodes
+ Returns:
+ subgraph(DGLGraph): subgraph
+ """
+ # TODO: Add sample constrain on each hop
+ sample_nodes = [target_nodes]
+ frontiers = target_nodes
+ if self.sample_type == 'greedy':
+ for i in range(self.hop):
+ # get sampled node number
+ tmp_sample_nodes = torch.cat(sample_nodes)
+ tmp_sample_nodes = torch.unique(tmp_sample_nodes)
+ if tmp_sample_nodes.size(0) < self.max_num_nodes: # whether sample or not
+ frontiers = self.graph.out_edges(frontiers)[1]
+ frontiers = torch.unique(frontiers)
+ if frontiers.size(0) > self.max_num_nodes - tmp_sample_nodes.size(0):
+ frontiers = np.random.choice(frontiers.numpy(), self.max_num_nodes - tmp_sample_nodes.size(0), replace=False)
+ frontiers = torch.unique(torch.tensor(frontiers))
+ sample_nodes.append(frontiers)
+ elif self.sample_type == 'greedy_set':
+ for i in range(self.hop):
+ # get sampled node number
+ tmp_sample_nodes = torch.cat(sample_nodes)
+ tmp_sample_nodes_set = set(tmp_sample_nodes.tolist())
+ # tmp_sample_nodes = torch.unique(tmp_sample_nodes)
+ if len(tmp_sample_nodes_set) < self.max_num_nodes: # whether sample or not
+ tmp_frontiers = self.graph.out_edges(frontiers)[1]
+ tmp_frontiers_set = set(tmp_frontiers.tolist())
+ tmp_frontiers_set = tmp_frontiers_set - tmp_sample_nodes_set
+ if not tmp_frontiers_set:
+ break
+ else:
+ frontiers_set = tmp_frontiers_set
+ frontiers = torch.tensor(list(frontiers_set))
+ if frontiers.size(0) > self.max_num_nodes - len(tmp_sample_nodes_set):
+ frontiers = np.random.choice(frontiers.numpy(), self.max_num_nodes - len(tmp_sample_nodes_set), replace=False)
+ frontiers = torch.unique(torch.tensor(frontiers))
+ sample_nodes.append(frontiers)
+ elif self.sample_type == 'average_set':
+ for i in range(self.hop):
+ # get sampled node number
+ tmp_sample_nodes = torch.cat(sample_nodes)
+ tmp_sample_nodes_set = set(tmp_sample_nodes.tolist())
+ num_nodes = int(self.max_num_nodes / self.hop * (i + 1))
+ if len(tmp_sample_nodes_set) < num_nodes: # whether sample or not
+ tmp_frontiers = self.graph.out_edges(frontiers)[1]
+ tmp_frontiers_set = set(tmp_frontiers.tolist())
+ tmp_frontiers_set = tmp_frontiers_set - tmp_sample_nodes_set
+ if not tmp_frontiers_set:
+ break
+ else:
+ frontiers_set = tmp_frontiers_set
+ frontiers = torch.tensor(list(frontiers_set))
+ if frontiers.size(0) > num_nodes - len(tmp_sample_nodes_set):
+ frontiers = np.random.choice(frontiers.numpy(), num_nodes - len(tmp_sample_nodes_set), replace=False)
+ frontiers = torch.unique(torch.tensor(frontiers))
+ sample_nodes.append(frontiers)
+ elif self.sample_type == 'average':
+ for i in range(self.hop):
+ tmp_sample_nodes = torch.cat(sample_nodes)
+ tmp_sample_nodes = torch.unique(tmp_sample_nodes)
+ num_nodes = int(self.max_num_nodes/self.hop*(i+1))
+ if tmp_sample_nodes.size(0) < num_nodes: # whether sample or not
+ frontiers = self.graph.out_edges(frontiers)[1]
+ frontiers = torch.unique(frontiers)
+ if frontiers.size(0) > num_nodes - tmp_sample_nodes.size(0):
+ frontiers = np.random.choice(frontiers.numpy(), num_nodes - tmp_sample_nodes.size(0),
+ replace=False)
+ frontiers = torch.unique(torch.tensor(frontiers))
+ sample_nodes.append(frontiers)
+ elif self.sample_type == 'enclosing_subgraph':
+ u, v = target_nodes[0], target_nodes[1]
+ u_neighbour, v_neighbour = u, v
+ # print(target_nodes)
+ u_sample_nodes = [u_neighbour.reshape(1)]
+ v_sample_nodes = [v_neighbour.reshape(1)]
+ graph = self.graph
+ if graph.has_edges_between(u, v):
+ link_id = graph.edge_ids(u, v, return_uv=True)[2]
+ graph.remove_edges(link_id)
+ if graph.has_edges_between(v, u):
+ link_id = graph.edge_ids(v, u, return_uv=True)[2]
+ graph.remove_edges(link_id)
+ for i in range(self.hop):
+ u_frontiers = self.graph.out_edges(u_neighbour)[1]
+ # v_frontiers = self.graph.out_edges(v_neighbour)[1]
+ u_neighbour = u_frontiers
+ # set(u_frontiers.tolist())
+ if u_frontiers.size(0) > self.max_num_nodes:
+ u_frontiers = np.random.choice(u_frontiers.numpy(), self.max_num_nodes, replace=False)
+ u_frontiers = torch.tensor(u_frontiers)
+ u_sample_nodes.append(u_frontiers)
+ for i in range(self.hop):
+ v_frontiers = self.graph.out_edges(v_neighbour)[1]
+ # v_frontiers = self.graph.out_edges(v_neighbour)[1]
+ v_neighbour = v_frontiers
+ if v_frontiers.size(0) > self.max_num_nodes:
+ v_frontiers = np.random.choice(v_frontiers.numpy(), self.max_num_nodes, replace=False)
+ v_frontiers = torch.tensor(v_frontiers)
+ v_sample_nodes.append(v_frontiers)
+ # print('U', u_sample_nodes)
+ # print('V', v_sample_nodes)
+ u_sample_nodes = torch.cat(u_sample_nodes)
+ u_sample_nodes = torch.unique(u_sample_nodes)
+ v_sample_nodes = torch.cat(v_sample_nodes)
+ v_sample_nodes = torch.unique(v_sample_nodes)
+ # print('U', u_sample_nodes)
+ # print('V', v_sample_nodes)
+ u_sample_nodes_set = set(u_sample_nodes.tolist())
+ v_sample_nodes_set = set(v_sample_nodes.tolist())
+ uv_inter_neighbour = u_sample_nodes_set.intersection(v_sample_nodes_set)
+ frontiers = torch.tensor(list(uv_inter_neighbour), dtype=torch.int64)
+ # print(frontiers)
+ sample_nodes.append(frontiers)
+ # print(sample_nodes)
+ # print("____________________________________")
+ # exit(0)
+ # v_neighbour = v_frontiers
+ else:
+ raise NotImplementedError
+
+ sample_nodes = torch.cat(sample_nodes)
+ sample_nodes = torch.unique(sample_nodes)
+ # print(sample_nodes)
+ # print("____________________________________")
+
+ return sample_nodes
+
+ def _collate(self, batch_nodes):
+
+ # batch_graphs = map(list, zip(*batch))
+ # print(batch_graphs)
+ # print(batch_triples)
+ # print(batch_labels)
+
+ return batch_nodes
+
+ def __call__(self, triples):
+ sample_nodes_list = []
+ edge_dataset = EdgeDataSet(triples, transform=self.sample_subgraph)
+ self.print_fn('Using {} workers in sampling job.'.format(self.num_workers))
+ sampler = DataLoader(edge_dataset, batch_size=1, num_workers=self.num_workers,
+ shuffle=False, collate_fn=self._collate)
+ for sample_nodes in tqdm(sampler):
+ sample_nodes_list.append(sample_nodes)
+ # for sample_nodes in tqdm(sampler, ncols=100):
+ # print(sample_nodes)
+ # exit(0)
+ # sample_nodes_list.append(sample_nodes)
+
+ return sample_nodes_list
+
+class SEALData(object):
+ """
+ 1. Generate positive and negative samples
+ 2. Subgraph sampling
+
+ Attributes:
+ g(dgl.DGLGraph): graph
+ split_edge(dict): split edge
+ hop(int): num of hop
+ neg_samples(int): num of negative samples per positive sample
+ subsample_ratio(float): ratio of subsample
+ use_coalesce(bool): True for coalesce graph. Graph with multi-edge need to coalesce
+ """
+
+ def __init__(self, g, split_edge, hop=1, neg_samples=1, subsample_ratio=1, prefix=None, save_dir=None,
+ num_workers=32, shuffle=True, use_coalesce=True, print_fn=print):
+ self.g = g
+ self.hop = hop
+ self.subsample_ratio = subsample_ratio
+ self.prefix = prefix
+ self.save_dir = save_dir
+ self.print_fn = print_fn
+
+ self.generator = PosNegEdgesGenerator(g=self.g,
+ split_edge=split_edge,
+ neg_samples=neg_samples,
+ subsample_ratio=subsample_ratio,
+ shuffle=shuffle)
+ # if use_coalesce:
+ # for k, v in g.edata.items():
+ # g.edata[k] = v.float() # dgl.to_simple() requires data is float
+ # self.g = dgl.to_simple(g, copy_ndata=True, copy_edata=True, aggregator='sum')
+ #
+ # self.ndata = {k: v for k, v in self.g.ndata.items()}
+ # self.edata = {k: v for k, v in self.g.edata.items()}
+ # self.g.ndata.clear()
+ # self.g.edata.clear()
+ # self.print_fn("Save ndata and edata in class.")
+ # self.print_fn("Clear ndata and edata in graph.")
+ #
+ self.sampler = SEALSampler(graph=self.g,
+ hop=hop,
+ num_workers=num_workers,
+ print_fn=print_fn)
+
+ def __call__(self, split_type):
+
+ if split_type == 'train':
+ subsample_ratio = self.subsample_ratio
+ else:
+ subsample_ratio = 1
+
+ path = osp.join(self.save_dir or '', '{}_{}_{}-hop_{}-subsample.bin'.format(self.prefix, split_type,
+ self.hop, subsample_ratio))
+
+ if osp.exists(path):
+ self.print_fn("Load existing processed {} files".format(split_type))
+ graph_list, data = dgl.load_graphs(path)
+ dataset = GraphDataSet(graph_list, data['labels'])
+
+ else:
+ self.print_fn("Processed {} files not exist.".format(split_type))
+
+ edges, labels = self.generator(split_type)
+ self.print_fn("Generate {} edges totally.".format(edges.size(0)))
+
+ graph_list, labels = self.sampler(edges, labels)
+ dataset = GraphDataSet(graph_list, labels)
+ dgl.save_graphs(path, graph_list, {'labels': labels})
+ self.print_fn("Save preprocessed subgraph to {}".format(path))
+ return dataset
+
+
+class TripleSampler(object):
+ def __init__(self, g, data, sample_ratio=0.1):
+ self.g = g
+ self.sample_ratio = sample_ratio
+ self.data = data
+
+ def __call__(self, split_type):
+
+ if split_type == 'train':
+ sample_ratio = self.sample_ratio
+ self.shuffle = True
+ pos_edges = self.g.edges()
+ pos_edges = torch.stack((pos_edges[0], pos_edges[1]), 1)
+
+ g = add_self_loop(self.g)
+ pos_edges = self.sample(pos_edges, sample_ratio).long()
+ eids = g.edge_ids(pos_edges[:, 0], pos_edges[:, 1])
+ edges = pos_edges
+ # labels = torch.cat([torch.ones(pos_edges.size(0), 1), torch.zeros(neg_edges.size(0), 1)])
+ triples = torch.stack(([edges[:, 0], g.edata['rel'][eids], edges[:, 1]]), 1)
+ else:
+ self.shuffle = False
+ triples = torch.tensor(self.data[split_type])
+ edges = torch.stack((triples[:, 0], triples[:, 2]), 1)
+
+ if self.shuffle:
+ perm = torch.randperm(edges.size(0))
+ edges = edges[perm]
+ triples = triples[perm]
+ return edges, triples
+
+ def sample(self, edges, sample_ratio):
+ """
+ Subsample generated edges.
+ Args:
+ edges(Tensor): edges to subsample
+ subsample_ratio(float): ratio of subsample
+
+ Returns:
+ edges(Tensor): edges
+
+ """
+
+ num_edges = edges.size(0)
+ perm = torch.randperm(num_edges)
+ perm = perm[:int(sample_ratio * num_edges)]
+ edges = edges[perm]
+ return edges
+
+
+def drnl_node_labeling(subgraph, src, dst):
+ """
+ Double Radius Node labeling
+ d = r(i,u)+r(i,v)
+ label = 1+ min(r(i,u),r(i,v))+ (d//2)*(d//2+d%2-1)
+ Isolated nodes in subgraph will be set as zero.
+ Extreme large graph may cause memory error.
+
+ Args:
+ subgraph(DGLGraph): The graph
+ src(int): node id of one of src node in new subgraph
+ dst(int): node id of one of dst node in new subgraph
+ Returns:
+ z(Tensor): node labeling tensor
+ """
+ adj = subgraph.adj().to_dense().numpy()
+ src, dst = (dst, src) if src > dst else (src, dst)
+ if src != dst:
+ idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
+ adj_wo_src = adj[idx, :][:, idx]
+
+ idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
+ adj_wo_dst = adj[idx, :][:, idx]
+
+ dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src)
+ dist2src = np.insert(dist2src, dst, 0, axis=0)
+ dist2src = torch.from_numpy(dist2src)
+
+ dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst - 1)
+ dist2dst = np.insert(dist2dst, src, 0, axis=0)
+ dist2dst = torch.from_numpy(dist2dst)
+ else:
+ dist2src = shortest_path(adj, directed=False, unweighted=True, indices=src)
+ # dist2src = np.insert(dist2src, dst, 0, axis=0)
+ dist2src = torch.from_numpy(dist2src)
+
+ dist2dst = shortest_path(adj, directed=False, unweighted=True, indices=dst)
+ # dist2dst = np.insert(dist2dst, src, 0, axis=0)
+ dist2dst = torch.from_numpy(dist2dst)
+
+ dist = dist2src + dist2dst
+ dist_over_2, dist_mod_2 = dist // 2, dist % 2
+
+ z = 1 + torch.min(dist2src, dist2dst)
+ z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
+ z[src] = 1.
+ z[dst] = 1.
+ z[torch.isnan(z)] = 0.
+
+ return z.to(torch.long)
+
+class GrailSampler(object):
+ def __init__(self, dataset, file_paths, external_kg_file, db_path, hop, enclosing_sub_graph, max_nodes_per_hop):
+ self.dataset = dataset
+ self.file_paths = file_paths
+ self.external_kg_file = external_kg_file
+ self.db_path = db_path
+ self.max_links = 2500000
+ self.params = dict()
+ self.params['hop'] = hop
+ self.params['enclosing_sub_graph'] = enclosing_sub_graph
+ self.params['max_nodes_per_hop'] = max_nodes_per_hop
+
+ def generate_subgraph_datasets(self, num_neg_samples_per_link, constrained_neg_prob, splits=['train', 'valid', 'test'], saved_relation2id=None, max_label_value=None):
+
+
+ testing = 'test' in splits
+ #adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel = process_files(params.file_paths, saved_relation2id)
+
+ # triple_file = f'data/{}/{}.txt'.format(params.dataset,params.BKG_file_name)
+ if self.dataset == 'drugbank':
+ adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel = self.process_files_ddi(self.file_paths, self.external_kg_file, saved_relation2id)
+ # else:
+ # adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel, triplets_mr, polarity_mr = self.process_files_decagon(params.file_paths, triple_file, saved_relation2id)
+ # self.plot_rel_dist(adj_list, f'rel_dist.png')
+ #print(triplets.keys(), triplets_mr.keys())
+ data_path = f'datasets/{self.dataset}/relation2id.json'
+ if not os.path.isdir(data_path) and testing:
+ with open(data_path, 'w') as f:
+ json.dump(relation2id, f)
+
+ graphs = {}
+
+ for split_name in splits:
+ if self.dataset == 'drugbank':
+ graphs[split_name] = {'triplets': triplets[split_name], 'max_size': self.max_links}
+ # elif self.dataset == 'BioSNAP':
+ # graphs[split_name] = {'triplets': triplets_mr[split_name], 'max_size': params.max_links, "polarity_mr": polarity_mr[split_name]}
+ # Sample train and valid/test links
+ for split_name, split in graphs.items():
+ print(f"Sampling negative links for {split_name}")
+ split['pos'], split['neg'] = self.sample_neg(adj_list, split['triplets'], num_neg_samples_per_link, max_size=split['max_size'], constrained_neg_prob=constrained_neg_prob)
+ #print(graphs.keys())
+ # if testing:
+ # directory = os.path.join(params.main_dir, 'data/{}/'.format(params.dataset))
+ # save_to_file(directory, f'neg_{params.test_file}_{params.constrained_neg_prob}.txt', graphs['test']['neg'], id2entity, id2relation)
+
+ self.links2subgraphs(adj_list, graphs, self.params, max_label_value)
+
+ def process_files_ddi(self, files, triple_file, saved_relation2id=None, keeptrainone = False):
+ entity2id = {}
+ relation2id = {} if saved_relation2id is None else saved_relation2id
+
+ triplets = {}
+ kg_triple = []
+ ent = 0
+ rel = 0
+
+ for file_type, file_path in files.items():
+ data = []
+ # with open(file_path) as f:
+ # file_data = [line.split() for line in f.read().split('\n')[:-1]]
+ file_data = np.loadtxt(file_path)
+ for triplet in file_data:
+ #print(triplet)
+ triplet[0], triplet[1], triplet[2] = int(triplet[0]), int(triplet[1]), int(triplet[2])
+ if triplet[0] not in entity2id:
+ entity2id[triplet[0]] = triplet[0]
+ #ent += 1
+ if triplet[1] not in entity2id:
+ entity2id[triplet[1]] = triplet[1]
+ #ent += 1
+ if not saved_relation2id and triplet[2] not in relation2id:
+ if keeptrainone:
+ triplet[2] = 0
+ relation2id[triplet[2]] = 0
+ rel = 1
+ else:
+ relation2id[triplet[2]] = triplet[2]
+ rel += 1
+
+ # Save the triplets corresponding to only the known relations
+ if triplet[2] in relation2id:
+ data.append([entity2id[triplet[0]], entity2id[triplet[1]], relation2id[triplet[2]]])
+
+ triplets[file_type] = np.array(data)
+ #print(rel)
+ triplet_kg = np.loadtxt(triple_file)
+ # print(np.max(triplet_kg[:, -1]))
+ for (h, t, r) in triplet_kg:
+ h, t, r = int(h), int(t), int(r)
+ if h not in entity2id:
+ entity2id[h] = h
+ if t not in entity2id:
+ entity2id[t] = t
+ if not saved_relation2id and rel+r not in relation2id:
+ relation2id[rel+r] = rel + r
+ kg_triple.append([h, t, r])
+ kg_triple = np.array(kg_triple)
+ id2entity = {v: k for k, v in entity2id.items()}
+ id2relation = {v: k for k, v in relation2id.items()}
+ #print(relation2id, rel)
+
+ # Construct the list of adjacency matrix each corresponding to each relation. Note that this is constructed only from the train data.
+ adj_list = []
+ #print(kg_triple)
+ #for i in range(len(relation2id)):
+ for i in range(rel):
+ idx = np.argwhere(triplets['train'][:, 2] == i)
+ adj_list.append(csc_matrix((np.ones(len(idx), dtype=np.uint8), (triplets['train'][:, 0][idx].squeeze(1), triplets['train'][:, 1][idx].squeeze(1))), shape=(34124, 34124)))
+ for i in range(rel, len(relation2id)):
+ idx = np.argwhere(kg_triple[:, 2] == i-rel)
+ #print(len(idx), i)
+ adj_list.append(csc_matrix((np.ones(len(idx), dtype=np.uint8), (kg_triple[:, 0][idx].squeeze(1), kg_triple[:, 1][idx].squeeze(1))), shape=(34124, 34124)))
+ #print(adj_list)
+ #assert 0
+ return adj_list, triplets, entity2id, relation2id, id2entity, id2relation, rel
+
+ def plot_rel_dist(self, adj_list, filename):
+ rel_count = []
+ for adj in adj_list:
+ rel_count.append(adj.count_nonzero())
+
+ fig = plt.figure(figsize=(12, 8))
+ plt.plot(rel_count)
+ fig.savefig(filename, dpi=fig.dpi)
+
+ def get_edge_count(self, adj_list):
+ count = []
+ for adj in adj_list:
+ count.append(len(adj.tocoo().row.tolist()))
+ return np.array(count)
+
+ def sample_neg(self, adj_list, edges, num_neg_samples_per_link=1, max_size=1000000, constrained_neg_prob=0):
+ pos_edges = edges
+ neg_edges = []
+
+ # if max_size is set, randomly sample train links
+ if max_size < len(pos_edges):
+ perm = np.random.permutation(len(pos_edges))[:max_size]
+ pos_edges = pos_edges[perm]
+
+ # sample negative links for train/test
+ n, r = adj_list[0].shape[0], len(adj_list)
+
+ # distribution of edges across reelations
+ theta = 0.001
+ edge_count = self.get_edge_count(adj_list)
+ rel_dist = np.zeros(edge_count.shape)
+ idx = np.nonzero(edge_count)
+ rel_dist[idx] = softmax(theta * edge_count[idx])
+
+ # possible head and tails for each relation
+ valid_heads = [adj.tocoo().row.tolist() for adj in adj_list]
+ valid_tails = [adj.tocoo().col.tolist() for adj in adj_list]
+
+ pbar = tqdm(total=len(pos_edges))
+ while len(neg_edges) < num_neg_samples_per_link * len(pos_edges):
+ neg_head, neg_tail, rel = pos_edges[pbar.n % len(pos_edges)][0], pos_edges[pbar.n % len(pos_edges)][1], \
+ pos_edges[pbar.n % len(pos_edges)][2]
+ if np.random.uniform() < constrained_neg_prob:
+ if np.random.uniform() < 0.5:
+ neg_head = np.random.choice(valid_heads[rel])
+ else:
+ neg_tail = np.random.choice(valid_tails[rel])
+ else:
+ if np.random.uniform() < 0.5:
+ neg_head = np.random.choice(n)
+ else:
+ neg_tail = np.random.choice(n)
+
+ if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0:
+ neg_edges.append([neg_head, neg_tail, rel])
+ pbar.update(1)
+
+ pbar.close()
+
+ neg_edges = np.array(neg_edges)
+ return pos_edges, neg_edges
+
+ def links2subgraphs(self, A, graphs, params, max_label_value=None):
+ '''
+ extract enclosing subgraphs, write map mode + named dbs
+ '''
+ max_n_label = {'value': np.array([0, 0])}
+ subgraph_sizes = []
+ enc_ratios = []
+ num_pruned_nodes = []
+
+ BYTES_PER_DATUM = self.get_average_subgraph_size(100, list(graphs.values())[0]['pos'], A, params) * 1.5
+ links_length = 0
+ for split_name, split in graphs.items():
+ links_length += (len(split['pos']) + len(split['neg'])) * 2
+ map_size = links_length * BYTES_PER_DATUM
+
+ env = lmdb.open(self.db_path, map_size=map_size, max_dbs=6)
+
+ def extraction_helper(A, links, g_labels, split_env):
+
+ with env.begin(write=True, db=split_env) as txn:
+ txn.put('num_graphs'.encode(), (len(links)).to_bytes(int.bit_length(len(links)), byteorder='little'))
+
+ with mp.Pool(processes=None, initializer=self.intialize_worker, initargs=(A, params, max_label_value)) as p:
+ args_ = zip(range(len(links)), links, g_labels)
+ for (str_id, datum) in tqdm(p.imap(self.extract_save_subgraph, args_), total=len(links)):
+ max_n_label['value'] = np.maximum(np.max(datum['n_labels'], axis=0), max_n_label['value'])
+ subgraph_sizes.append(datum['subgraph_size'])
+ enc_ratios.append(datum['enc_ratio'])
+ num_pruned_nodes.append(datum['num_pruned_nodes'])
+
+ with env.begin(write=True, db=split_env) as txn:
+ txn.put(str_id, serialize(datum))
+
+ for split_name, split in graphs.items():
+ logging.info(f"Extracting enclosing subgraphs for positive links in {split_name} set")
+ if self.dataset == 'BioSNAP':
+ labels = np.array(split["polarity_mr"])
+ else:
+ labels = np.ones(len(split['pos']))
+ db_name_pos = split_name + '_pos'
+ split_env = env.open_db(db_name_pos.encode())
+ extraction_helper(A, split['pos'], labels, split_env)
+
+ logging.info(f"Extracting enclosing subgraphs for negative links in {split_name} set")
+ if self.dataset == 'BioSNAP':
+ labels = np.array(split["polarity_mr"])
+ else:
+ labels = np.ones(len(split['pos']))
+ db_name_neg = split_name + '_neg'
+ split_env = env.open_db(db_name_neg.encode())
+ extraction_helper(A, split['neg'], labels, split_env)
+
+ max_n_label['value'] = max_label_value if max_label_value is not None else max_n_label['value']
+
+ with env.begin(write=True) as txn:
+ bit_len_label_sub = int.bit_length(int(max_n_label['value'][0]))
+ bit_len_label_obj = int.bit_length(int(max_n_label['value'][1]))
+ txn.put('max_n_label_sub'.encode(),
+ (int(max_n_label['value'][0])).to_bytes(bit_len_label_sub, byteorder='little'))
+ txn.put('max_n_label_obj'.encode(),
+ (int(max_n_label['value'][1])).to_bytes(bit_len_label_obj, byteorder='little'))
+
+ txn.put('avg_subgraph_size'.encode(), struct.pack('f', float(np.mean(subgraph_sizes))))
+ txn.put('min_subgraph_size'.encode(), struct.pack('f', float(np.min(subgraph_sizes))))
+ txn.put('max_subgraph_size'.encode(), struct.pack('f', float(np.max(subgraph_sizes))))
+ txn.put('std_subgraph_size'.encode(), struct.pack('f', float(np.std(subgraph_sizes))))
+
+ txn.put('avg_enc_ratio'.encode(), struct.pack('f', float(np.mean(enc_ratios))))
+ txn.put('min_enc_ratio'.encode(), struct.pack('f', float(np.min(enc_ratios))))
+ txn.put('max_enc_ratio'.encode(), struct.pack('f', float(np.max(enc_ratios))))
+ txn.put('std_enc_ratio'.encode(), struct.pack('f', float(np.std(enc_ratios))))
+
+ txn.put('avg_num_pruned_nodes'.encode(), struct.pack('f', float(np.mean(num_pruned_nodes))))
+ txn.put('min_num_pruned_nodes'.encode(), struct.pack('f', float(np.min(num_pruned_nodes))))
+ txn.put('max_num_pruned_nodes'.encode(), struct.pack('f', float(np.max(num_pruned_nodes))))
+ txn.put('std_num_pruned_nodes'.encode(), struct.pack('f', float(np.std(num_pruned_nodes))))
+
+ def get_average_subgraph_size(self, sample_size, links, A, params):
+ total_size = 0
+ # print(links, len(links))
+ lst = np.random.choice(len(links), sample_size)
+ for idx in lst:
+ (n1, n2, r_label) = links[idx]
+ # for (n1, n2, r_label) in links[np.random.choice(len(links), sample_size)]:
+ nodes, n_labels, subgraph_size, enc_ratio, num_pruned_nodes, common_neighbor = self.subgraph_extraction_labeling((n1, n2),
+ r_label, A,
+ params['hop'],
+ params['enclosing_sub_graph'],
+ params['max_nodes_per_hop'])
+ datum = {'nodes': nodes, 'r_label': r_label, 'g_label': 0, 'n_labels': n_labels, 'common_neighbor': common_neighbor,
+ 'subgraph_size': subgraph_size, 'enc_ratio': enc_ratio, 'num_pruned_nodes': num_pruned_nodes}
+ total_size += len(serialize(datum))
+ return total_size / sample_size
+
+ def intialize_worker(self, A, params, max_label_value):
+ global A_, params_, max_label_value_
+ A_, params_, max_label_value_ = A, params, max_label_value
+
+ def extract_save_subgraph(self, args_):
+ idx, (n1, n2, r_label), g_label = args_
+ nodes, n_labels, subgraph_size, enc_ratio, num_pruned_nodes, common_neighbor = self.subgraph_extraction_labeling((n1, n2), r_label,
+ A_, params_['hop'],
+ params_['enclosing_sub_graph'],
+ params_['max_nodes_per_hop'])
+
+ # max_label_value_ is to set the maximum possible value of node label while doing double-radius labelling.
+ if max_label_value_ is not None:
+ n_labels = np.array([np.minimum(label, max_label_value_).tolist() for label in n_labels])
+
+ datum = {'nodes': nodes, 'r_label': r_label, 'g_label': g_label, 'n_labels': n_labels, 'common_neighbor': common_neighbor,
+ 'subgraph_size': subgraph_size, 'enc_ratio': enc_ratio, 'num_pruned_nodes': num_pruned_nodes}
+ str_id = '{:08}'.format(idx).encode('ascii')
+
+ return (str_id, datum)
+
+ def get_neighbor_nodes(self, roots, adj, h=1, max_nodes_per_hop=None):
+ bfs_generator = _bfs_relational(adj, roots, max_nodes_per_hop)
+ lvls = list()
+ for _ in range(h):
+ try:
+ lvls.append(next(bfs_generator))
+ except StopIteration:
+ pass
+ return set().union(*lvls)
+
+ def subgraph_extraction_labeling(self, ind, rel, A_list, h=1, enclosing_sub_graph=False, max_nodes_per_hop=None,
+ max_node_label_value=None):
+ # extract the h-hop enclosing subgraphs around link 'ind'
+ A_incidence = incidence_matrix(A_list)
+ A_incidence += A_incidence.T
+ ind = list(ind)
+ ind[0], ind[1] = int(ind[0]), int(ind[1])
+ ind = (ind[0], ind[1])
+ root1_nei = self.get_neighbor_nodes(set([ind[0]]), A_incidence, h, max_nodes_per_hop)
+ root2_nei = self.get_neighbor_nodes(set([ind[1]]), A_incidence, h, max_nodes_per_hop)
+ subgraph_nei_nodes_int = root1_nei.intersection(root2_nei)
+ subgraph_nei_nodes_un = root1_nei.union(root2_nei)
+
+ root1_nei_1 = self.get_neighbor_nodes(set([ind[0]]), A_incidence, 1, max_nodes_per_hop)
+ root2_nei_1 = self.get_neighbor_nodes(set([ind[1]]), A_incidence, 1, max_nodes_per_hop)
+ common_neighbor = root1_nei_1.intersection(root2_nei_1)
+
+ # Extract subgraph | Roots being in the front is essential for labelling and the model to work properly.
+ if enclosing_sub_graph:
+ if ind[0] in subgraph_nei_nodes_int:
+ subgraph_nei_nodes_int.remove(ind[0])
+ if ind[1] in subgraph_nei_nodes_int:
+ subgraph_nei_nodes_int.remove(ind[1])
+ subgraph_nodes = list(ind) + list(subgraph_nei_nodes_int)
+ else:
+ if ind[0] in subgraph_nei_nodes_un:
+ subgraph_nei_nodes_un.remove(ind[0])
+ if ind[1] in subgraph_nei_nodes_un:
+ subgraph_nei_nodes_un.remove(ind[1])
+ subgraph_nodes = list(ind) + list(subgraph_nei_nodes_un) # list(set(ind).union(subgraph_nei_nodes_un))
+
+ subgraph = [adj[subgraph_nodes, :][:, subgraph_nodes] for adj in A_list]
+
+ labels, enclosing_subgraph_nodes = self.node_label(incidence_matrix(subgraph), max_distance=h)
+ # print(ind, subgraph_nodes[:32],enclosing_subgraph_nodes[:32], labels)
+ pruned_subgraph_nodes = np.array(subgraph_nodes)[enclosing_subgraph_nodes].tolist()
+ pruned_labels = labels[enclosing_subgraph_nodes]
+ # pruned_subgraph_nodes = subgraph_nodes
+ # pruned_labels = labels
+
+ if max_node_label_value is not None:
+ pruned_labels = np.array([np.minimum(label, max_node_label_value).tolist() for label in pruned_labels])
+
+ subgraph_size = len(pruned_subgraph_nodes)
+ enc_ratio = len(subgraph_nei_nodes_int) / (len(subgraph_nei_nodes_un) + 1e-3)
+ num_pruned_nodes = len(subgraph_nodes) - len(pruned_subgraph_nodes)
+ # print(pruned_subgraph_nodes)
+ # import time
+ # time.sleep(10)
+ return pruned_subgraph_nodes, pruned_labels, subgraph_size, enc_ratio, num_pruned_nodes, common_neighbor
+
+ def node_label(self, subgraph, max_distance=1):
+ # implementation of the node labeling scheme described in the paper
+ roots = [0, 1]
+ sgs_single_root = [remove_nodes(subgraph, [root]) for root in roots]
+ dist_to_roots = [
+ np.clip(ssp.csgraph.dijkstra(sg, indices=[0], directed=False, unweighted=True, limit=1e6)[:, 1:], 0, 1e7)
+ for r, sg in enumerate(sgs_single_root)]
+ dist_to_roots = np.array(list(zip(dist_to_roots[0][0], dist_to_roots[1][0])), dtype=int)
+
+ target_node_labels = np.array([[0, 1], [1, 0]])
+ labels = np.concatenate((target_node_labels, dist_to_roots)) if dist_to_roots.size else target_node_labels
+
+ enclosing_subgraph_nodes = np.where(np.max(labels, axis=1) <= max_distance)[0]
+ return labels, enclosing_subgraph_nodes
\ No newline at end of file
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000..7b96b2e
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,233 @@
+import argparse
+from scipy.sparse.csgraph import shortest_path
+import numpy as np
+import pandas as pd
+import torch
+import dgl
+import json
+import logging
+import logging.config
+import os
+import psutil
+from statistics import mean
+import pickle
+
+
+def parse_arguments():
+ """
+ Parse arguments
+ """
+ parser = argparse.ArgumentParser(description='SEAL')
+ parser.add_argument('--dataset', type=str, default='ogbl-collab')
+ parser.add_argument('--gpu_id', type=int, default=0)
+ parser.add_argument('--hop', type=int, default=1)
+ parser.add_argument('--model', type=str, default='dgcnn')
+ parser.add_argument('--gcn_type', type=str, default='gcn')
+ parser.add_argument('--num_layers', type=int, default=3)
+ parser.add_argument('--hidden_units', type=int, default=32)
+ parser.add_argument('--sort_k', type=int, default=30)
+ parser.add_argument('--pooling', type=str, default='sum')
+ parser.add_argument('--dropout', type=str, default=0.5)
+ parser.add_argument('--hits_k', type=int, default=50)
+ parser.add_argument('--lr', type=float, default=0.0001)
+ parser.add_argument('--neg_samples', type=int, default=1)
+ parser.add_argument('--subsample_ratio', type=float, default=0.1)
+ parser.add_argument('--epochs', type=int, default=60)
+ parser.add_argument('--batch_size', type=int, default=32)
+ parser.add_argument('--eval_steps', type=int, default=5)
+ parser.add_argument('--num_workers', type=int, default=32)
+ parser.add_argument('--random_seed', type=int, default=2023)
+ parser.add_argument('--save_dir', type=str, default='./processed')
+ args = parser.parse_args()
+
+ return args
+
+
+def coalesce_graph(graph, aggr_type='sum', copy_data=False):
+ """
+ Coalesce multi-edge graph
+ Args:
+ graph(DGLGraph): graph
+ aggr_type(str): type of aggregator for multi edge weights
+ copy_data(bool): if copy ndata and edata in new graph
+
+ Returns:
+ graph(DGLGraph): graph
+
+
+ """
+ src, dst = graph.edges()
+ graph_df = pd.DataFrame({'src': src, 'dst': dst})
+ graph_df['edge_weight'] = graph.edata['edge_weight'].numpy()
+
+ if aggr_type == 'sum':
+ tmp = graph_df.groupby(['src', 'dst'])['edge_weight'].sum().reset_index()
+ elif aggr_type == 'mean':
+ tmp = graph_df.groupby(['src', 'dst'])['edge_weight'].mean().reset_index()
+ else:
+ raise ValueError("aggr type error")
+
+ if copy_data:
+ graph = dgl.to_simple(graph, copy_ndata=True, copy_edata=True)
+ else:
+ graph = dgl.to_simple(graph)
+
+ src, dst = graph.edges()
+ graph_df = pd.DataFrame({'src': src, 'dst': dst})
+ graph_df = pd.merge(graph_df, tmp, how='left', on=['src', 'dst'])
+ graph.edata['edge_weight'] = torch.from_numpy(graph_df['edge_weight'].values).unsqueeze(1)
+
+ graph.edata.pop('count')
+ return graph
+
+
+def drnl_node_labeling(subgraph, src, dst):
+ """
+ Double Radius Node labeling
+ d = r(i,u)+r(i,v)
+ label = 1+ min(r(i,u),r(i,v))+ (d//2)*(d//2+d%2-1)
+ Isolated nodes in subgraph will be set as zero.
+ Extreme large graph may cause memory error.
+
+ Args:
+ subgraph(DGLGraph): The graph
+ src(int): node id of one of src node in new subgraph
+ dst(int): node id of one of dst node in new subgraph
+ Returns:
+ z(Tensor): node labeling tensor
+ """
+ adj = subgraph.adj().to_dense().numpy()
+ src, dst = (dst, src) if src > dst else (src, dst)
+ if src != dst:
+ idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
+ adj_wo_src = adj[idx, :][:, idx]
+
+ idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
+ adj_wo_dst = adj[idx, :][:, idx]
+
+ dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src)
+ dist2src = np.insert(dist2src, dst, 0, axis=0)
+ dist2src = torch.from_numpy(dist2src)
+
+ dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst - 1)
+ dist2dst = np.insert(dist2dst, src, 0, axis=0)
+ dist2dst = torch.from_numpy(dist2dst)
+ else:
+ dist2src = shortest_path(adj, directed=False, unweighted=True, indices=src)
+ # dist2src = np.insert(dist2src, dst, 0, axis=0)
+ dist2src = torch.from_numpy(dist2src)
+
+ dist2dst = shortest_path(adj, directed=False, unweighted=True, indices=dst)
+ # dist2dst = np.insert(dist2dst, src, 0, axis=0)
+ dist2dst = torch.from_numpy(dist2dst)
+
+ dist = dist2src + dist2dst
+ dist_over_2, dist_mod_2 = dist // 2, dist % 2
+
+ z = 1 + torch.min(dist2src, dist2dst)
+ z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
+ z[src] = 1.
+ z[dst] = 1.
+ z[torch.isnan(z)] = 0.
+
+ return z.to(torch.long)
+
+
+def cal_ranks(probs, label):
+ sorted_idx = np.argsort(probs, axis=1)[:,::-1]
+ find_target = sorted_idx == np.expand_dims(label, 1)
+ ranks = np.nonzero(find_target)[1] + 1
+ return ranks
+
+def cal_performance(ranks):
+ mrr = (1. / ranks).sum() / len(ranks)
+ m_r = sum(ranks) * 1.0 / len(ranks)
+ h_1 = sum(ranks<=1) * 1.0 / len(ranks)
+ h_10 = sum(ranks<=10) * 1.0 / len(ranks)
+ return mrr, m_r, h_1, h_10
+
+def get_logger(name, log_dir):
+ config_dict = json.load(open('./config/' + 'log_config.json'))
+ config_dict['handlers']['file_handler']['filename'] = log_dir + name + '.log'
+ logging.config.dictConfig(config_dict)
+ logger = logging.getLogger(name)
+ return logger
+
+def get_current_memory_gb() -> int:
+# 获取当前进程内存占用。
+ pid = os.getpid()
+ p = psutil.Process(pid)
+ info = p.memory_full_info()
+ return info.uss / 1024. / 1024. / 1024.
+
+def get_f1_score_list(class_dict):
+ f1_score_list = [[],[],[],[],[]]
+ for key in class_dict:
+ if key.isdigit():
+ if class_dict[key]['support'] < 10:
+ f1_score_list[0].append(class_dict[key]['f1-score'])
+ elif 10 <= class_dict[key]['support'] < 50:
+ f1_score_list[1].append(class_dict[key]['f1-score'])
+ elif 50 <= class_dict[key]['support'] < 100:
+ f1_score_list[2].append(class_dict[key]['f1-score'])
+ elif 100 <= class_dict[key]['support'] < 1000:
+ f1_score_list[3].append(class_dict[key]['f1-score'])
+ elif 1000 <= class_dict[key]['support'] < 100000:
+ f1_score_list[4].append(class_dict[key]['f1-score'])
+ for index, _ in enumerate(f1_score_list):
+ f1_score_list[index] = mean(_)
+ return f1_score_list
+
+def get_acc_list(class_dict):
+ acc_list = [0.0,0.0,0.0,0.0,0.0]
+ support_list = [0,0,0,0,0]
+ proportion_list = [0.0,0.0,0.0,0.0,0.0]
+ for key in class_dict:
+ if key.isdigit():
+ if class_dict[key]['support'] < 10:
+ acc_list[0] += (class_dict[key]['recall']*class_dict[key]['support'])
+ support_list[0]+=class_dict[key]['support']
+ elif 10 <= class_dict[key]['support'] < 50:
+ acc_list[1] += (class_dict[key]['recall']*class_dict[key]['support'])
+ support_list[1] += class_dict[key]['support']
+ elif 50 <= class_dict[key]['support'] < 100:
+ acc_list[2] += (class_dict[key]['recall']*class_dict[key]['support'])
+ support_list[2] += class_dict[key]['support']
+ elif 100 <= class_dict[key]['support'] < 1000:
+ acc_list[3] += (class_dict[key]['recall']*class_dict[key]['support'])
+ support_list[3] += class_dict[key]['support']
+ elif 1000 <= class_dict[key]['support'] < 100000:
+ acc_list[4] += (class_dict[key]['recall'] * class_dict[key]['support'])
+ support_list[4] += class_dict[key]['support']
+ for index, _ in enumerate(acc_list):
+ acc_list[index] = acc_list[index] / support_list[index]
+ # proportion_list[index] = support_list[index] / class_dict['macro avg']['support']
+ return acc_list
+
+def deserialize(data):
+ data_tuple = pickle.loads(data)
+ keys = ('nodes', 'r_label', 'g_label', 'n_label')
+ return dict(zip(keys, data_tuple))
+
+
+class Temp_Scheduler(object):
+ def __init__(self, total_epochs, curr_temp, base_temp, temp_min=0.33, last_epoch=-1):
+ self.curr_temp = curr_temp
+ self.base_temp = base_temp
+ self.temp_min = temp_min
+ self.last_epoch = last_epoch
+ self.total_epochs = total_epochs
+ self.step(last_epoch + 1)
+
+ def step(self, epoch=None):
+ return self.decay_whole_process()
+
+ def decay_whole_process(self, epoch=None):
+ if epoch is None:
+ epoch = self.last_epoch + 1
+ self.last_epoch = epoch
+ # self.total_epochs = 150
+ self.curr_temp = (1 - self.last_epoch / self.total_epochs) * (self.base_temp - self.temp_min) + self.temp_min
+ if self.curr_temp < self.temp_min:
+ self.curr_temp = self.temp_min
+ return self.curr_temp
\ No newline at end of file