diff --git a/examples/pytorch/__temporary__/cluster_gcn/cluster_gcn.py b/examples/pytorch/__temporary__/cluster_gcn/cluster_gcn.py index 7f23626731a4..17f4ed54c1ba 100644 --- a/examples/pytorch/__temporary__/cluster_gcn/cluster_gcn.py +++ b/examples/pytorch/__temporary__/cluster_gcn/cluster_gcn.py @@ -8,8 +8,6 @@ import numpy as np from ogb.nodeproppred import DglNodePropPredDataset -USE_WRAPPER = True - class SAGE(nn.Module): def __init__(self, in_feats, n_hidden, n_classes): super().__init__() @@ -40,11 +38,6 @@ def forward(self, sg, x): model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda() opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) -if USE_WRAPPER: - import dglnew - graph.create_formats_() - graph = dglnew.graph.wrapper.DGLGraphStorage(graph) - num_partitions = 1000 sampler = dgl.dataloading.ClusterGCNSampler( graph, num_partitions, @@ -61,14 +54,13 @@ def forward(self, sg, x): batch_size=100, shuffle=True, drop_last=False, - pin_memory=True, - num_workers=8, - persistent_workers=True, - use_prefetch_thread=True) # TBD: could probably remove this argument + num_workers=0, + use_uva=True) durations = [] for _ in range(10): t0 = time.time() + model.train() for it, sg in enumerate(dataloader): x = sg.ndata['feat'] y = sg.ndata['label'][:, 0] @@ -85,4 +77,27 @@ def forward(self, sg, x): tt = time.time() print(tt - t0) durations.append(tt - t0) + + model.eval() + with torch.no_grad(): + val_preds, test_preds = [], [] + val_labels, test_labels = [], [] + for it, sg in enumerate(dataloader): + x = sg.ndata['feat'] + y = sg.ndata['label'][:, 0] + m_val = sg.ndata['valid_mask'] + m_test = sg.ndata['test_mask'] + y_hat = model(sg, x) + val_preds.append(y_hat[m_val]) + val_labels.append(y[m_val]) + test_preds.append(y_hat[m_test]) + test_labels.append(y[m_test]) + val_preds = torch.cat(val_preds, 0) + val_labels = torch.cat(val_labels, 0) + test_preds = torch.cat(test_preds, 0) + test_labels = torch.cat(test_labels, 0) + val_acc = MF.accuracy(val_preds, val_labels) + test_acc = MF.accuracy(test_preds, test_labels) + print('Validation acc:', val_acc.item(), 'Test acc:', test_acc.item()) + print(np.mean(durations[4:]), np.std(durations[4:])) diff --git a/examples/pytorch/__temporary__/graphsage/ddp.py b/examples/pytorch/__temporary__/graphsage/ddp.py index d1366f1e1d70..f7e16fe8898f 100644 --- a/examples/pytorch/__temporary__/graphsage/ddp.py +++ b/examples/pytorch/__temporary__/graphsage/ddp.py @@ -9,8 +9,9 @@ import time import numpy as np from ogb.nodeproppred import DglNodePropPredDataset +import tqdm -USE_WRAPPER = False +USE_UVA = True class SAGE(nn.Module): def __init__(self, in_feats, n_hidden, n_classes): @@ -20,6 +21,8 @@ def __init__(self, in_feats, n_hidden, n_classes): self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.dropout = nn.Dropout(0.5) + self.n_hidden = n_hidden + self.n_classes = n_classes def forward(self, blocks, x): h = x @@ -30,41 +33,66 @@ def forward(self, blocks, x): h = self.dropout(h) return h + def inference(self, g, device, batch_size, num_workers, buffer_device=None): + # The difference between this inference function and the one in the official + # example is that the intermediate results can also benefit from prefetching. + g.ndata['h'] = g.ndata['feat'] + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) + dataloader = dgl.dataloading.NodeDataLoader( + g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, + batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers, + persistent_workers=(num_workers > 0)) + if buffer_device is None: + buffer_device = device -def train(rank, world_size, graph, num_classes, split_idx): + for l, layer in enumerate(self.layers): + y = torch.zeros( + g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes, + device=buffer_device) + for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): + x = blocks[0].srcdata['h'] + h = layer(blocks[0], x) + if l != len(self.layers) - 1: + h = F.relu(h) + h = self.dropout(h) + y[output_nodes] = h.to(buffer_device) + g.ndata['h'] = y + return y + + +def train(rank, world_size, shared_memory_name, features, num_classes, split_idx): torch.cuda.set_device(rank) dist.init_process_group('nccl', 'tcp://127.0.0.1:12347', world_size=world_size, rank=rank) + graph = dgl.hetero_from_shared_memory(shared_memory_name) + feat, labels = features + graph.ndata['feat'] = feat + graph.ndata['label'] = labels + model = SAGE(graph.ndata['feat'].shape[1], 256, num_classes).cuda() model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank) opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] - if USE_WRAPPER: - import dglnew - graph = dglnew.graph.wrapper.DGLGraphStorage(graph) + + if USE_UVA: + train_idx = train_idx.to('cuda') sampler = dgl.dataloading.NeighborSampler( - [5, 5, 5], output_device='cpu', prefetch_node_feats=['feat'], - prefetch_labels=['label']) - dataloader = dgl.dataloading.NodeDataLoader( - graph, - train_idx, - sampler, - device='cuda', - batch_size=1000, - shuffle=True, - drop_last=False, - pin_memory=True, - num_workers=4, - persistent_workers=True, - use_ddp=True, - use_prefetch_thread=True) # TBD: could probably remove this argument + [15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label']) + train_dataloader = dgl.dataloading.NodeDataLoader( + graph, train_idx, sampler, + device='cuda', batch_size=1000, shuffle=True, drop_last=False, + num_workers=0, use_ddp=True, use_uva=USE_UVA) + valid_dataloader = dgl.dataloading.NodeDataLoader( + graph, valid_idx, sampler, device='cuda', batch_size=1024, shuffle=True, + drop_last=False, num_workers=0, use_uva=USE_UVA) durations = [] for _ in range(10): + model.train() t0 = time.time() - for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): + for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader): x = blocks[0].srcdata['feat'] y = blocks[-1].dstdata['label'][:, 0] y_hat = model(blocks, x) @@ -80,27 +108,38 @@ def train(rank, world_size, graph, num_classes, split_idx): if rank == 0: print(tt - t0) durations.append(tt - t0) + + model.eval() + ys = [] + y_hats = [] + for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader): + with torch.no_grad(): + x = blocks[0].srcdata['feat'] + ys.append(blocks[-1].dstdata['label']) + y_hats.append(model.module(blocks, x)) + acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys)) + print('Validation acc:', acc.item()) + dist.barrier() + if rank == 0: print(np.mean(durations[4:]), np.std(durations[4:])) + model.eval() + with torch.no_grad(): + pred = model.module.inference(graph, 'cuda', 1000, 12, graph.device) + acc = MF.accuracy(pred.to(graph.device), graph.ndata['label']) + print('Test acc:', acc.item()) if __name__ == '__main__': dataset = DglNodePropPredDataset('ogbn-products') graph, labels = dataset[0] - graph.ndata['label'] = labels - graph.create_formats_() + shared_memory_name = 'shm' # can be any string + feat = graph.ndata['feat'] + graph = graph.shared_memory(shared_memory_name) split_idx = dataset.get_idx_split() num_classes = dataset.num_classes n_procs = 4 # Tested with mp.spawn and fork. Both worked and got 4s per epoch with 4 GPUs # and 3.86s per epoch with 8 GPUs on p2.8x, compared to 5.2s from official examples. - #import torch.multiprocessing as mp - #mp.spawn(train, args=(n_procs, graph, num_classes, split_idx), nprocs=n_procs) - import dgl.multiprocessing as mp - procs = [] - for i in range(n_procs): - p = mp.Process(target=train, args=(i, n_procs, graph, num_classes, split_idx)) - p.start() - procs.append(p) - for p in procs: - p.join() + import torch.multiprocessing as mp + mp.spawn(train, args=(n_procs, shared_memory_name, (feat, labels), num_classes, split_idx), nprocs=n_procs) diff --git a/examples/pytorch/__temporary__/graphsage/link_pred.py b/examples/pytorch/__temporary__/graphsage/link_pred.py index 601044a69f38..7cdd2d71a478 100644 --- a/examples/pytorch/__temporary__/graphsage/link_pred.py +++ b/examples/pytorch/__temporary__/graphsage/link_pred.py @@ -6,20 +6,54 @@ import dgl.nn as dglnn import time import numpy as np +import tqdm # OGB must follow DGL if both DGL and PyG are installed. Otherwise DataLoader will hang. # (This is a long-standing issue) -from ogb.nodeproppred import DglNodePropPredDataset +from ogb.linkproppred import DglLinkPropPredDataset -USE_WRAPPER = True +USE_UVA = False +device = 'cuda' + +def to_bidirected_with_reverse_mapping(g): + """Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]`` + is the reverse edge of edge ID ``i``. + Does not work with graphs that have self-loops. + """ + g_simple, mapping = dgl.to_simple( + dgl.add_reverse_edges(g), return_counts='count', writeback_mapping=True) + c = g_simple.edata['count'] + num_edges = g.num_edges() + mapping_offset = torch.zeros(g_simple.num_edges() + 1, dtype=g_simple.idtype) + mapping_offset[1:] = c.cumsum(0) + idx = mapping.argsort() + idx_uniq = idx[mapping_offset[:-1]] + reverse_idx = torch.where(idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges) + reverse_mapping = mapping[reverse_idx] + + # Correctness check + src1, dst1 = g_simple.edges() + src2, dst2 = g_simple.find_edges(reverse_mapping) + assert torch.equal(src1, dst2) + assert torch.equal(src2, dst1) + return g_simple, reverse_mapping class SAGE(nn.Module): - def __init__(self, in_feats, n_hidden, n_classes): + def __init__(self, in_feats, n_hidden): super().__init__() + self.n_hidden = n_hidden self.layers = nn.ModuleList() self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) - self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) - self.dropout = nn.Dropout(0.5) + self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) + self.predictor = nn.Sequential( + nn.Linear(n_hidden, n_hidden), + nn.ReLU(), + nn.Linear(n_hidden, n_hidden), + nn.ReLU(), + nn.Linear(n_hidden, 1)) + + def predict(self, h_src, h_dst): + return self.predictor(h_src * h_dst) def forward(self, pair_graph, neg_pair_graph, blocks, x): h = x @@ -27,50 +61,88 @@ def forward(self, pair_graph, neg_pair_graph, blocks, x): h = layer(block, h) if l != len(self.layers) - 1: h = F.relu(h) - h = self.dropout(h) - with pair_graph.local_scope(), neg_pair_graph.local_scope(): - pair_graph.ndata['h'] = neg_pair_graph.ndata['h'] = h - pair_graph.apply_edges(dgl.function.u_dot_v('h', 'h', 's')) - neg_pair_graph.apply_edges(dgl.function.u_dot_v('h', 'h', 's')) - return pair_graph.edata['s'], neg_pair_graph.edata['s'] - -dataset = DglNodePropPredDataset('ogbn-products') -graph, labels = dataset[0] -graph.ndata['label'] = labels -split_idx = dataset.get_idx_split() -train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] - -model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda() + pos_src, pos_dst = pair_graph.edges() + neg_src, neg_dst = neg_pair_graph.edges() + h_pos = self.predict(h[pos_src], h[pos_dst]) + h_neg = self.predict(h[neg_src], h[neg_dst]) + return h_pos, h_neg + + def inference(self, g, device, batch_size, num_workers, buffer_device=None): + # The difference between this inference function and the one in the official + # example is that the intermediate results can also benefit from prefetching. + g.ndata['h'] = g.ndata['feat'] + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) + dataloader = dgl.dataloading.NodeDataLoader( + g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, + batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers) + if buffer_device is None: + buffer_device = device + + for l, layer in enumerate(self.layers): + y = torch.zeros(g.num_nodes(), self.n_hidden, device=buffer_device) + for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): + x = blocks[0].srcdata['h'] + h = layer(blocks[0], x) + if l != len(self.layers) - 1: + h = F.relu(h) + y[output_nodes] = h.to(buffer_device) + g.ndata['h'] = y + return y + + +def compute_mrr(model, node_emb, src, dst, neg_dst, device, batch_size=500): + rr = torch.zeros(src.shape[0]) + for start in tqdm.trange(0, src.shape[0], batch_size): + end = min(start + batch_size, src.shape[0]) + all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1) + h_src = node_emb[src[start:end]][:, None, :].to(device) + h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device) + pred = model.predict(h_src, h_dst).squeeze(-1) + relevance = torch.zeros(*pred.shape, dtype=torch.bool) + relevance[:, 0] = True + rr[start:end] = MF.retrieval_reciprocal_rank(pred, relevance) + return rr.mean() + + +def evaluate(model, edge_split, device, num_workers): + with torch.no_grad(): + node_emb = model.inference(graph, device, 4096, num_workers, 'cpu') + results = [] + for split in ['valid', 'test']: + src = edge_split[split]['source_node'].to(device) + dst = edge_split[split]['target_node'].to(device) + neg_dst = edge_split[split]['target_node_neg'].to(device) + results.append(compute_mrr(model, node_emb, src, dst, neg_dst, device)) + return results + + +dataset = DglLinkPropPredDataset('ogbl-citation2') +graph = dataset[0] +graph, reverse_eids = to_bidirected_with_reverse_mapping(graph) +seed_edges = torch.arange(graph.num_edges()) +edge_split = dataset.get_edge_split() + +model = SAGE(graph.ndata['feat'].shape[1], 256).to(device) opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) -num_edges = graph.num_edges() -train_eids = torch.arange(num_edges) -if USE_WRAPPER: - import dglnew - graph.create_formats_() - graph = dglnew.graph.wrapper.DGLGraphStorage(graph) +if not USE_UVA: + graph = graph.to(device) + reverse_eids = reverse_eids.to(device) + seed_edges = torch.arange(graph.num_edges()).to(device) -sampler = dgl.dataloading.NeighborSampler( - [5, 5, 5], output_device='cpu', prefetch_node_feats=['feat'], - prefetch_labels=['label']) +sampler = dgl.dataloading.NeighborSampler([15, 10, 5], prefetch_node_feats=['feat']) dataloader = dgl.dataloading.EdgeDataLoader( - graph, - train_eids, - sampler, - device='cuda', - batch_size=1000, - shuffle=True, - drop_last=False, - pin_memory=True, - num_workers=8, - persistent_workers=True, - use_prefetch_thread=True, # TBD: could probably remove this argument + graph, seed_edges, sampler, + device=device, batch_size=512, shuffle=True, + drop_last=False, num_workers=0, exclude='reverse_id', - reverse_eids=torch.arange(num_edges) ^ 1, - negative_sampler=dgl.dataloading.negative_sampler.Uniform(5)) + reverse_eids=reverse_eids, + negative_sampler=dgl.dataloading.negative_sampler.Uniform(1), + use_uva=USE_UVA) durations = [] -for _ in range(10): +for epoch in range(10): + model.train() t0 = time.time() for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(dataloader): x = blocks[0].srcdata['feat'] @@ -83,12 +155,16 @@ def forward(self, pair_graph, neg_pair_graph, blocks, x): opt.zero_grad() loss.backward() opt.step() - if it % 20 == 0: - acc = MF.auroc(score, labels.long()) + if (it + 1) % 20 == 0: mem = torch.cuda.max_memory_allocated() / 1000000 - print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') - tt = time.time() - print(tt - t0) - t0 = time.time() - durations.append(tt - t0) + print('Loss', loss.item(), 'GPU Mem', mem, 'MB') + if (it + 1) == 1000: + tt = time.time() + print(tt - t0) + durations.append(tt - t0) + break + if epoch % 10 == 0: + model.eval() + valid_mrr, test_mrr = evaluate(model, edge_split, device, 12) + print('Validation MRR:', valid_mrr.item(), 'Test MRR:', test_mrr.item()) print(np.mean(durations[4:]), np.std(durations[4:])) diff --git a/examples/pytorch/__temporary__/graphsage/normal.py b/examples/pytorch/__temporary__/graphsage/normal.py index 6f82809b4d3a..eee395104c60 100644 --- a/examples/pytorch/__temporary__/graphsage/normal.py +++ b/examples/pytorch/__temporary__/graphsage/normal.py @@ -7,8 +7,10 @@ import time import numpy as np from ogb.nodeproppred import DglNodePropPredDataset +import tqdm +import argparse -USE_WRAPPER = True +USE_UVA = True # Set to True for UVA sampling class SAGE(nn.Module): def __init__(self, in_feats, n_hidden, n_classes): @@ -18,6 +20,8 @@ def __init__(self, in_feats, n_hidden, n_classes): self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.dropout = nn.Dropout(0.5) + self.n_hidden = n_hidden + self.n_classes = n_classes def forward(self, blocks, x): h = x @@ -28,42 +32,64 @@ def forward(self, blocks, x): h = self.dropout(h) return h + def inference(self, g, device, batch_size, num_workers, buffer_device=None): + # The difference between this inference function and the one in the official + # example is that the intermediate results can also benefit from prefetching. + g.ndata['h'] = g.ndata['feat'] + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) + dataloader = dgl.dataloading.NodeDataLoader( + g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, + batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, + persistent_workers=(num_workers > 0)) + if buffer_device is None: + buffer_device = device + + for l, layer in enumerate(self.layers): + y = torch.zeros( + g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes, + device=buffer_device) + for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): + x = blocks[0].srcdata['h'] + h = layer(blocks[0], x) + if l != len(self.layers) - 1: + h = F.relu(h) + h = self.dropout(h) + y[output_nodes] = h.to(buffer_device) + g.ndata['h'] = y + return y + dataset = DglNodePropPredDataset('ogbn-products') graph, labels = dataset[0] -graph.ndata['label'] = labels +graph.ndata['label'] = labels.squeeze() split_idx = dataset.get_idx_split() train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] -model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda() -opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) +if not USE_UVA: + graph = graph.to('cuda') + train_idx = train_idx.to('cuda') + valid_idx = valid_idx.to('cuda') + test_idx = test_idx.to('cuda') +device = 'cuda' -if USE_WRAPPER: - import dglnew - graph.create_formats_() - graph = dglnew.graph.wrapper.DGLGraphStorage(graph) +model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).to(device) +opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) sampler = dgl.dataloading.NeighborSampler( - [5, 5, 5], output_device='cpu', prefetch_node_feats=['feat'], - prefetch_labels=['label']) -dataloader = dgl.dataloading.NodeDataLoader( - graph, - train_idx, - sampler, - device='cuda', - batch_size=1000, - shuffle=True, - drop_last=False, - pin_memory=True, - num_workers=16, - persistent_workers=True, - use_prefetch_thread=True) # TBD: could probably remove this argument + [15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label']) +train_dataloader = dgl.dataloading.NodeDataLoader( + graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True, + drop_last=False, num_workers=0, use_uva=USE_UVA) +valid_dataloader = dgl.dataloading.NodeDataLoader( + graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True, + drop_last=False, num_workers=0, use_uva=USE_UVA) durations = [] for _ in range(10): + model.train() t0 = time.time() - for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): + for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader): x = blocks[0].srcdata['feat'] - y = blocks[-1].dstdata['label'][:, 0] + y = blocks[-1].dstdata['label'] y_hat = model(blocks, x) loss = F.cross_entropy(y_hat, y) opt.zero_grad() @@ -76,4 +102,23 @@ def forward(self, blocks, x): tt = time.time() print(tt - t0) durations.append(tt - t0) + + model.eval() + ys = [] + y_hats = [] + for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader): + with torch.no_grad(): + x = blocks[0].srcdata['feat'] + ys.append(blocks[-1].dstdata['label']) + y_hats.append(model(blocks, x)) + acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys)) + print('Validation acc:', acc.item()) + print(np.mean(durations[4:]), np.std(durations[4:])) + +# Test accuracy and offline inference of all nodes +model.eval() +with torch.no_grad(): + pred = model.inference(graph, device, 4096, 12 if USE_UVA else 0, graph.device) + acc = MF.accuracy(pred.to(graph.device), graph.ndata['label']) + print('Test acc:', acc.item()) diff --git a/examples/pytorch/__temporary__/rgat/rgat.py b/examples/pytorch/__temporary__/rgat/rgat.py index 3a4364b2ccc7..a3760c989a15 100644 --- a/examples/pytorch/__temporary__/rgat/rgat.py +++ b/examples/pytorch/__temporary__/rgat/rgat.py @@ -11,7 +11,7 @@ from ogb.nodeproppred import DglNodePropPredDataset import tqdm -USE_WRAPPER = True +USE_UVA = False class HeteroGAT(nn.Module): def __init__(self, etypes, in_feats, n_hidden, n_classes, n_heads=4): @@ -52,44 +52,62 @@ def forward(self, blocks, x): graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='rev_writes') graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='has_topic') graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='affiliated_with') -graph.edges['cites'].data['weight'] = torch.ones(graph.num_edges('cites')) # dummy edge weights model = HeteroGAT(graph.etypes, graph.ndata['feat']['paper'].shape[1], 256, dataset.num_classes).cuda() opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) -if USE_WRAPPER: - import dglnew - graph.create_formats_() - graph = dglnew.graph.wrapper.DGLGraphStorage(graph) - split_idx = dataset.get_idx_split() train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] -sampler = dgl.dataloading.NeighborSampler( - [5, 5, 5], output_device='cpu', +if not USE_UVA: + graph = graph.to('cuda') + train_idx = recursive_apply(train_idx, lambda x: x.to('cuda')) + valid_idx = recursive_apply(valid_idx, lambda x: x.to('cuda')) + test_idx = recursive_apply(test_idx, lambda x: x.to('cuda')) + +train_sampler = dgl.dataloading.NeighborSampler( + [5, 5, 5], + prefetch_node_feats={k: ['feat'] for k in graph.ntypes}, + prefetch_labels={'paper': ['label']}) +valid_sampler = dgl.dataloading.NeighborSampler( + [10, 10, 10], # Slightly more prefetch_node_feats={k: ['feat'] for k in graph.ntypes}, - prefetch_labels={'paper': ['label']}, - prefetch_edge_feats={'cites': ['weight']}) -dataloader = dgl.dataloading.NodeDataLoader( - graph, - train_idx, - sampler, - device='cuda', - batch_size=1000, - shuffle=True, - drop_last=False, - pin_memory=True, - num_workers=8, - persistent_workers=True, - use_prefetch_thread=True) # TBD: could probably remove this argument + prefetch_labels={'paper': ['label']}) +train_dataloader = dgl.dataloading.NodeDataLoader( + graph, train_idx, train_sampler, + device='cuda', batch_size=1000, shuffle=True, + drop_last=False, num_workers=0, use_uva=USE_UVA) +valid_dataloader = dgl.dataloading.NodeDataLoader( + graph, valid_idx, valid_sampler, + device='cuda', batch_size=1000, shuffle=False, + drop_last=False, num_workers=0, use_uva=USE_UVA) +test_dataloader = dgl.dataloading.NodeDataLoader( + graph, test_idx, valid_sampler, + device='cuda', batch_size=1000, shuffle=False, + drop_last=False, num_workers=0, use_uva=USE_UVA) + +def evaluate(model, dataloader): + preds = [] + labels = [] + with torch.no_grad(): + for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): + x = blocks[0].srcdata['feat'] + y = blocks[-1].dstdata['label']['paper'][:, 0] + y_hat = model(blocks, x) + preds.append(y_hat) + labels.append(y) + preds = torch.cat(preds, 0) + labels = torch.cat(labels, 0) + acc = MF.accuracy(preds, labels) + return acc durations = [] for _ in range(10): + model.train() t0 = time.time() - for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): + for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader): x = blocks[0].srcdata['feat'] y = blocks[-1].dstdata['label']['paper'][:, 0] - assert y.min() >= 0 and y.max() < dataset.num_classes y_hat = model(blocks, x) loss = F.cross_entropy(y_hat, y) opt.zero_grad() @@ -102,4 +120,9 @@ def forward(self, blocks, x): tt = time.time() print(tt - t0) durations.append(tt - t0) + + model.eval() + valid_acc = evaluate(model, valid_dataloader) + test_acc = evaluate(model, test_dataloader) + print('Validation acc:', valid_acc, 'Test acc:', test_acc) print(np.mean(durations[4:]), np.std(durations[4:])) diff --git a/include/dgl/aten/coo.h b/include/dgl/aten/coo.h index 0830c21810e3..f9e187f9d38e 100644 --- a/include/dgl/aten/coo.h +++ b/include/dgl/aten/coo.h @@ -134,7 +134,7 @@ struct COOMatrix { * \brief Pin the row, col and data (if not Null) of the matrix. * \note This is an in-place method. Behavior depends on the current context, * kDLCPU: will be pinned; - * kDLCPUPinned: directly return; + * IsPinned: directly return; * kDLGPU: invalid, will throw an error. * The context check is deferred to pinning the NDArray. */ @@ -149,7 +149,7 @@ struct COOMatrix { /*! * \brief Unpin the row, col and data (if not Null) of the matrix. * \note This is an in-place method. Behavior depends on the current context, - * kDLCPUPinned: will be unpinned; + * IsPinned: will be unpinned; * others: directly return. * The context check is deferred to unpinning the NDArray. */ diff --git a/include/dgl/aten/csr.h b/include/dgl/aten/csr.h index 049adb9392a1..d194951f73bb 100644 --- a/include/dgl/aten/csr.h +++ b/include/dgl/aten/csr.h @@ -127,7 +127,7 @@ struct CSRMatrix { * \brief Pin the indptr, indices and data (if not Null) of the matrix. * \note This is an in-place method. Behavior depends on the current context, * kDLCPU: will be pinned; - * kDLCPUPinned: directly return; + * IsPinned: directly return; * kDLGPU: invalid, will throw an error. * The context check is deferred to pinning the NDArray. */ @@ -142,7 +142,7 @@ struct CSRMatrix { /*! * \brief Unpin the indptr, indices and data (if not Null) of the matrix. * \note This is an in-place method. Behavior depends on the current context, - * kDLCPUPinned: will be unpinned; + * IsPinned: will be unpinned; * others: directly return. * The context check is deferred to unpinning the NDArray. */ diff --git a/include/dgl/aten/macro.h b/include/dgl/aten/macro.h index 89dc8033886c..c845f4cfe876 100644 --- a/include/dgl/aten/macro.h +++ b/include/dgl/aten/macro.h @@ -43,7 +43,7 @@ */ #ifdef DGL_USE_CUDA #define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do { \ - if ((val) == kDLCPU || (val) == kDLCPUPinned) { \ + if ((val) == kDLCPU) { \ constexpr auto XPU = kDLCPU; \ {__VA_ARGS__} \ } else if ((val) == kDLGPU) { \ @@ -233,7 +233,7 @@ }); #define CHECK_VALID_CONTEXT(VAR1, VAR2) \ - CHECK(((VAR1)->ctx == (VAR2)->ctx) || ((VAR1)->ctx.device_type == kDLCPUPinned)) \ + CHECK(((VAR1)->ctx == (VAR2)->ctx) || (VAR1).IsPinned()) \ << "Expected " << (#VAR2) << "(" << (VAR2)->ctx << ")" << " to have the same device " \ << "context as " << (#VAR1) << "(" << (VAR1)->ctx << "). " \ << "Or " << (#VAR1) << "(" << (VAR1)->ctx << ")" << " is pinned"; @@ -246,7 +246,7 @@ * If csr is pinned, array's context will conduct the actual operation. */ #define ATEN_CSR_SWITCH_CUDA_UVA(csr, array, XPU, IdType, op, ...) do { \ - CHECK_VALID_CONTEXT(csr.indices, array); \ + CHECK_VALID_CONTEXT(csr.indices, array); \ ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, op, { \ ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \ {__VA_ARGS__} \ @@ -264,7 +264,7 @@ }); // Macro to dispatch according to device context and index type. -#define ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, op, ...) \ +#define ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, op, ...) \ ATEN_XPU_SWITCH_CUDA((coo).row->ctx.device_type, XPU, op, { \ ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, { \ {__VA_ARGS__} \ diff --git a/include/dgl/base_heterograph.h b/include/dgl/base_heterograph.h index e05db8a49efa..24b96002eed3 100644 --- a/include/dgl/base_heterograph.h +++ b/include/dgl/base_heterograph.h @@ -23,6 +23,7 @@ namespace dgl { // Forward declaration class BaseHeteroGraph; +class HeteroPickleStates; typedef std::shared_ptr HeteroGraphPtr; struct FlattenedHeteroGraph; @@ -436,6 +437,21 @@ class BaseHeteroGraph : public runtime::Object { */ virtual aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const = 0; + /*! + * \brief Set the COO matrix representation for a given edge type. + */ + virtual void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) = 0; + + /*! + * \brief Set the CSR matrix representation for a given edge type. + */ + virtual void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) = 0; + + /*! + * \brief Set the CSC matrix representation for a given edge type. + */ + virtual void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) = 0; + /*! * \brief Extract the induced subgraph by the given vertices. * @@ -864,6 +880,25 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph); */ HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states); +/*! + * \brief Create heterograph from pickling states pickled by ForkingPickler. + * + * This is different from HeteroUnpickle where + * (1) Backward compatibility is not required, + * (2) All graph formats are pickled instead of only one. + */ +HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states); + +/*! + * \brief Get the pickling states of the relation graph structure in backend tensors for + * ForkingPickler. + * + * This is different from HeteroPickle where + * (1) Backward compatibility is not required, + * (2) All graph formats are pickled instead of only one. + */ +HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph); + #define FORMAT_HAS_CSC(format) \ ((format) & CSC_CODE) diff --git a/include/dgl/runtime/device_api.h b/include/dgl/runtime/device_api.h index 83a00a0d3629..5151140500f6 100644 --- a/include/dgl/runtime/device_api.h +++ b/include/dgl/runtime/device_api.h @@ -160,6 +160,13 @@ class DeviceAPI { */ DGL_DLL virtual void UnpinData(void* ptr); + /*! + * \brief Check whether the memory is in pinned memory. + */ + DGL_DLL virtual bool IsPinned(const void* ptr) { + return false; + } + /*! * \brief Allocate temporal workspace for backend execution. * diff --git a/include/dgl/runtime/ndarray.h b/include/dgl/runtime/ndarray.h index 599a26212e61..520f4479c1f6 100644 --- a/include/dgl/runtime/ndarray.h +++ b/include/dgl/runtime/ndarray.h @@ -176,7 +176,7 @@ class NDArray { * on the underlying DLTensor. * \note This is an in-place method. Behavior depends on the current context, * kDLCPU: will be pinned; - * kDLCPUPinned: directly return; + * IsPinned: directly return; * kDLGPU: invalid, will throw an error. */ inline void PinMemory_(); @@ -184,7 +184,7 @@ class NDArray { * \brief In-place method to unpin the current array by calling UnpinData * on the underlying DLTensor. * \note This is an in-place method. Behavior depends on the current context, - * kDLCPUPinned: will be unpinned; + * IsPinned: will be unpinned; * others: directly return. */ inline void UnpinMemory_(); @@ -299,7 +299,7 @@ class NDArray { * \note Data of the given array will be pinned inplace. * Behavior depends on the current context, * kDLCPU: will be pinned; - * kDLCPUPinned: directly return; + * IsPinned: directly return; * kDLGPU: invalid, will throw an error. */ DGL_DLL static void PinData(DLTensor* tensor); @@ -309,11 +309,18 @@ class NDArray { * \param tensor The array to be unpinned. * \note Data of the given array will be unpinned inplace. * Behavior depends on the current context, - * kDLCPUPinned: will be unpinned; + * IsPinned: will be unpinned; * others: directly return. */ DGL_DLL static void UnpinData(DLTensor* tensor); + /*! + * \brief Function check if the data of a DLTensor is pinned. + * \param tensor The array to be checked. + * \return true if pinned. + */ + DGL_DLL static bool IsDataPinned(DLTensor* tensor); + // internal namespace struct Internal; private: @@ -485,7 +492,7 @@ inline void NDArray::UnpinMemory_() { inline bool NDArray::IsPinned() const { CHECK(data_ != nullptr); - return data_->dl_tensor.ctx.device_type == kDLCPUPinned; + return IsDataPinned(&(data_->dl_tensor)); } inline int NDArray::use_count() const { diff --git a/python/dgl/_ffi/_ctypes/ndarray.py b/python/dgl/_ffi/_ctypes/ndarray.py index 0055ae20da11..900a40fcf617 100644 --- a/python/dgl/_ffi/_ctypes/ndarray.py +++ b/python/dgl/_ffi/_ctypes/ndarray.py @@ -82,8 +82,6 @@ def to_dlpack(self, alignment=0): Indicates the alignment requirement when converting to dlpack. Will copy to a new tensor if the alignment requirement is not satisfied. 0 means no alignment requirement. - Will copy to a new tensor if the array is pinned because some backends, - e.g., pytorch, do not support kDLCPUPinned device type. Returns diff --git a/python/dgl/_ffi/ndarray.py b/python/dgl/_ffi/ndarray.py index dd588c2f6a7d..3a90e046c224 100644 --- a/python/dgl/_ffi/ndarray.py +++ b/python/dgl/_ffi/ndarray.py @@ -316,25 +316,15 @@ def copyto(self, target): raise ValueError("Unsupported target type %s" % str(type(target))) return target - def pin_memory_(self, ctx): + def pin_memory_(self): """Pin host memory and map into GPU address space (in-place) - - Parameters - ---------- - ctx : DGLContext - The target GPU to map the host memory space """ - check_call(_LIB.DGLArrayPinData(self.handle, ctx)) + check_call(_LIB.DGLArrayPinData(self.handle)) - def unpin_memory_(self, ctx): + def unpin_memory_(self): """Unpin host memory pinned by pin_memory_() - - Parameters - ---------- - ctx : DGLContext - The target GPU to map the host memory space """ - check_call(_LIB.DGLArrayUnpinData(self.handle, ctx)) + check_call(_LIB.DGLArrayUnpinData(self.handle)) def free_extension_handle(handle, type_code): diff --git a/python/dgl/backend/backend.py b/python/dgl/backend/backend.py index 0f8b2657a074..7e3501e3c3be 100644 --- a/python/dgl/backend/backend.py +++ b/python/dgl/backend/backend.py @@ -330,6 +330,21 @@ def copy_to(input, ctx, **kwargs): """ pass +def is_pinned(input): + """Check whether the tensor is in pinned memory. + + Parameters + ---------- + input : Tensor + The tensor. + + Returns + ------- + bool + Whether the tensor is in pinned memory. + """ + pass + ############################################################################### # Tensor functions on feature data # -------------------------------- diff --git a/python/dgl/backend/mxnet/tensor.py b/python/dgl/backend/mxnet/tensor.py index 5c35c5373fa0..1f713cc13f8d 100644 --- a/python/dgl/backend/mxnet/tensor.py +++ b/python/dgl/backend/mxnet/tensor.py @@ -144,6 +144,9 @@ def asnumpy(input): def copy_to(input, ctx, **kwargs): return input.as_in_context(ctx) +def is_pinned(input): + return input.context == mx.cpu_pinned() + def sum(input, dim, keepdims=False): if len(input) == 0: return nd.array([0.], dtype=input.dtype, ctx=input.context) diff --git a/python/dgl/backend/pytorch/tensor.py b/python/dgl/backend/pytorch/tensor.py index 2359894cd0f0..254e8d853d5a 100644 --- a/python/dgl/backend/pytorch/tensor.py +++ b/python/dgl/backend/pytorch/tensor.py @@ -120,6 +120,9 @@ def copy_to(input, ctx, **kwargs): else: raise RuntimeError('Invalid context', ctx) +def is_pinned(input): + return input.is_pinned() + def sum(input, dim, keepdims=False): return th.sum(input, dim=dim, keepdim=keepdims) diff --git a/python/dgl/backend/tensorflow/tensor.py b/python/dgl/backend/tensorflow/tensor.py index fed68dd8956e..cc22029399ed 100644 --- a/python/dgl/backend/tensorflow/tensor.py +++ b/python/dgl/backend/tensorflow/tensor.py @@ -162,6 +162,8 @@ def copy_to(input, ctx, **kwargs): new_tensor = tf.identity(input) return new_tensor +def is_pinned(input): + return False # not sure how to do this def sum(input, dim, keepdims=False): if input.dtype == tf.bool: diff --git a/python/dgl/contrib/unified_tensor.py b/python/dgl/contrib/unified_tensor.py index 49b5c6a852fa..dedb0ccbd041 100644 --- a/python/dgl/contrib/unified_tensor.py +++ b/python/dgl/contrib/unified_tensor.py @@ -75,7 +75,7 @@ def __init__(self, input, device): self._array = F.zerocopy_to_dgl_ndarray(self._input) self._device = device - self._array.pin_memory_(utils.to_dgl_context(self._device)) + self._array.pin_memory_() def __len__(self): return len(self._array) @@ -105,7 +105,7 @@ def __setitem__(self, key, val): def __del__(self): if hasattr(self, '_array') and self._array != None: - self._array.unpin_memory_(utils.to_dgl_context(self._device)) + self._array.unpin_memory_() self._array = None if hasattr(self, '_input'): diff --git a/python/dgl/dataloading/base.py b/python/dgl/dataloading/base.py index b0ebe2c368b0..bc2c7e2aa2d6 100644 --- a/python/dgl/dataloading/base.py +++ b/python/dgl/dataloading/base.py @@ -124,6 +124,8 @@ def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map): def _find_exclude_eids(g, exclude_mode, eids, **kwargs): if exclude_mode is None: return None + elif callable(exclude_mode): + return exclude_mode(eids) elif F.is_tensor(exclude_mode) or ( isinstance(exclude_mode, Mapping) and all(F.is_tensor(v) for v in exclude_mode.values())): @@ -151,9 +153,6 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes= None (default) Does not exclude any edge. - Tensor or dict[etype, Tensor] - Exclude the given edge IDs. - 'self' Exclude the given edges themselves but nothing else. @@ -176,6 +175,10 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes= This mode assumes that the reverse of an edge with ID ``e`` and type ``etype`` will have ID ``e`` and type ``reverse_etype_map[etype]``. + + callable + Any function that takes in a single argument :attr:`seed_edges` and returns + a tensor or dict of tensors. eids : Tensor or dict[etype, Tensor] The edge IDs. reverse_eids : Tensor or dict[etype, Tensor] @@ -191,9 +194,8 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes= seed_edges, reverse_eid_map=reverse_eids, reverse_etype_map=reverse_etypes) - if exclude_eids is not None: - exclude_eids = recursive_apply( - exclude_eids, lambda x: x.to(output_device)) + if exclude_eids is not None and output_device is not None: + exclude_eids = recursive_apply(exclude_eids, lambda x: F.copy_to(x, output_device)) return exclude_eids @@ -202,8 +204,8 @@ class EdgeBlockSampler(object): classification and link prediction. """ def __init__(self, block_sampler, exclude=None, reverse_eids=None, - reverse_etypes=None, negative_sampler=None, prefetch_node_feats=None, - prefetch_labels=None, prefetch_edge_feats=None): + reverse_etypes=None, negative_sampler=None, + prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None,): self.reverse_eids = reverse_eids self.reverse_etypes = reverse_etypes self.exclude = exclude @@ -249,6 +251,8 @@ def sample(self, g, seed_edges): If :attr:`negative_sampler` is given, also returns another graph containing the negative pairs as edges. """ + if isinstance(seed_edges, Mapping): + seed_edges = {g.to_canonical_etype(k): v for k, v in seed_edges.items()} exclude = self.exclude pair_graph = g.edge_subgraph( seed_edges, relabel_nodes=False, output_device=self.output_device) diff --git a/python/dgl/dataloading/cluster_gcn.py b/python/dgl/dataloading/cluster_gcn.py index 1cb41ab8614d..401954b5e8fe 100644 --- a/python/dgl/dataloading/cluster_gcn.py +++ b/python/dgl/dataloading/cluster_gcn.py @@ -55,7 +55,7 @@ def __init__(self, g, k, balance_ntypes=None, balance_edges=False, mode='k-way', partition_node_ids = np.argsort(partition_ids) partition_size = F.zerocopy_from_numpy(np.bincount(partition_ids, minlength=k)) partition_offset = F.zerocopy_from_numpy(np.insert(np.cumsum(partition_size), 0, 0)) - partition_node_ids = F.zerocopy_from_numpy(partition_ids) + partition_node_ids = F.zerocopy_from_numpy(partition_node_ids) with open(cache_path, 'wb') as f: pickle.dump((partition_offset, partition_node_ids), f) self.partition_offset = partition_offset diff --git a/python/dgl/dataloading/dataloader.py b/python/dgl/dataloading/dataloader.py index 45c22125c441..45395e6f9563 100644 --- a/python/dgl/dataloading/dataloader.py +++ b/python/dgl/dataloading/dataloader.py @@ -1,6 +1,6 @@ """DGL PyTorch DataLoaders""" from collections.abc import Mapping, Sequence -from queue import Queue +from queue import Queue, Empty, Full import itertools import threading from distutils.version import LooseVersion @@ -8,23 +8,33 @@ import math import inspect import re +import atexit +import os import torch import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler -from ..base import NID, EID, dgl_warning +from ..base import NID, EID from ..batch import batch as batch_graphs from ..heterograph import DGLHeteroGraph from .. import ndarray as nd from ..utils import ( recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads, - create_shared_mem_array, get_shared_mem_array) + create_shared_mem_array, get_shared_mem_array, context_of, pin_memory_inplace) from ..frame import LazyFeature from ..storages import wrap_storage from .base import BlockSampler, EdgeBlockSampler from .. import backend as F +PYTHON_EXIT_STATUS = False +def _set_python_exit_flag(): + global PYTHON_EXIT_STATUS + PYTHON_EXIT_STATUS = True +atexit.register(_set_python_exit_flag) + +prefetcher_timeout = int(os.environ.get('DGL_PREFETCHER_TIMEOUT', '10')) + class _TensorizedDatasetIter(object): def __init__(self, dataset, batch_size, drop_last, mapping_keys): self.dataset = dataset @@ -54,7 +64,8 @@ def _next_indices(self): def __next__(self): batch = self._next_indices() if self.mapping_keys is None: - return batch + # clone() fixes #3755, probably. Not sure why. Need to take a look afterwards. + return batch.clone() # convert the type-ID pairs to dictionary type_ids = batch[:, 0] @@ -67,28 +78,31 @@ def __next__(self): type_id_offset = type_id_count.cumsum(0).tolist() type_id_offset.insert(0, 0) id_dict = { - self.mapping_keys[type_id_uniq[i]]: indices[type_id_offset[i]:type_id_offset[i+1]] + self.mapping_keys[type_id_uniq[i]]: + indices[type_id_offset[i]:type_id_offset[i+1]].clone() for i in range(len(type_id_uniq))} return id_dict def _get_id_tensor_from_mapping(indices, device, keys): lengths = torch.LongTensor([ - (indices[k].shape[0] if k in indices else 0) for k in keys], device=device) + (indices[k].shape[0] if k in indices else 0) for k in keys]).to(device) type_ids = torch.arange(len(keys), device=device).repeat_interleave(lengths) all_indices = torch.cat([indices[k] for k in keys if k in indices]) return torch.stack([type_ids, all_indices], 1) -def _divide_by_worker(dataset): +def _divide_by_worker(dataset, batch_size, drop_last): num_samples = dataset.shape[0] worker_info = torch.utils.data.get_worker_info() if worker_info: - chunk_size = num_samples // worker_info.num_workers - left_over = num_samples % worker_info.num_workers - start = (chunk_size * worker_info.id) + min(left_over, worker_info.id) - end = start + chunk_size + (worker_info.id < left_over) - assert worker_info.id < worker_info.num_workers - 1 or end == num_samples + num_batches = (num_samples + (0 if drop_last else batch_size - 1)) // batch_size + num_batches_per_worker = num_batches // worker_info.num_workers + left_over = num_batches % worker_info.num_workers + start = (num_batches_per_worker * worker_info.id) + min(left_over, worker_info.id) + end = start + num_batches_per_worker + (worker_info.id < left_over) + start *= batch_size + end = min(end * batch_size, num_samples) dataset = dataset[start:end] return dataset @@ -98,31 +112,39 @@ class TensorizedDataset(torch.utils.data.IterableDataset): When the dataset is on the GPU, this significantly reduces the overhead. """ def __init__(self, indices, batch_size, drop_last): + name, _ = _generate_shared_mem_name_id() if isinstance(indices, Mapping): self._mapping_keys = list(indices.keys()) self._device = next(iter(indices.values())).device - self._tensor_dataset = _get_id_tensor_from_mapping( + self._id_tensor = _get_id_tensor_from_mapping( indices, self._device, self._mapping_keys) else: - self._tensor_dataset = indices + self._id_tensor = indices self._device = indices.device self._mapping_keys = None + # Use a shared memory array to permute indices for shuffling. This is to make sure that + # the worker processes can see it when persistent_workers=True, where self._indices + # would not be duplicated every epoch. + self._indices = create_shared_mem_array(name, (self._id_tensor.shape[0],), torch.int64) + self._indices[:] = torch.arange(self._id_tensor.shape[0]) self.batch_size = batch_size self.drop_last = drop_last + self.shared_mem_name = name + self.shared_mem_size = self._indices.shape[0] def shuffle(self): """Shuffle the dataset.""" # TODO: may need an in-place shuffle kernel - perm = torch.randperm(self._tensor_dataset.shape[0], device=self._device) - self._tensor_dataset[:] = self._tensor_dataset[perm] + self._indices[:] = self._indices[torch.randperm(self._indices.shape[0])] def __iter__(self): - dataset = _divide_by_worker(self._tensor_dataset) + indices = _divide_by_worker(self._indices, self.batch_size, self.drop_last) + id_tensor = self._id_tensor[indices.to(self._device)] return _TensorizedDatasetIter( - dataset, self.batch_size, self.drop_last, self._mapping_keys) + id_tensor, self.batch_size, self.drop_last, self._mapping_keys) def __len__(self): - num_samples = self._tensor_dataset.shape[0] + num_samples = self._id_tensor.shape[0] return (num_samples + (0 if self.drop_last else (self.batch_size - 1))) // self.batch_size def _get_shared_mem_name(id_): @@ -168,20 +190,20 @@ def __init__(self, indices, batch_size, drop_last, ddp_seed): self.shared_mem_size = self.total_size if not self.drop_last else len(indices) self.num_indices = len(indices) + if isinstance(indices, Mapping): + self._device = next(iter(indices.values())).device + self._id_tensor = _get_id_tensor_from_mapping( + indices, self._device, self._mapping_keys) + else: + self._id_tensor = indices + self._device = self._id_tensor.device + if self.rank == 0: name, id_ = _generate_shared_mem_name_id() - if isinstance(indices, Mapping): - device = next(iter(indices.values())).device - id_tensor = _get_id_tensor_from_mapping(indices, device, self._mapping_keys) - self._tensor_dataset = create_shared_mem_array( - name, (self.shared_mem_size, 2), torch.int64) - self._tensor_dataset[:id_tensor.shape[0], :] = id_tensor - else: - self._tensor_dataset = create_shared_mem_array( - name, (self.shared_mem_size,), torch.int64) - self._tensor_dataset[:len(indices)] = indices - self._device = self._tensor_dataset.device - meta_info = torch.LongTensor([id_, self._tensor_dataset.shape[0]]) + self._indices = create_shared_mem_array( + name, (self.shared_mem_size,), torch.int64) + self._indices[:self._id_tensor.shape[0]] = torch.arange(self._id_tensor.shape[0]) + meta_info = torch.LongTensor([id_, self._indices.shape[0]]) else: meta_info = torch.LongTensor([0, 0]) @@ -194,43 +216,41 @@ def __init__(self, indices, batch_size, drop_last, ddp_seed): if self.rank != 0: id_, num_samples = meta_info.tolist() name = _get_shared_mem_name(id_) - if isinstance(indices, Mapping): - indices_shared = get_shared_mem_array(name, (num_samples, 2), torch.int64) - else: - indices_shared = get_shared_mem_array(name, (num_samples,), torch.int64) - self._tensor_dataset = indices_shared - self._device = indices_shared.device + indices_shared = get_shared_mem_array(name, (num_samples,), torch.int64) + self._indices = indices_shared + self.shared_mem_name = name def shuffle(self): """Shuffles the dataset.""" # Only rank 0 does the actual shuffling. The other ranks wait for it. if self.rank == 0: - self._tensor_dataset[:self.num_indices] = self._tensor_dataset[ + self._indices[:self.num_indices] = self._indices[ torch.randperm(self.num_indices, device=self._device)] if not self.drop_last: # pad extra - self._tensor_dataset[self.num_indices:] = \ - self._tensor_dataset[:self.total_size - self.num_indices] + self._indices[self.num_indices:] = \ + self._indices[:self.total_size - self.num_indices] dist.barrier() def __iter__(self): start = self.num_samples * self.rank end = self.num_samples * (self.rank + 1) - dataset = _divide_by_worker(self._tensor_dataset[start:end]) + indices = _divide_by_worker(self._indices[start:end], self.batch_size, self.drop_last) + id_tensor = self._id_tensor[indices.to(self._device)] return _TensorizedDatasetIter( - dataset, self.batch_size, self.drop_last, self._mapping_keys) + id_tensor, self.batch_size, self.drop_last, self._mapping_keys) def __len__(self): return (self.num_samples + (0 if self.drop_last else (self.batch_size - 1))) // \ self.batch_size -def _prefetch_update_feats(feats, frames, types, get_storage_func, id_name, device, pin_memory): +def _prefetch_update_feats(feats, frames, types, get_storage_func, id_name, device, pin_prefetcher): for tid, frame in enumerate(frames): type_ = types[tid] default_id = frame.get(id_name, None) for key in frame.keys(): - column = frame[key] + column = frame._columns[key] if isinstance(column, LazyFeature): parent_key = column.name or key if column.id_ is None and default_id is None: @@ -238,7 +258,7 @@ def _prefetch_update_feats(feats, frames, types, get_storage_func, id_name, devi 'Found a LazyFeature with no ID specified, ' 'and the graph does not have dgl.NID or dgl.EID columns') feats[tid, key] = get_storage_func(parent_key, type_).fetch( - column.id_ or default_id, device, pin_memory) + column.id_ or default_id, device, pin_prefetcher) # This class exists to avoid recursion into the feature dictionary returned by the @@ -254,10 +274,10 @@ def _prefetch_for_subgraph(subg, dataloader): node_feats, edge_feats = {}, {} _prefetch_update_feats( node_feats, subg._node_frames, subg.ntypes, dataloader.graph.get_node_storage, - NID, dataloader.device, dataloader.pin_memory) + NID, dataloader.device, dataloader.pin_prefetcher) _prefetch_update_feats( edge_feats, subg._edge_frames, subg.canonical_etypes, dataloader.graph.get_edge_storage, - EID, dataloader.device, dataloader.pin_memory) + EID, dataloader.device, dataloader.pin_prefetcher) return _PrefetchedGraphFeatures(node_feats, edge_feats) @@ -266,7 +286,7 @@ def _prefetch_for(item, dataloader): return _prefetch_for_subgraph(item, dataloader) elif isinstance(item, LazyFeature): return dataloader.other_storages[item.name].fetch( - item.id_, dataloader.device, dataloader.pin_memory) + item.id_, dataloader.device, dataloader.pin_prefetcher) else: return None @@ -313,8 +333,17 @@ def _assign_for(item, feat): else: return item - -def _prefetcher_entry(dataloader_it, dataloader, queue, num_threads, use_alternate_streams): +def _put_if_event_not_set(queue, result, event): + while not event.is_set(): + try: + queue.put(result, timeout=1.0) + break + except Full: + continue + +def _prefetcher_entry( + dataloader_it, dataloader, queue, num_threads, use_alternate_streams, + done_event): # PyTorch will set the number of threads to 1 which slows down pin_memory() calls # in main process if a prefetching thread is created. if num_threads is not None: @@ -327,20 +356,27 @@ def _prefetcher_entry(dataloader_it, dataloader, queue, num_threads, use_alterna stream = None try: - for batch in dataloader_it: + while not done_event.is_set(): + try: + batch = next(dataloader_it) + except StopIteration: + break batch = recursive_apply(batch, restore_parent_storage_columns, dataloader.graph) feats = _prefetch(batch, dataloader, stream) - queue.put(( + _put_if_event_not_set(queue, ( # batch will be already in pinned memory as per the behavior of # PyTorch DataLoader. - recursive_apply(batch, lambda x: x.to(dataloader.device, non_blocking=True)), + recursive_apply( + batch, lambda x: x.to(dataloader.device, non_blocking=True)), feats, stream.record_event() if stream is not None else None, - None)) - queue.put((None, None, None, None)) + None), + done_event) + _put_if_event_not_set(queue, (None, None, None, None), done_event) except: # pylint: disable=bare-except - queue.put((None, None, None, ExceptionWrapper(where='in prefetcher'))) + _put_if_event_not_set( + queue, (None, None, None, ExceptionWrapper(where='in prefetcher')), done_event) # DGLHeteroGraphs have the semantics of lazy feature slicing with subgraphs. Such behavior depends @@ -400,15 +436,18 @@ def __init__(self, dataloader, dataloader_it, use_thread=False, use_alternate_st self.dataloader_it = dataloader_it self.dataloader = dataloader self.graph_sampler = self.dataloader.graph_sampler - self.pin_memory = self.dataloader.pin_memory + self.pin_prefetcher = self.dataloader.pin_prefetcher self.num_threads = num_threads self.use_thread = use_thread self.use_alternate_streams = use_alternate_streams + self._shutting_down = False if use_thread: + self._done_event = threading.Event() thread = threading.Thread( target=_prefetcher_entry, - args=(dataloader_it, dataloader, self.queue, num_threads, use_alternate_streams), + args=(dataloader_it, dataloader, self.queue, num_threads, + use_alternate_streams, self._done_event), daemon=True) thread.start() self.thread = thread @@ -416,6 +455,31 @@ def __init__(self, dataloader, dataloader_it, use_thread=False, use_alternate_st def __iter__(self): return self + def _shutdown(self): + # Sometimes when Python is exiting complicated operations like + # self.queue.get_nowait() will hang. So we set it to no-op and let Python handle + # the rest since the thread is daemonic. + # PyTorch takes the same solution. + if PYTHON_EXIT_STATUS is True or PYTHON_EXIT_STATUS is None: + return + if not self._shutting_down: + try: + self._shutting_down = True + self._done_event.set() + + try: + self.queue.get_nowait() # In case the thread is blocking on put(). + except: # pylint: disable=bare-except + pass + + self.thread.join() + except: # pylint: disable=bare-except + pass + + def __del__(self): + if self.use_thread: + self._shutdown() + def _next_non_threaded(self): batch = next(self.dataloader_it) batch = recursive_apply(batch, restore_parent_storage_columns, self.dataloader.graph) @@ -430,7 +494,11 @@ def _next_non_threaded(self): return batch, feats, stream_event def _next_threaded(self): - batch, feats, stream_event, exception = self.queue.get() + try: + batch, feats, stream_event, exception = self.queue.get(timeout=prefetcher_timeout) + except Empty: + raise RuntimeError( + f'Prefetcher thread timed out at {prefetcher_timeout} seconds.') if batch is None: self.thread.join() if exception is None: @@ -485,23 +553,100 @@ def create_tensorized_dataset(indices, batch_size, drop_last, use_ddp, ddp_seed) return TensorizedDataset(indices, batch_size, drop_last) +def _get_device(device): + device = torch.device(device) + if device.type == 'cuda' and device.index is None: + device = torch.device('cuda', torch.cuda.current_device()) + return device + class DataLoader(torch.utils.data.DataLoader): """DataLoader class.""" def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False, ddp_seed=0, batch_size=1, drop_last=False, shuffle=False, - use_prefetch_thread=False, use_alternate_streams=True, **kwargs): + use_prefetch_thread=None, use_alternate_streams=None, + pin_prefetcher=None, use_uva=False, **kwargs): + # (BarclayII) I hoped that pin_prefetcher can be merged into PyTorch's native + # pin_memory argument. But our neighbor samplers and subgraph samplers + # return indices, which could be CUDA tensors (e.g. during UVA sampling) + # hence cannot be pinned. PyTorch's native pin memory thread does not ignore + # CUDA tensors when pinning and will crash. To enable pin memory for prefetching + # features and disable pin memory for sampler's return value, I had to use + # a different argument. Of course I could change the meaning of pin_memory + # to pinning prefetched features and disable pin memory for sampler's returns + # no matter what, but I doubt if it's reasonable. self.graph = graph + self.indices = indices # For PyTorch-Lightning + num_workers = kwargs.get('num_workers', 0) try: if isinstance(indices, Mapping): indices = {k: (torch.tensor(v) if not torch.is_tensor(v) else v) for k, v in indices.items()} + indices_device = next(iter(indices.values())).device else: indices = torch.tensor(indices) if not torch.is_tensor(indices) else indices + indices_device = indices.device except: # pylint: disable=bare-except # ignore when it fails to convert to torch Tensors. pass + self.device = _get_device(device) + + # Sanity check - we only check for DGLGraphs. + if isinstance(self.graph, DGLHeteroGraph): + # Check graph and indices device as well as num_workers + if use_uva: + if self.graph.device.type != 'cpu': + raise ValueError('Graph must be on CPU if UVA sampling is enabled.') + if num_workers > 0: + raise ValueError('num_workers must be 0 if UVA sampling is enabled.') + + # Create all the formats and pin the features - custom GraphStorages + # will need to do that themselves. + self.graph.create_formats_() + self.graph.pin_memory_() + for frame in itertools.chain(self.graph._node_frames, self.graph._edge_frames): + for col in frame._columns.values(): + pin_memory_inplace(col.data) + + indices = recursive_apply(indices, lambda x: x.to(self.device)) + else: + if self.graph.device != indices_device: + raise ValueError( + 'Expect graph and indices to be on the same device. ' + 'If you wish to use UVA sampling, please set use_uva=True.') + if self.graph.device.type == 'cuda': + if num_workers > 0: + raise ValueError('num_workers must be 0 if graph and indices are on CUDA.') + + # Check pin_prefetcher and use_prefetch_thread - should be only effective + # if performing CPU sampling but output device is CUDA + if not (self.device.type == 'cuda' and self.graph.device.type == 'cpu'): + if pin_prefetcher is True: + raise ValueError( + 'pin_prefetcher=True is only effective when device=cuda and ' + 'sampling is performed on CPU.') + if pin_prefetcher is None: + pin_prefetcher = False + + if use_prefetch_thread is True: + raise ValueError( + 'use_prefetch_thread=True is only effective when device=cuda and ' + 'sampling is performed on CPU.') + if pin_prefetcher is None: + pin_prefetcher = False + else: + if pin_prefetcher is None: + pin_prefetcher = True + if use_prefetch_thread is None: + use_prefetch_thread = True + + # Check use_alternate_streams + if use_alternate_streams is None: + use_alternate_streams = ( + self.device.type == 'cuda' and self.graph.device.type == 'cpu' and + not use_uva) + if (torch.is_tensor(indices) or ( isinstance(indices, Mapping) and all(torch.is_tensor(v) for v in indices.values()))): @@ -511,17 +656,18 @@ def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False, self.dataset = indices self.ddp_seed = ddp_seed - self._shuffle_dataset = shuffle + self.use_ddp = use_ddp + self.use_uva = use_uva + self.shuffle = shuffle + self.drop_last = drop_last self.graph_sampler = graph_sampler - self.device = torch.device(device) self.use_alternate_streams = use_alternate_streams - if self.device.type == 'cuda' and self.device.index is None: - self.device = torch.device('cuda', torch.cuda.current_device()) + self.pin_prefetcher = pin_prefetcher self.use_prefetch_thread = use_prefetch_thread worker_init_fn = WorkerInitWrapper(kwargs.get('worker_init_fn', None)) # Instantiate all the formats if the number of workers is greater than 0. - if kwargs.get('num_workers', 0) > 0 and hasattr(self.graph, 'create_formats_'): + if num_workers > 0 and hasattr(self.graph, 'create_formats_'): self.graph.create_formats_() self.other_storages = {} @@ -534,7 +680,7 @@ def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False, **kwargs) def __iter__(self): - if self._shuffle_dataset: + if self.shuffle: self.dataset.shuffle() # When using multiprocessing PyTorch sometimes set the number of PyTorch threads to 1 # when spawning new Python threads. This drastically slows down pinning features. @@ -551,30 +697,377 @@ def attach_data(self, name, data): # Alias class NodeDataLoader(DataLoader): - """NodeDataLoader class.""" + """PyTorch dataloader for batch-iterating over a set of nodes, generating the list + of message flow graphs (MFGs) as computation dependency of the said minibatch. + + Parameters + ---------- + graph : DGLGraph + The graph. + indices : Tensor or dict[ntype, Tensor] + The node set to compute outputs. + graph_sampler : object + The neighborhood sampler. It could be any object that has a :attr:`sample` + method. The :attr:`sample` methods must take in a graph object and either a tensor + of node indices or a dict of such tensors. + device : device context, optional + The device of the generated MFGs in each iteration, which should be a + PyTorch device object (e.g., ``torch.device``). + + By default this value is the same as the device of :attr:`g`. + use_ddp : boolean, optional + If True, tells the DataLoader to split the training set for each + participating process appropriately using + :class:`torch.utils.data.distributed.DistributedSampler`. + + Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`. + ddp_seed : int, optional + The seed for shuffling the dataset in + :class:`torch.utils.data.distributed.DistributedSampler`. + + Only effective when :attr:`use_ddp` is True. + use_uva : bool, optional + Whether to use Unified Virtual Addressing (UVA) to directly sample the graph + and slice the features from CPU into GPU. Setting it to True will pin the + graph and feature tensors into pinned memory. + + Default: False. + use_prefetch_thread : bool, optional + (Advanced option) + Spawns a new Python thread to perform feature slicing + asynchronously. Can make things faster at the cost of GPU memory. + + Default: True if the graph is on CPU and :attr:`device` is CUDA. False otherwise. + use_alternate_streams : bool, optional + (Advanced option) + Whether to slice and transfers the features to GPU on a non-default stream. + + Default: True if the graph is on CPU, :attr:`device` is CUDA, and :attr:`use_uva` + is False. False otherwise. + pin_prefetcher : bool, optional + (Advanced option) + Whether to pin the feature tensors into pinned memory. + + Default: True if the graph is on CPU and :attr:`device` is CUDA. False otherwise. + batch_size : int, optional + drop_last : bool, optional + shuffle : bool, optional + kwargs : dict + Arguments being passed to :py:class:`torch.utils.data.DataLoader`. + + Examples + -------- + To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on + a homogeneous graph where each node takes messages from all neighbors (assume + the backend is PyTorch): + + >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) + >>> dataloader = dgl.dataloading.NodeDataLoader( + ... g, train_nid, sampler, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for input_nodes, output_nodes, blocks in dataloader: + ... train_on(input_nodes, output_nodes, blocks) + + **Using with Distributed Data Parallel** + + If you are using PyTorch's distributed training (e.g. when using + :mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by turning + on the `use_ddp` option: + + >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) + >>> dataloader = dgl.dataloading.NodeDataLoader( + ... g, train_nid, sampler, use_ddp=True, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for epoch in range(start_epoch, n_epochs): + ... for input_nodes, output_nodes, blocks in dataloader: + ... train_on(input_nodes, output_nodes, blocks) + + Notes + ----- + Please refer to + :doc:`Minibatch Training Tutorials ` + and :ref:`User Guide Section 6 ` for usage. + + **Tips for selecting the proper device** + + * If the input graph :attr:`g` is on GPU, the output device :attr:`device` must be the same GPU + and :attr:`num_workers` must be zero. In this case, the sampling and subgraph construction + will take place on the GPU. This is the recommended setting when using a single-GPU and + the whole graph fits in GPU memory. + + * If the input graph :attr:`g` is on CPU while the output device :attr:`device` is GPU, then + depending on the value of :attr:`use_uva`: + + - If :attr:`use_uva` is set to True, the sampling and subgraph construction will happen + on GPU even if the GPU itself cannot hold the entire graph. This is the recommended + setting unless there are operations not supporting UVA. :attr:`num_workers` must be 0 + in this case. + + - Otherwise, both the sampling and subgraph construction will take place on the CPU. + """ class EdgeDataLoader(DataLoader): - """EdgeDataLoader class.""" + """PyTorch dataloader for batch-iterating over a set of edges, generating the list + of message flow graphs (MFGs) as computation dependency of the said minibatch for + edge classification, edge regression, and link prediction. + + For each iteration, the object will yield + + * A tensor of input nodes necessary for computing the representation on edges, or + a dictionary of node type names and such tensors. + + * A subgraph that contains only the edges in the minibatch and their incident nodes. + Note that the graph has an identical metagraph with the original graph. + + * If a negative sampler is given, another graph that contains the "negative edges", + connecting the source and destination nodes yielded from the given negative sampler. + + * A list of MFGs necessary for computing the representation of the incident nodes + of the edges in the minibatch. + + For more details, please refer to :ref:`guide-minibatch-edge-classification-sampler` + and :ref:`guide-minibatch-link-classification-sampler`. + + Parameters + ---------- + g : DGLGraph + The graph. + indices : Tensor or dict[etype, Tensor] + The edge set in graph :attr:`g` to compute outputs. + graph_sampler : object + The neighborhood sampler. It could be any object that has a :attr:`sample` + method. The :attr:`sample` methods must take in a graph object and either a tensor + of node indices or a dict of such tensors. + device : device context, optional + The device of the generated MFGs and graphs in each iteration, which should be a + PyTorch device object (e.g., ``torch.device``). + + By default this value is the same as the device of :attr:`g`. + use_ddp : boolean, optional + If True, tells the DataLoader to split the training set for each + participating process appropriately using + :class:`torch.utils.data.distributed.DistributedSampler`. + + Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`. + ddp_seed : int, optional + The seed for shuffling the dataset in + :class:`torch.utils.data.distributed.DistributedSampler`. + + Only effective when :attr:`use_ddp` is True. + use_prefetch_thread : bool, optional + (Advanced option) + Spawns a new Python thread to perform feature slicing + asynchronously. Can make things faster at the cost of GPU memory. + + Default: True if the graph is on CPU and :attr:`device` is CUDA. False otherwise. + use_alternate_streams : bool, optional + (Advanced option) + Whether to slice and transfers the features to GPU on a non-default stream. + + Default: True if the graph is on CPU, :attr:`device` is CUDA, and :attr:`use_uva` + is False. False otherwise. + pin_prefetcher : bool, optional + (Advanced option) + Whether to pin the feature tensors into pinned memory. + + Default: True if the graph is on CPU and :attr:`device` is CUDA. False otherwise. + exclude : str, optional + Whether and how to exclude dependencies related to the sampled edges in the + minibatch. Possible values are + + * None, for not excluding any edges. + + * ``self``, for excluding only the edges sampled as seed edges in this minibatch. + + * ``reverse_id``, for excluding not only the edges sampled in the minibatch but + also their reverse edges of the same edge type. Requires the argument + :attr:`reverse_eids`. + + * ``reverse_types``, for excluding not only the edges sampled in the minibatch + but also their reverse edges of different types but with the same IDs. + Requires the argument :attr:`reverse_etypes`. + + * A callable which takes in a tensor or a dictionary of tensors and their + canonical edge types and returns a tensor or dictionary of tensors to + exclude. + reverse_eids : Tensor or dict[etype, Tensor], optional + A tensor of reverse edge ID mapping. The i-th element indicates the ID of + the i-th edge's reverse edge. + + If the graph is heterogeneous, this argument requires a dictionary of edge + types and the reverse edge ID mapping tensors. + + See the description of the argument with the same name in the docstring of + :class:`~dgl.dataloading.EdgeCollator` for more details. + reverse_etypes : dict[etype, etype], optional + The mapping from the original edge types to their reverse edge types. + + See the description of the argument with the same name in the docstring of + :class:`~dgl.dataloading.EdgeCollator` for more details. + negative_sampler : callable, optional + The negative sampler. + + See the description of the argument with the same name in the docstring of + :class:`~dgl.dataloading.EdgeCollator` for more details. + use_uva : bool, optional + Whether to use Unified Virtual Addressing (UVA) to directly sample the graph + and slice the features from CPU into GPU. Setting it to True will pin the + graph and feature tensors into pinned memory. + + Default: False. + batch_size : int, optional + drop_last : bool, optional + shuffle : bool, optional + kwargs : dict + Arguments being passed to :py:class:`torch.utils.data.DataLoader`. + + Examples + -------- + The following example shows how to train a 3-layer GNN for edge classification on a + set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes + messages from all neighbors. + + Say that you have an array of source node IDs ``src`` and another array of destination + node IDs ``dst``. One can make it bidirectional by adding another set of edges + that connects from ``dst`` to ``src``: + + >>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src]))) + + One can then know that the ID difference of an edge and its reverse edge is ``|E|``, + where ``|E|`` is the length of your source/destination array. The reverse edge + mapping can be obtained by + + >>> E = len(src) + >>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)]) + + Note that the sampled edges as well as their reverse edges are removed from + computation dependencies of the incident nodes. That is, the edge will not + involve in neighbor sampling and message aggregation. This is a common trick + to avoid information leakage. + + >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) + >>> dataloader = dgl.dataloading.EdgeDataLoader( + ... g, train_eid, sampler, exclude='reverse_id', + ... reverse_eids=reverse_eids, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for input_nodes, pair_graph, blocks in dataloader: + ... train_on(input_nodes, pair_graph, blocks) + + To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` on a + homogeneous graph where each node takes messages from all neighbors (assume the + backend is PyTorch), with 5 uniformly chosen negative samples per edge: + + >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) + >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5) + >>> dataloader = dgl.dataloading.EdgeDataLoader( + ... g, train_eid, sampler, exclude='reverse_id', + ... reverse_eids=reverse_eids, negative_sampler=neg_sampler, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: + ... train_on(input_nodse, pair_graph, neg_pair_graph, blocks) + + For heterogeneous graphs, the reverse of an edge may have a different edge type + from the original edge. For instance, consider that you have an array of + user-item clicks, representated by a user array ``user`` and an item array ``item``. + You may want to build a heterogeneous graph with a user-click-item relation and an + item-clicked-by-user relation. + + >>> g = dgl.heterograph({ + ... ('user', 'click', 'item'): (user, item), + ... ('item', 'clicked-by', 'user'): (item, user)}) + + To train a 3-layer GNN for edge classification on a set of edges ``train_eid`` with + type ``click``, you can write + + >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) + >>> dataloader = dgl.dataloading.EdgeDataLoader( + ... g, {'click': train_eid}, sampler, exclude='reverse_types', + ... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'}, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for input_nodes, pair_graph, blocks in dataloader: + ... train_on(input_nodes, pair_graph, blocks) + + To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` with type + ``click``, you can write + + >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) + >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5) + >>> dataloader = dgl.dataloading.EdgeDataLoader( + ... g, train_eid, sampler, exclude='reverse_types', + ... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'}, + ... negative_sampler=neg_sampler, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: + ... train_on(input_nodes, pair_graph, neg_pair_graph, blocks) + + **Using with Distributed Data Parallel** + + If you are using PyTorch's distributed training (e.g. when using + :mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by + turning on the :attr:`use_ddp` option: + + >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) + >>> dataloader = dgl.dataloading.EdgeDataLoader( + ... g, train_eid, sampler, use_ddp=True, exclude='reverse_id', + ... reverse_eids=reverse_eids, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for epoch in range(start_epoch, n_epochs): + ... for input_nodes, pair_graph, blocks in dataloader: + ... train_on(input_nodes, pair_graph, blocks) + + Notes + ----- + Please refer to + :doc:`Minibatch Training Tutorials ` + and :ref:`User Guide Section 6 ` for usage. + + **Tips for selecting the proper device** + + * If the input graph :attr:`g` is on GPU, the output device :attr:`device` must be the same GPU + and :attr:`num_workers` must be zero. In this case, the sampling and subgraph construction + will take place on the GPU. This is the recommended setting when using a single-GPU and + the whole graph fits in GPU memory. + + * If the input graph :attr:`g` is on CPU while the output device :attr:`device` is GPU, then + depending on the value of :attr:`use_uva`: + + - If :attr:`use_uva` is set to True, the sampling and subgraph construction will happen + on GPU even if the GPU itself cannot hold the entire graph. This is the recommended + setting unless there are operations not supporting UVA. :attr:`num_workers` must be 0 + in this case. + + - Otherwise, both the sampling and subgraph construction will take place on the CPU. + """ def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False, ddp_seed=0, batch_size=1, drop_last=False, shuffle=False, use_prefetch_thread=False, use_alternate_streams=True, + pin_prefetcher=False, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None, - g_sampling=None, **kwargs): - if g_sampling is not None: - dgl_warning( - "g_sampling is deprecated. " - "Please merge g_sampling and the original graph into one graph and use " - "the exclude argument to specify which edges you don't want to sample.") + use_uva=False, **kwargs): + device = _get_device(device) + if isinstance(graph_sampler, BlockSampler): + if reverse_eids is not None: + if use_uva: + reverse_eids = recursive_apply(reverse_eids, lambda x: x.to(device)) + else: + reverse_eids_device = context_of(reverse_eids) + indices_device = context_of(indices) + if indices_device != reverse_eids_device: + raise ValueError('Expect the same device for indices and reverse_eids') graph_sampler = EdgeBlockSampler( graph_sampler, exclude=exclude, reverse_eids=reverse_eids, - reverse_etypes=reverse_etypes, negative_sampler=negative_sampler) + reverse_etypes=reverse_etypes, negative_sampler=negative_sampler, + prefetch_node_feats=graph_sampler.prefetch_node_feats, + prefetch_labels=graph_sampler.prefetch_labels, + prefetch_edge_feats=graph_sampler.prefetch_edge_feats) super().__init__( graph, indices, graph_sampler, device=device, use_ddp=use_ddp, ddp_seed=ddp_seed, batch_size=batch_size, drop_last=drop_last, shuffle=shuffle, use_prefetch_thread=use_prefetch_thread, use_alternate_streams=use_alternate_streams, + pin_prefetcher=pin_prefetcher, use_uva=use_uva, **kwargs) diff --git a/python/dgl/frame.py b/python/dgl/frame.py index 820b2be2a36c..817f2a84d5c9 100644 --- a/python/dgl/frame.py +++ b/python/dgl/frame.py @@ -56,6 +56,12 @@ def data(self): """No-op. For compatibility of :meth:`Frame.__repr__` method.""" return self + def pin_memory_(self): + """No-op. For compatibility of :meth:`Frame.pin_memory_` method.""" + + def unpin_memory_(self): + """No-op. For compatibility of :meth:`Frame.unpin_memory_` method.""" + class Scheme(namedtuple('Scheme', ['shape', 'dtype'])): """The column scheme. @@ -142,6 +148,7 @@ def __init__(self, storage, scheme=None, index=None, device=None): self.scheme = scheme if scheme else infer_scheme(storage) self.index = index self.device = device + self.pinned = False def __len__(self): """The number of features (number of rows) in this column.""" @@ -183,6 +190,7 @@ def data(self, val): """Update the column data.""" self.index = None self.storage = val + self.pinned = False def to(self, device, **kwargs): # pylint: disable=invalid-name """ Return a new column with columns copy to the targeted device (cpu/gpu). @@ -330,6 +338,10 @@ def __getstate__(self): def __copy__(self): return self.clone() + def fetch(self, indices, device, pin_memory=False): + _ = self.data # materialize in case of lazy slicing & data transfer + return super().fetch(indices, device, pin_memory=False) + class Frame(MutableMapping): """The columnar storage for node/edge features. @@ -702,3 +714,15 @@ def to(self, device, **kwargs): # pylint: disable=invalid-name def __repr__(self): return repr(dict(self)) + + def pin_memory_(self): + """Registers the data of every column into pinned memory, materializing them if + necessary.""" + for column in self._columns.values(): + column.pin_memory_() + + def unpin_memory_(self): + """Unregisters the data of every column from pinned memory, materializing them + if necessary.""" + for column in self._columns.values(): + column.unpin_memory_() diff --git a/python/dgl/heterograph.py b/python/dgl/heterograph.py index 36398b020770..327d715f6119 100644 --- a/python/dgl/heterograph.py +++ b/python/dgl/heterograph.py @@ -5474,7 +5474,7 @@ def pin_memory_(self): Materialization of new sparse formats for pinned graphs is not allowed. To avoid implicit formats materialization during training, - you should create all the needed formats before pinnning. + you should create all the needed formats before pinning. But cloning and materialization is fine. See the examples below. Returns @@ -5530,6 +5530,7 @@ def pin_memory_(self): if F.device_type(self.device) != 'cpu': raise DGLError("The graph structure must be on CPU to be pinned.") self._graph.pin_memory_() + return self def unpin_memory_(self): @@ -5546,6 +5547,7 @@ def unpin_memory_(self): if not self._graph.is_pinned(): return self self._graph.unpin_memory_() + return self def is_pinned(self): diff --git a/python/dgl/heterograph_index.py b/python/dgl/heterograph_index.py index b6331fd18eee..7db664b271c6 100644 --- a/python/dgl/heterograph_index.py +++ b/python/dgl/heterograph_index.py @@ -1,6 +1,7 @@ """Module for heterogeneous graph index class definition.""" from __future__ import absolute_import +import sys import itertools import numpy as np import scipy @@ -1365,4 +1366,27 @@ def __setstate__(self, state): self.__init_handle_by_constructor__( _CAPI_DGLCreateHeteroPickleStatesOld, metagraph, num_nodes_per_type, adjs) +def _forking_rebuild(pk_state): + meta, arrays = pk_state + arrays = [F.to_dgl_nd(arr) for arr in arrays] + states = _CAPI_DGLCreateHeteroPickleStates(meta, arrays) + return _CAPI_DGLHeteroForkingUnpickle(states) + +def _forking_reduce(graph_index): + states = _CAPI_DGLHeteroForkingPickle(graph_index) + arrays = [F.from_dgl_nd(arr) for arr in states.arrays] + # Similar to what being mentioned in HeteroGraphIndex.__getstate__, we need to save + # the tensors as an attribute of the original graph index object. Otherwise + # PyTorch will throw weird errors like bad value(s) in fds_to_keep or unable to + # resize file. + graph_index._forking_pk_state = (states.meta, arrays) + return _forking_rebuild, (graph_index._forking_pk_state,) + + +if not (F.get_preferred_backend() == 'mxnet' and sys.version_info.minor <= 6): + # Python 3.6 MXNet crashes with the following statement; remove until we no longer support + # 3.6 (which is EOL anyway). + from multiprocessing.reduction import ForkingPickler + ForkingPickler.register(HeteroGraphIndex, _forking_reduce) + _init_api("dgl.heterograph_index") diff --git a/python/dgl/ndarray.py b/python/dgl/ndarray.py index cc01feb50e44..3709e435815f 100644 --- a/python/dgl/ndarray.py +++ b/python/dgl/ndarray.py @@ -222,6 +222,7 @@ def __repr__(self): _set_class_ndarray(NDArray) _init_api("dgl.ndarray") +_init_api("dgl.ndarray.uvm", __name__) # An array representing null (no value) that can be safely converted to # other backend tensors. diff --git a/python/dgl/storages/__init__.py b/python/dgl/storages/__init__.py index bfa19e496e85..c24f5588b5be 100644 --- a/python/dgl/storages/__init__.py +++ b/python/dgl/storages/__init__.py @@ -3,7 +3,8 @@ from .base import * from .numpy import * +# Defines the name TensorStorage if F.get_preferred_backend() == 'pytorch': - from .pytorch_tensor import * + from .pytorch_tensor import PyTorchTensorStorage as TensorStorage else: - from .tensor import * + from .tensor import BaseTensorStorage as TensorStorage diff --git a/python/dgl/storages/pytorch_tensor.py b/python/dgl/storages/pytorch_tensor.py index 8fdb30a03b8f..82f1158ff5f8 100644 --- a/python/dgl/storages/pytorch_tensor.py +++ b/python/dgl/storages/pytorch_tensor.py @@ -1,7 +1,9 @@ """Feature storages for PyTorch tensors.""" import torch -from .base import FeatureStorage, register_storage_wrapper +from .base import register_storage_wrapper +from .tensor import BaseTensorStorage +from ..utils import gather_pinned_tensor_rows def _fetch_cpu(indices, tensor, feature_shape, device, pin_memory): result = torch.empty( @@ -15,18 +17,26 @@ def _fetch_cuda(indices, tensor, device): return torch.index_select(tensor, 0, indices).to(device) @register_storage_wrapper(torch.Tensor) -class TensorStorage(FeatureStorage): +class PyTorchTensorStorage(BaseTensorStorage): """Feature storages for slicing a PyTorch tensor.""" - def __init__(self, tensor): - self.storage = tensor - self.feature_shape = tensor.shape[1:] - self.is_cuda = (tensor.device.type == 'cuda') - def fetch(self, indices, device, pin_memory=False): device = torch.device(device) - if not self.is_cuda: + storage_device_type = self.storage.device.type + indices_device_type = indices.device.type + if storage_device_type != 'cuda': + if indices_device_type == 'cuda': + if self.storage.is_pinned(): + return gather_pinned_tensor_rows(self.storage, indices) + else: + raise ValueError( + f'Got indices on device {indices.device} whereas the feature tensor ' + f'is on {self.storage.device}. Please either (1) move the graph ' + f'to GPU with to() method, or (2) pin the graph with ' + f'pin_memory_() method.') # CPU to CPU or CUDA - use pin_memory and async transfer if possible - return _fetch_cpu(indices, self.storage, self.feature_shape, device, pin_memory) + else: + return _fetch_cpu(indices, self.storage, self.storage.shape[1:], device, + pin_memory) else: # CUDA to CUDA or CPU return _fetch_cuda(indices, self.storage, device) diff --git a/python/dgl/storages/tensor.py b/python/dgl/storages/tensor.py index d454119ec3c3..bf7026c1b7dd 100644 --- a/python/dgl/storages/tensor.py +++ b/python/dgl/storages/tensor.py @@ -1,12 +1,8 @@ """Feature storages for tensors across different frameworks.""" from .base import FeatureStorage from .. import backend as F -from ..utils import recursive_apply_pair -def _fetch(indices, tensor, device): - return F.copy_to(F.gather_row(tensor, indices), device) - -class TensorStorage(FeatureStorage): +class BaseTensorStorage(FeatureStorage): """FeatureStorage that synchronously slices features from a tensor and transfers it to the given device. """ @@ -14,4 +10,4 @@ def __init__(self, tensor): self.storage = tensor def fetch(self, indices, device, pin_memory=False): # pylint: disable=unused-argument - return recursive_apply_pair(indices, self.storage, _fetch, device) + return F.copy_to(F.gather_row(tensor, indices), device) diff --git a/python/dgl/utils/__init__.py b/python/dgl/utils/__init__.py index 6e15327b9824..f789a6eaedb3 100644 --- a/python/dgl/utils/__init__.py +++ b/python/dgl/utils/__init__.py @@ -5,3 +5,4 @@ from .shared_mem import * from .filter import * from .exception import * +from .pin_memory import * diff --git a/python/dgl/utils/internal.py b/python/dgl/utils/internal.py index b232e7299717..9a942316d3de 100644 --- a/python/dgl/utils/internal.py +++ b/python/dgl/utils/internal.py @@ -937,4 +937,8 @@ def recursive_apply_pair(data1, data2, fn, *args, **kwargs): else: return fn(data1, data2, *args, **kwargs) +def context_of(data): + """Return the device of the data which can be either a tensor or a dict of tensors.""" + return F.context(next(iter(data.values())) if isinstance(data, Mapping) else data) + _init_api("dgl.utils.internal") diff --git a/python/dgl/utils/pin_memory.py b/python/dgl/utils/pin_memory.py new file mode 100644 index 000000000000..45c1ddaa718d --- /dev/null +++ b/python/dgl/utils/pin_memory.py @@ -0,0 +1,32 @@ +"""Utility functions related to pinned memory tensors.""" + +from .. import backend as F +from .._ffi.function import _init_api + +def pin_memory_inplace(tensor): + """Register the tensor into pinned memory in-place (i.e. without copying).""" + F.to_dgl_nd(tensor).pin_memory_() + +def unpin_memory_inplace(tensor): + """Unregister the tensor from pinned memory in-place (i.e. without copying).""" + F.to_dgl_nd(tensor).unpin_memory_() + +def gather_pinned_tensor_rows(tensor, rows): + """Directly gather rows from a CPU tensor given an indices array on CUDA devices, + and returns the result on the same CUDA device without copying. + + Parameters + ---------- + tensor : Tensor + The tensor. Must be in pinned memory. + rows : Tensor + The rows to gather. Must be a CUDA tensor. + + Returns + ------- + Tensor + The result with the same device as :attr:`rows`. + """ + return F.from_dgl_nd(_CAPI_DGLIndexSelectCPUFromGPU(F.to_dgl_nd(tensor), F.to_dgl_nd(rows))) + +_init_api("dgl.ndarray.uvm", __name__) diff --git a/src/array/cuda/array_index_select.cu b/src/array/cuda/array_index_select.cu index cc8b35bd4315..816b143e58f7 100644 --- a/src/array/cuda/array_index_select.cu +++ b/src/array/cuda/array_index_select.cu @@ -27,7 +27,7 @@ NDArray IndexSelect(NDArray array, IdArray index) { shape.emplace_back(array->shape[d]); } - // use index->ctx for kDLCPUPinned array + // use index->ctx for pinned array NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx); if (len == 0) return ret; diff --git a/src/array/cuda/uvm/array_index_select_uvm.cu b/src/array/cuda/uvm/array_index_select_uvm.cu index dc6abe54e1a7..1a4aaa23d501 100644 --- a/src/array/cuda/uvm/array_index_select_uvm.cu +++ b/src/array/cuda/uvm/array_index_select_uvm.cu @@ -24,7 +24,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { int64_t num_feat = 1; std::vector shape{len}; - CHECK_EQ(array->ctx.device_type, kDLCPUPinned); + CHECK(array.IsPinned()); CHECK_EQ(index->ctx.device_type, kDLGPU); for (int d = 1; d < array->ndim; ++d) { @@ -72,6 +72,8 @@ template NDArray IndexSelectCPUFromGPU(NDArray, IdArray); template NDArray IndexSelectCPUFromGPU(NDArray, IdArray); template NDArray IndexSelectCPUFromGPU(NDArray, IdArray); template NDArray IndexSelectCPUFromGPU(NDArray, IdArray); +template NDArray IndexSelectCPUFromGPU(NDArray, IdArray); +template NDArray IndexSelectCPUFromGPU(NDArray, IdArray); } // namespace impl } // namespace aten diff --git a/src/array/uvm_array.cc b/src/array/uvm_array.cc index b580319aecdc..f78a7713bde9 100644 --- a/src/array/uvm_array.cc +++ b/src/array/uvm_array.cc @@ -15,7 +15,7 @@ namespace aten { NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { #ifdef DGL_USE_CUDA - CHECK_EQ(array->ctx.device_type, kDLCPUPinned) + CHECK(array.IsPinned()) << "Only the CPUPinned device type input array is supported"; CHECK_EQ(index->ctx.device_type, kDLGPU) << "Only the GPU device type input index is supported"; diff --git a/src/bcast.cc b/src/bcast.cc index 3149b60aab04..bf03221e5beb 100644 --- a/src/bcast.cc +++ b/src/bcast.cc @@ -83,12 +83,6 @@ BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs) { rst.out_len /= rst.reduce_size; // out_len is divied by reduce_size in dot. } } -#ifdef DEBUG - LOG(INFO) << "lhs_len: " << rst.lhs_len << " " << - "rhs_len: " << rst.rhs_len << " " << - "out_len: " << rst.out_len << " " << - "reduce_size: " << rst.reduce_size << std::endl; -#endif return rst; } diff --git a/src/graph/heterograph.h b/src/graph/heterograph.h index 240b81884ab5..31ad2f99a60e 100644 --- a/src/graph/heterograph.h +++ b/src/graph/heterograph.h @@ -236,7 +236,7 @@ class HeteroGraph : public BaseHeteroGraph { * \brief Pin all relation graphs of the current graph. * \note The graph will be pinned inplace. Behavior depends on the current context, * kDLCPU: will be pinned; - * kDLCPUPinned: directly return; + * IsPinned: directly return; * kDLGPU: invalid, will throw an error. * The context check is deferred to pinning the NDArray. */ @@ -245,7 +245,7 @@ class HeteroGraph : public BaseHeteroGraph { /*! * \brief Unpin all relation graphs of the current graph. * \note The graph will be unpinned inplace. Behavior depends on the current context, - * kDLCPUPinned: will be unpinned; + * IsPinned: will be unpinned; * others: directly return. * The context check is deferred to unpinning the NDArray. */ @@ -272,6 +272,18 @@ class HeteroGraph : public BaseHeteroGraph { return relation_graphs_; } + void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override { + GetRelationGraph(etype)->SetCOOMatrix(0, coo); + } + + void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override { + GetRelationGraph(etype)->SetCSRMatrix(0, csr); + } + + void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override { + GetRelationGraph(etype)->SetCSCMatrix(0, csc); + } + private: // To create empty class friend class Serializer; diff --git a/src/graph/heterograph_capi.cc b/src/graph/heterograph_capi.cc index 114c05fa6b3c..6ec678af4726 100644 --- a/src/graph/heterograph_capi.cc +++ b/src/graph/heterograph_capi.cc @@ -173,13 +173,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDataType") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroContext") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; - // The Python side only recognizes CPU and GPU device type. - // Use is_pinned() to checked whether the object is - // on page-locked memory - if (hg->Context().device_type == kDLCPUPinned) - *rv = DLContext{kDLCPU, 0}; - else - *rv = hg->Context(); + *rv = hg->Context(); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsPinned") diff --git a/src/graph/pickle.cc b/src/graph/pickle.cc index 36e0880c4074..b751083f15ba 100644 --- a/src/graph/pickle.cc +++ b/src/graph/pickle.cc @@ -51,6 +51,42 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) { return states; } +HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph) { + HeteroPickleStates states; + dmlc::MemoryStringStream ofs(&states.meta); + dmlc::Stream *strm = &ofs; + strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph())); + strm->Write(graph->NumVerticesPerType()); + for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { + auto created_formats = graph->GetCreatedFormats(); + auto allowed_formats = graph->GetAllowedFormats(); + strm->Write(created_formats); + strm->Write(allowed_formats); + if (created_formats & COO_CODE) { + const auto &coo = graph->GetCOOMatrix(etype); + strm->Write(coo.row_sorted); + strm->Write(coo.col_sorted); + states.arrays.push_back(coo.row); + states.arrays.push_back(coo.col); + } + if (created_formats & CSR_CODE) { + const auto &csr = graph->GetCSRMatrix(etype); + strm->Write(csr.sorted); + states.arrays.push_back(csr.indptr); + states.arrays.push_back(csr.indices); + states.arrays.push_back(csr.data); + } + if (created_formats & CSC_CODE) { + const auto &csc = graph->GetCSCMatrix(etype); + strm->Write(csc.sorted); + states.arrays.push_back(csc.indptr); + states.arrays.push_back(csc.indices); + states.arrays.push_back(csc.data); + } + } + return states; +} + HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { char *buf = const_cast(states.meta.c_str()); // a readonly stream? dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size()); @@ -137,6 +173,76 @@ HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states) { return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type); } +HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) { + char *buf = const_cast(states.meta.c_str()); // a readonly stream? + dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size()); + dmlc::Stream *strm = &ifs; + auto meta_imgraph = Serializer::make_shared(); + CHECK(strm->Read(&meta_imgraph)) << "Invalid meta graph"; + GraphPtr metagraph = meta_imgraph; + std::vector relgraphs(metagraph->NumEdges()); + std::vector num_nodes_per_type; + CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type"; + + auto array_itr = states.arrays.begin(); + for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) { + const auto& pair = metagraph->FindEdge(etype); + const dgl_type_t srctype = pair.first; + const dgl_type_t dsttype = pair.second; + const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2; + int64_t num_src = num_nodes_per_type[srctype]; + int64_t num_dst = num_nodes_per_type[dsttype]; + + dgl_format_code_t created_formats, allowed_formats; + CHECK(strm->Read(&created_formats)) << "Invalid code for created formats"; + CHECK(strm->Read(&allowed_formats)) << "Invalid code for allowed formats"; + HeteroGraphPtr relgraph = nullptr; + + if (created_formats & COO_CODE) { + CHECK_GE(states.arrays.end() - array_itr, 2); + const auto &row = *(array_itr++); + const auto &col = *(array_itr++); + bool rsorted; + bool csorted; + CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'"; + CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'"; + auto coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted); + if (!relgraph) + relgraph = CreateFromCOO(num_vtypes, coo, allowed_formats); + else + relgraph->SetCOOMatrix(0, coo); + } + if (created_formats & CSR_CODE) { + CHECK_GE(states.arrays.end() - array_itr, 3); + const auto &indptr = *(array_itr++); + const auto &indices = *(array_itr++); + const auto &edge_id = *(array_itr++); + bool sorted; + CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'"; + auto csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted); + if (!relgraph) + relgraph = CreateFromCSR(num_vtypes, csr, allowed_formats); + else + relgraph->SetCSRMatrix(0, csr); + } + if (created_formats & CSC_CODE) { + CHECK_GE(states.arrays.end() - array_itr, 3); + const auto &indptr = *(array_itr++); + const auto &indices = *(array_itr++); + const auto &edge_id = *(array_itr++); + bool sorted; + CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'"; + auto csc = aten::CSRMatrix(num_dst, num_src, indptr, indices, edge_id, sorted); + if (!relgraph) + relgraph = CreateFromCSC(num_vtypes, csc, allowed_formats); + else + relgraph->SetCSCMatrix(0, csc); + } + relgraphs[etype] = relgraph; + } + return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type); +} + DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroPickleStatesRef st = args[0]; @@ -186,6 +292,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle") *rv = HeteroPickleStatesRef(st); }); +DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingPickle") +.set_body([] (DGLArgs args, DGLRetValue* rv) { + HeteroGraphRef ref = args[0]; + std::shared_ptr st( new HeteroPickleStates ); + *st = HeteroForkingPickle(ref.sptr()); + *rv = HeteroPickleStatesRef(st); + }); + DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroPickleStatesRef ref = args[0]; @@ -203,6 +317,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle") *rv = HeteroGraphRef(graph); }); +DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingUnpickle") +.set_body([] (DGLArgs args, DGLRetValue* rv) { + HeteroPickleStatesRef ref = args[0]; + HeteroGraphPtr graph = HeteroForkingUnpickle(*ref.sptr()); + *rv = HeteroGraphRef(graph); + }); + DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld") .set_body([] (DGLArgs args, DGLRetValue* rv) { GraphRef metagraph = args[0]; diff --git a/src/graph/sampling/neighbor/neighbor.cc b/src/graph/sampling/neighbor/neighbor.cc index 8cc342d36c6b..2cc9c89dab12 100644 --- a/src/graph/sampling/neighbor/neighbor.cc +++ b/src/graph/sampling/neighbor/neighbor.cc @@ -31,7 +31,7 @@ HeteroSubgraph ExcludeCertainEdges( sg.induced_edges[etype]->shape[0], sg.induced_edges[etype]->dtype.bits, sg.induced_edges[etype]->ctx); - if (exclude_edges[etype].GetSize() == 0) { + if (exclude_edges[etype].GetSize() == 0 || edge_ids.GetSize() == 0) { remain_edges[etype] = edge_ids; remain_induced_edges[etype] = sg.induced_edges[etype]; continue; diff --git a/src/graph/sampling/randomwalks/frequency_hashmap.cu b/src/graph/sampling/randomwalks/frequency_hashmap.cu index ee7460210ee7..03348cfd5c19 100644 --- a/src/graph/sampling/randomwalks/frequency_hashmap.cu +++ b/src/graph/sampling/randomwalks/frequency_hashmap.cu @@ -4,12 +4,12 @@ * \brief frequency hashmap - used to select top-k frequency edges of each node */ -#include #include #include #include #include "../../../runtime/cuda/cuda_common.h" #include "../../../array/cuda/atomic.cuh" +#include "../../../array/cuda/dgl_cub.cuh" #include "frequency_hashmap.cuh" namespace dgl { diff --git a/src/graph/unit_graph.cc b/src/graph/unit_graph.cc index 288500271809..460171efd802 100644 --- a/src/graph/unit_graph.cc +++ b/src/graph/unit_graph.cc @@ -359,6 +359,18 @@ class UnitGraph::COO : public BaseHeteroGraph { return aten::CSRMatrix(); } + void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override { + adj_ = coo; + } + + void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override { + LOG(FATAL) << "Not enabled for COO graph"; + } + + void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override { + LOG(FATAL) << "Not enabled for COO graph"; + } + SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override { LOG(FATAL) << "Not enabled for COO graph"; return SparseFormat::kCOO; @@ -779,6 +791,18 @@ class UnitGraph::CSR : public BaseHeteroGraph { return adj_; } + void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override { + LOG(FATAL) << "Not enabled for CSR graph"; + } + + void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override { + adj_ = csr; + } + + void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override { + LOG(FATAL) << "Please use in_csr_->SetCSRMatrix(etype, csc) instead."; + } + SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override { LOG(FATAL) << "Not enabled for CSR graph"; return SparseFormat::kCSR; @@ -1243,7 +1267,7 @@ HeteroGraphPtr UnitGraph::CreateFromCSC( if (num_vtypes == 1) CHECK_EQ(num_src, num_dst); auto mg = CreateUnitGraphMetaGraph(num_vtypes); - CSRPtr csc(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids)); + CSRPtr csc(new CSR(mg, num_dst, num_src, indptr, indices, edge_ids)); return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats)); } @@ -1488,6 +1512,54 @@ aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const { return GetCOO()->adj(); } +void UnitGraph::SetCOOMatrix(dgl_type_t etype, COOMatrix coo) { + if (!(formats_ & COO_CODE)) { + LOG(FATAL) << "The graph have restricted sparse format " << + CodeToStr(formats_) << ", cannot set COO matrix."; + return; + } + if (IsPinned()) { + LOG(FATAL) << "Cannot set COOMatrix if the graph is pinned, please unpin the graph."; + return; + } + if (!coo_->defined()) + *(const_cast(this)->coo_) = COO(meta_graph(), coo); + else + coo_->SetCOOMatrix(0, coo); +} + +void UnitGraph::SetCSRMatrix(dgl_type_t etype, CSRMatrix csr) { + if (!(formats_ & CSR_CODE)) { + LOG(FATAL) << "The graph have restricted sparse format " << + CodeToStr(formats_) << ", cannot set CSR matrix."; + return; + } + if (IsPinned()) { + LOG(FATAL) << "Cannot set CSRMatrix if the graph is pinned, please unpin the graph."; + return; + } + if (!out_csr_->defined()) + *(const_cast(this)->out_csr_) = CSR(meta_graph(), csr); + else + out_csr_->SetCSRMatrix(0, csr); +} + +void UnitGraph::SetCSCMatrix(dgl_type_t etype, CSRMatrix csc) { + if (!(formats_ & CSC_CODE)) { + LOG(FATAL) << "The graph have restricted sparse format " << + CodeToStr(formats_) << ", cannot set CSC matrix."; + return; + } + if (IsPinned()) { + LOG(FATAL) << "Cannot set CSCMatrix if the graph is pinned, please unpin the graph."; + return; + } + if (!in_csr_->defined()) + *(const_cast(this)->in_csr_) = CSR(meta_graph(), csc); + else + in_csr_->SetCSRMatrix(0, csc); +} + HeteroGraphPtr UnitGraph::GetAny() const { if (in_csr_->defined()) { return in_csr_; diff --git a/src/graph/unit_graph.h b/src/graph/unit_graph.h index 7ac73357a377..d21b0cff2ae5 100644 --- a/src/graph/unit_graph.h +++ b/src/graph/unit_graph.h @@ -214,7 +214,7 @@ class UnitGraph : public BaseHeteroGraph { * \brief Pin the in_csr_, out_scr_ and coo_ of the current graph. * \note The graph will be pinned inplace. Behavior depends on the current context, * kDLCPU: will be pinned; - * kDLCPUPinned: directly return; + * IsPinned: directly return; * kDLGPU: invalid, will throw an error. * The context check is deferred to pinning the NDArray. */ @@ -223,7 +223,7 @@ class UnitGraph : public BaseHeteroGraph { /*! * \brief Unpin the in_csr_, out_scr_ and coo_ of the current graph. * \note The graph will be unpinned inplace. Behavior depends on the current context, - * kDLCPUPinned: will be unpinned; + * IsPinned: will be unpinned; * others: directly return. * The context check is deferred to unpinning the NDArray. */ @@ -305,6 +305,10 @@ class UnitGraph : public BaseHeteroGraph { void InvalidateCOO(); + void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override; + void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override; + void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override; + private: friend class Serializer; friend class HeteroGraph; diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 84e7aec1cbe4..02554e60754d 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -187,6 +187,42 @@ class CUDADeviceAPI final : public DeviceAPI { CUDA_CALL(cudaHostUnregister(ptr)); } + bool IsPinned(const void* ptr) override { + // can't be a pinned tensor if CUDA context is unavailable. + if (!is_available_) + return false; + + cudaPointerAttributes attr; + cudaError_t status = cudaPointerGetAttributes(&attr, ptr); + bool result = false; + + switch (status) { + case cudaErrorInvalidValue: + // might be a normal CPU tensor in CUDA 10.2- + cudaGetLastError(); // clear error + break; + case cudaSuccess: + result = (attr.type == cudaMemoryTypeHost); + break; + case cudaErrorInitializationError: + case cudaErrorNoDevice: + case cudaErrorInsufficientDriver: + case cudaErrorInvalidDevice: + // We don't want to fail in these particular cases since this function can be called + // when users only want to run on CPU even if CUDA API is enabled, or in a forked + // subprocess where CUDA context cannot be initialized. So we just mark the CUDA + // context to unavailable and return. + is_available_ = false; + cudaGetLastError(); // clear error + break; + default: + LOG(FATAL) << "error while determining memory status: " << cudaGetErrorString(status); + break; + } + + return result; + } + void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final { return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); } @@ -213,6 +249,8 @@ class CUDADeviceAPI final : public DeviceAPI { CUDA_CALL(cudaStreamSynchronize(stream)); } } + + bool is_available_ = true; }; typedef dmlc::ThreadLocalStore CUDAThreadStore; diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 11589fb345c3..c9dc99a0353e 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -64,7 +64,7 @@ struct NDArray::Internal { ptr->mem = nullptr; } else if (ptr->dl_tensor.data != nullptr) { // if the array is still pinned before freeing, unpin it. - if (ptr->dl_tensor.ctx.device_type == kDLCPUPinned) { + if (IsDataPinned(&(ptr->dl_tensor))) { UnpinData(&(ptr->dl_tensor)); } dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace( @@ -206,19 +206,6 @@ NDArray NDArray::EmptyShared(const std::string &name, return ret; } -inline DLContext GetDevice(DLContext ctx) { - switch (ctx.device_type) { - case kDLCPU: - case kDLGPU: - return ctx; - break; - default: - // fallback to CPU - return DLContext{kDLCPU, 0}; - break; - } -} - NDArray NDArray::Empty(std::vector shape, DLDataType dtype, DLContext ctx) { @@ -226,7 +213,7 @@ NDArray NDArray::Empty(std::vector shape, if (td->IsAvailable()) return td->Empty(shape, dtype, ctx); - NDArray ret = Internal::Create(shape, dtype, GetDevice(ctx)); + NDArray ret = Internal::Create(shape, dtype, ctx); // setup memory content size_t size = GetDataSize(ret.data_->dl_tensor); size_t alignment = GetDataAlignment(ret.data_->dl_tensor); @@ -242,6 +229,7 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { data->deleter = Internal::DLPackDeleter; data->manager_ctx = tensor; data->dl_tensor = tensor->dl_tensor; + return NDArray(data); } @@ -260,7 +248,7 @@ void NDArray::CopyFromTo(DLTensor* from, // Use the context that is *not* a cpu context to get the correct device // api manager. - DGLContext ctx = GetDevice(from->ctx).device_type != kDLCPU ? from->ctx : to->ctx; + DGLContext ctx = from->ctx.device_type != kDLCPU ? from->ctx : to->ctx; DeviceAPI::Get(ctx)->CopyDataFromTo( from->data, static_cast(from->byte_offset), @@ -269,19 +257,15 @@ void NDArray::CopyFromTo(DLTensor* from, } void NDArray::PinData(DLTensor* tensor) { - // Only need to call PinData once, since the pinned memory can be seen - // by all CUDA contexts, not just the one that performed the allocation - if (tensor->ctx.device_type == kDLCPUPinned) return; + if (IsDataPinned(tensor)) return; CHECK_EQ(tensor->ctx.device_type, kDLCPU) << "Only NDArray on CPU can be pinned"; DeviceAPI::Get(kDLGPU)->PinData(tensor->data, GetDataSize(*tensor)); - tensor->ctx = DLContext{kDLCPUPinned, 0}; } void NDArray::UnpinData(DLTensor* tensor) { - if (tensor->ctx.device_type != kDLCPUPinned) return; + if (!IsDataPinned(tensor)) return; DeviceAPI::Get(kDLGPU)->UnpinData(tensor->data); - tensor->ctx = DLContext{kDLCPU, 0}; } template @@ -343,6 +327,14 @@ std::shared_ptr NDArray::GetSharedMem() const { return this->data_->mem; } +bool NDArray::IsDataPinned(DLTensor* tensor) { + // Can only be pinned if on CPU... + if (tensor->ctx.device_type != kDLCPU) + return false; + // ... and CUDA device API is enabled, and the tensor is indeed in pinned memory. + auto device = DeviceAPI::Get(kDLGPU, true); + return device && device->IsPinned(tensor->data); +} void NDArray::Save(dmlc::Stream* strm) const { auto zc_strm = dynamic_cast(strm); @@ -489,10 +481,9 @@ int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out, API_BEGIN(); auto* nd_container = reinterpret_cast(from); DLTensor* nd = &(nd_container->dl_tensor); - if ((alignment != 0 && !is_aligned(nd->data, alignment)) - || (nd->ctx.device_type == kDLCPUPinned)) { + if (alignment != 0 && !is_aligned(nd->data, alignment)) { std::vector shape_vec(nd->shape, nd->shape + nd->ndim); - NDArray copy_ndarray = NDArray::Empty(shape_vec, nd->dtype, GetDevice(nd->ctx)); + NDArray copy_ndarray = NDArray::Empty(shape_vec, nd->dtype, nd->ctx); copy_ndarray.CopyFrom(nd); *out = copy_ndarray.ToDLPack(); } else { diff --git a/tests/compute/test_heterograph.py b/tests/compute/test_heterograph.py index 1065df738f18..d838dc4e944f 100644 --- a/tests/compute/test_heterograph.py +++ b/tests/compute/test_heterograph.py @@ -12,6 +12,7 @@ from test_utils import parametrize_dtype, get_cases from utils import assert_is_identical_hetero from scipy.sparse import rand +import multiprocessing as mp def create_test_heterograph(idtype): # test heterograph from the docstring, plus a user -- wishes -- game relation @@ -206,6 +207,32 @@ def _test_validate_bipartite(card): assert g.device == F.cpu() assert F.array_equal(g.edata['w'], F.copy_to(F.tensor(adj.data), F.cpu())) +def test_create2(): + mat = ssp.random(20, 30, 0.1) + + # coo + mat = mat.tocoo() + row = F.tensor(mat.row, dtype=F.int64) + col = F.tensor(mat.col, dtype=F.int64) + g = dgl.heterograph( + {('A', 'AB', 'B'): ('coo', (row, col))}, num_nodes_dict={'A': 20, 'B': 30}) + + # csr + mat = mat.tocsr() + indptr = F.tensor(mat.indptr, dtype=F.int64) + indices = F.tensor(mat.indices, dtype=F.int64) + data = F.tensor([], dtype=F.int64) + g = dgl.heterograph( + {('A', 'AB', 'B'): ('csr', (indptr, indices, data))}, num_nodes_dict={'A': 20, 'B': 30}) + + # csc + mat = mat.tocsc() + indptr = F.tensor(mat.indptr, dtype=F.int64) + indices = F.tensor(mat.indices, dtype=F.int64) + data = F.tensor([], dtype=F.int64) + g = dgl.heterograph( + {('A', 'AB', 'B'): ('csc', (indptr, indices, data))}, num_nodes_dict={'A': 20, 'B': 30}) + @parametrize_dtype def test_query(idtype): g = create_test_heterograph(idtype) @@ -2796,6 +2823,24 @@ def test_adj_sparse(idtype, fmt): assert np.array_equal(F.asnumpy(indices_sorted), indices_sorted_np) +def _test_forking_pickler_entry(g, q): + q.put(g.formats()) + +@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="MXNet doesn't support spawning") +def test_forking_pickler(): + ctx = mp.get_context('spawn') + g = dgl.graph(([0,1,2],[1,2,3])) + g.create_formats_() + q = ctx.Queue(1) + proc = ctx.Process(target=_test_forking_pickler_entry, args=(g, q)) + proc.start() + fmt = q.get()['created'] + proc.join() + assert 'coo' in fmt + assert 'csr' in fmt + assert 'csc' in fmt + + if __name__ == '__main__': # test_create() # test_query() diff --git a/tests/pytorch/test_dataloader.py b/tests/pytorch/test_dataloader.py index 088b6b49c66d..a98a3c799b17 100644 --- a/tests/pytorch/test_dataloader.py +++ b/tests/pytorch/test_dataloader.py @@ -1,12 +1,14 @@ import os +import numpy as np import dgl import dgl.ops as OPS import backend as F import unittest import torch +from functools import partial from torch.utils.data import DataLoader from collections import defaultdict -from collections.abc import Iterator +from collections.abc import Iterator, Mapping from itertools import product import pytest @@ -99,7 +101,8 @@ def _check_device(data): assert data.device == F.ctx() @pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2']) -@pytest.mark.parametrize('pin_graph', [True, False]) +# TODO(BarclayII): Re-enable pin_graph = True after PyTorch is upgraded to 1.9.0 on CI +@pytest.mark.parametrize('pin_graph', [False]) def test_node_dataloader(sampler_name, pin_graph): g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])) if F.ctx() != F.cpu() and pin_graph: @@ -153,7 +156,8 @@ def test_node_dataloader(sampler_name, pin_graph): dgl.dataloading.negative_sampler.Uniform(2), dgl.dataloading.negative_sampler.GlobalUniform(15, False, 3), dgl.dataloading.negative_sampler.GlobalUniform(15, True, 3)]) -@pytest.mark.parametrize('pin_graph', [True, False]) +# TODO(BarclayII): Re-enable pin_graph = True after PyTorch is upgraded to 1.9.0 on CI +@pytest.mark.parametrize('pin_graph', [False]) def test_edge_dataloader(sampler_name, neg_sampler, pin_graph): g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])) if F.ctx() != F.cpu() and pin_graph: @@ -222,15 +226,99 @@ def test_edge_dataloader(sampler_name, neg_sampler, pin_graph): if g1.is_pinned(): g1.unpin_memory_() +def _create_homogeneous(): + s = torch.randint(0, 200, (1000,), device=F.ctx()) + d = torch.randint(0, 200, (1000,), device=F.ctx()) + src = torch.cat([s, d]) + dst = torch.cat([d, s]) + g = dgl.graph((s, d), num_nodes=200) + reverse_eids = torch.cat([torch.arange(1000, 2000), torch.arange(0, 1000)]).to(F.ctx()) + always_exclude = torch.randint(0, 1000, (50,), device=F.ctx()) + seed_edges = torch.arange(0, 1000, device=F.ctx()) + return g, reverse_eids, always_exclude, seed_edges + +def _create_heterogeneous(): + edges = {} + for utype, etype, vtype in [('A', 'AA', 'A'), ('A', 'AB', 'B')]: + s = torch.randint(0, 200, (1000,), device=F.ctx()) + d = torch.randint(0, 200, (1000,), device=F.ctx()) + edges[utype, etype, vtype] = (s, d) + edges[vtype, 'rev-' + etype, utype] = (d, s) + g = dgl.heterograph(edges, num_nodes_dict={'A': 200, 'B': 200}) + reverse_etypes = {'AA': 'rev-AA', 'AB': 'rev-AB', 'rev-AA': 'AA', 'rev-AB': 'AB'} + always_exclude = { + 'AA': torch.randint(0, 1000, (50,), device=F.ctx()), + 'AB': torch.randint(0, 1000, (50,), device=F.ctx())} + seed_edges = { + 'AA': torch.arange(0, 1000, device=F.ctx()), + 'AB': torch.arange(0, 1000, device=F.ctx())} + return g, reverse_etypes, always_exclude, seed_edges + +def _find_edges_to_exclude(g, exclude, always_exclude, pair_eids): + if exclude == None: + return always_exclude + elif exclude == 'self': + return torch.cat([pair_eids, always_exclude]) if always_exclude is not None else pair_eids + elif exclude == 'reverse_id': + pair_eids = torch.cat([pair_eids, pair_eids + 1000]) + return torch.cat([pair_eids, always_exclude]) if always_exclude is not None else pair_eids + elif exclude == 'reverse_types': + pair_eids = {g.to_canonical_etype(k): v for k, v in pair_eids.items()} + if ('A', 'AA', 'A') in pair_eids: + pair_eids[('A', 'rev-AA', 'A')] = pair_eids[('A', 'AA', 'A')] + if ('A', 'AB', 'B') in pair_eids: + pair_eids[('B', 'rev-AB', 'A')] = pair_eids[('A', 'AB', 'B')] + if always_exclude is not None: + always_exclude = {g.to_canonical_etype(k): v for k, v in always_exclude.items()} + for k in always_exclude.keys(): + if k in pair_eids: + pair_eids[k] = torch.cat([pair_eids[k], always_exclude[k]]) + else: + pair_eids[k] = always_exclude[k] + return pair_eids + +@pytest.mark.parametrize('always_exclude_flag', [False, True]) +@pytest.mark.parametrize('exclude', [None, 'self', 'reverse_id', 'reverse_types']) +def test_edge_dataloader_excludes(exclude, always_exclude_flag): + if exclude == 'reverse_types': + g, reverse_etypes, always_exclude, seed_edges = _create_heterogeneous() + else: + g, reverse_eids, always_exclude, seed_edges = _create_homogeneous() + g = g.to(F.ctx()) + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) + if not always_exclude_flag: + always_exclude = None + + kwargs = {} + kwargs['exclude'] = ( + partial(_find_edges_to_exclude, g, exclude, always_exclude) if always_exclude_flag + else exclude) + kwargs['reverse_eids'] = reverse_eids if exclude == 'reverse_id' else None + kwargs['reverse_etypes'] = reverse_etypes if exclude == 'reverse_types' else None + + dataloader = dgl.dataloading.EdgeDataLoader( + g, seed_edges, sampler, batch_size=50, device=F.ctx(), **kwargs) + for input_nodes, pair_graph, blocks in dataloader: + block = blocks[0] + pair_eids = pair_graph.edata[dgl.EID] + block_eids = block.edata[dgl.EID] + + edges_to_exclude = _find_edges_to_exclude(g, exclude, always_exclude, pair_eids) + if edges_to_exclude is None: + continue + edges_to_exclude = dgl.utils.recursive_apply(edges_to_exclude, lambda x: x.cpu().numpy()) + block_eids = dgl.utils.recursive_apply(block_eids, lambda x: x.cpu().numpy()) + + if isinstance(edges_to_exclude, Mapping): + for k in edges_to_exclude.keys(): + assert not np.isin(edges_to_exclude[k], block_eids[k]).any() + else: + assert not np.isin(edges_to_exclude, block_eids).any() + if __name__ == '__main__': test_graph_dataloader() test_cluster_gcn(0) test_neighbor_nonuniform(0) - for sampler in ['full', 'neighbor']: - test_node_dataloader(sampler) - for neg_sampler in [ - dgl.dataloading.negative_sampler.Uniform(2), - dgl.dataloading.negative_sampler.GlobalUniform(2, False), - dgl.dataloading.negative_sampler.GlobalUniform(2, True)]: - for pin_graph in [True, False]: - test_edge_dataloader(sampler, neg_sampler, pin_graph) + for exclude in [None, 'self', 'reverse_id', 'reverse_types']: + test_edge_dataloader_excludes(exclude, False) + test_edge_dataloader_excludes(exclude, True)