From 2bf8a600e0a19a2fe5f51fea2649972f7af044d7 Mon Sep 17 00:00:00 2001 From: NoakLiu <116571268+NoakLiu@users.noreply.github.com> Date: Sat, 30 Nov 2024 01:59:47 +0800 Subject: [PATCH] graphsnapshot GraphSnapShot is a framework for caching local structure for fast and efficient graph learning. It can achieve fast storage, retrieval and computation for graph learning at large scale. It can quickly store and update the local topology of graph structure, just like take snapshots of the graphs. Motivation: https://github.com/dmlc/dgl/issues/7841 Paper: https://arxiv.org/abs/2406.17918 Code implementation: https://github.com/NoakLiu/GraphSnapShot DGL acceleration module: https://github.com/NoakLiu/GraphSnapShot/tree/main/examples/dgl/dgl_cache_struct DGL test module: https://github.com/NoakLiu/GraphSnapShot/tree/main/examples/dgl/acceleration_tests_dgl --- python/dgl/dataloading/neighbor_sampler.py | 4216 +++++++++++++++++++- 1 file changed, 4212 insertions(+), 4 deletions(-) diff --git a/python/dgl/dataloading/neighbor_sampler.py b/python/dgl/dataloading/neighbor_sampler.py index 7aa7d8b7dc01..cf15a43bab25 100644 --- a/python/dgl/dataloading/neighbor_sampler.py +++ b/python/dgl/dataloading/neighbor_sampler.py @@ -1,5 +1,5 @@ """Data loading components for neighbor sampling""" - +from functools import cache from .. import backend as F from ..base import EID, NID from ..heterograph import DGLGraph @@ -7,6 +7,9 @@ from ..utils import get_num_threads from .base import BlockSampler +import torch +import dgl + class NeighborSampler(BlockSampler): """Sampler that builds computational dependency of node representations via @@ -195,11 +198,9 @@ def sample_blocks(self, g, seed_nodes, exclude_eids=None): ) block = to_block(frontier, seed_nodes) # If sampled from graphbolt-backed DistGraph, `EID` may not be in - # the block. If not exists, we should remove it from the block. + # the block. if EID in frontier.edata.keys(): block.edata[EID] = frontier.edata[EID] - else: - del block.edata[EID] seed_nodes = block.srcdata[NID] blocks.insert(0, block) @@ -244,3 +245,4210 @@ class MultiLayerFullNeighborSampler(NeighborSampler): def __init__(self, num_layers, **kwargs): super().__init__([-1] * num_layers, **kwargs) + +class NeighborSampler_FCR_struct(BlockSampler): + """ + A neighbor sampler that supports cache-refreshing (FCR) for efficient sampling, + tailored for multi-layer GNNs. This sampler augments the sampling process by + maintaining a cache of pre-sampled neighborhoods that can be reused across + multiple sampling iterations. It introduces cache amplification (via the alpha + parameter) and cache refresh cycles (via the T parameter) to manage the balance + between sampling efficiency and freshness of the sampled neighborhoods. + + Parameters + ---------- + g : DGLGraph + The input graph. + fanouts : list[int] or list[dict[etype, int]] + List of neighbors to sample per edge type for each GNN layer, with the i-th + element being the fanout for the i-th GNN layer. + edge_dir : str, default "in" + Direction of sampling. Can be either "in" for incoming edges or "out" for outgoing edges. + prob : str, optional + Name of the edge feature in g.edata used as the probability for edge sampling. + alpha : int, default 2 + Cache amplification ratio. Determines the size of the pre-sampled cache relative + to the actual sampling needs. A larger alpha means more neighbors are pre-sampled. + T : int, default 1 + Cache refresh cycle. Specifies how often (in terms of sampling iterations) the + cache should be refreshed. + + Examples + -------- + Initialize a graph and a NeighborSampler_FCR_struct for a 2-layer GNN with fanouts + [5, 10]. Assume alpha=2 for double the size of pre-sampling and T=3 for refreshing + the cache every 3 iterations. + + >>> import dgl + >>> import torch + >>> g = dgl.rand_graph(100, 200) # Random graph with 100 nodes and 200 edges + >>> g.ndata['feat'] = torch.randn(100, 10) # Random node features + >>> sampler = NeighborSampler_FCR_struct(g, [5, 10], alpha=2, T=3) + + To perform sampling: + + >>> seed_nodes = torch.tensor([1, 2, 3]) # Nodes for which neighbors are sampled + >>> for i in range(5): # Simulate 5 sampling iterations + ... seed_nodes, output_nodes, blocks = sampler.sample_blocks(seed_nodes) + ... # Process the sampled blocks + """ + + def __init__( + self, + g, + fanouts, + edge_dir='in', + alpha=2, + T=20, + prob=None, + mask=None, + replace=False, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + output_device=None, + fused=True, + ): + self.g = g + + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.fanouts = fanouts + self.edge_dir = edge_dir + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.prob = prob or mask + self.replace = replace + self.fused = fused + self.mapping = {} + + self.alpha = alpha + self.cycle = 0 # Initialize sampling cycle counter + self.amplified_fanouts = [f * alpha for f in fanouts] # Amplified fanouts for pre-sampling + self.T = T + self.Toptim = int(self.g.number_of_nodes() / max(self.amplified_fanouts)) + self.cache_struct = [] # Initialize cache structure + self.cache_refresh() # Pre-sample and populate the cache + + def cache_refresh(self,exclude_eids=None): + """ + Pre-samples neighborhoods with amplified fanouts and refreshes the cache. This method + is automatically called upon initialization and after every T sampling iterations to + ensure that the cache is periodically updated with fresh samples. + """ + self.cache_struct.clear() # Clear existing cache + for fanout in self.amplified_fanouts: + # Sample neighbors for each layer with amplified fanout + # print("large") + # print(fanout) + # print("---") + frontier = self.g.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), # Consider all nodes as seeds for pre-sampling + # self.g.number_of_nodes(), + # 10, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids + ) + frontier = dgl.add_self_loop(frontier) + # print(frontier) + # print(self.cache_struct) + # print("then append") + self.cache_struct.append(frontier) # Update cache with new samples + + def sample_blocks(self, g,seed_nodes, exclude_eids=None): + """ + Samples blocks from the graph for the specified seed nodes using the cache. + + Parameters + ---------- + seed_nodes : Tensor + The nodes for which the neighborhoods are to be sampled. + + Returns + ------- + tuple + A tuple containing the seed nodes for the next layer, the output nodes, and + the list of blocks sampled from the graph. + """ + output_nodes = seed_nodes + + # refresh cache after a period of time for generalization + self.cycle += 1 + if self.cycle % self.Toptim == 0: + self.cache_refresh() # Refresh cache every T cycles + + blocks = [] + + if self.fused and get_num_threads() > 1: + # print("fused") + cpu = F.device_type(g.device) == "cpu" + if isinstance(seed_nodes, dict): + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = g.sample_neighbors_fused( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + + for k in range(len(self.cache_struct)-1,-1,-1): + cached_structure = self.cache_struct[k] + fanout = self.fanouts[k] + frontier = cached_structure.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids + ) + + # Sample frontier from the cache for acceleration + block = to_block(frontier, seed_nodes) + if EID in frontier.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier.edata[EID] + blocks.insert(0, block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + # output_nodes = seed_nodes + return seed_nodes, output_nodes, blocks + +class NeighborSampler_FCR_struct_shared_cache(BlockSampler): + def __init__( + self, + g, + fanouts, + edge_dir='in', + alpha=2, + T=20, + prob=None, + mask=None, + replace=False, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + output_device=None, + fused=True, + ): + self.g = g + + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.fanouts = fanouts + self.edge_dir = edge_dir + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.prob = prob or mask + self.replace = replace + self.fused = fused + self.mapping = {} + + self.alpha = alpha + self.cycle = 0 # Initialize sampling cycle counter + self.amplified_fanouts = [f * alpha for f in fanouts] # Amplified fanouts for pre-sampling + self.T = T + self.Toptim = int(self.g.number_of_nodes() / max(self.amplified_fanouts)) + # self.cache_struct = [] # Initialize cache structure + self.shared_cache_size = max(self.amplified_fanouts) + self.shared_cache = None + self.cache_refresh() # Pre-sample and populate the cache + + def cache_refresh(self,exclude_eids=None): + """ + Pre-samples neighborhoods to refresh the shared cache. This method + is automatically called upon initialization and after every T sampling iterations to + ensure that the cache is periodically updated with fresh samples. + """ + del self.shared_cache + self.shared_cache=self.g.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), # Consider all nodes as seeds for pre-sampling + self.shared_cache_size, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids + ) + + def sample_blocks(self, g,seed_nodes, exclude_eids=None): + """ + Samples blocks from the graph for the specified seed nodes using the cache. + + Parameters + ---------- + seed_nodes : Tensor + The nodes for which the neighborhoods are to be sampled. + + Returns + ------- + tuple + A tuple containing the seed nodes for the next layer, the output nodes, and + the list of blocks sampled from the graph. + """ + output_nodes = seed_nodes + + # refresh full cache after every T cycles to learn graph structure + self.cycle += 1 + if self.cycle % self.Toptim == 0: + self.cache_refresh() + + blocks = [] + + if self.fused and get_num_threads() > 1: + # print("fused") + cpu = F.device_type(g.device) == "cpu" + if isinstance(seed_nodes, dict): + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = g.sample_neighbors_fused( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + + for k in range(len(self.fanouts)-1,-1,-1): + fanout = self.fanouts[k] + frontier = self.shared_cache.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids + ) + + # Sample frontier from the cache for acceleration + block = to_block(frontier, seed_nodes) + if EID in frontier.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier.edata[EID] + blocks.insert(0, block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + # output_nodes = seed_nodes + return seed_nodes, output_nodes, blocks + +class NeighborSampler_OTF_struct_FSCRFCF_shared_cache(BlockSampler): + """ + Implements an on-the-fly (OTF) neighbor sampling strategy for Deep Graph Library (DGL) graphs. + This sampler dynamically samples neighbors while balancing efficiency through caching and + freshness of samples by periodically refreshing parts of the cache. It supports specifying + fanouts, sampling direction, and probabilities, along with cache management parameters to + control the trade-offs between sampling efficiency and cache freshness. + + As for the parameters explanations, + 1. amp_rate: sample a larger cache than the original cache to store the local structure + 2. refresh_rate: decide how many portion should be sampled from disk, and the remaining comes out from cache, then combine them as new disk + 3. T: decide how long time will the cache to refresh and store the new structure (refresh mode in OTF is partially refresh) + """ + + def __init__(self, g, + fanouts, + edge_dir='in', + amp_rate=1.5, # cache amplification rate (should be bigger than 1 --> to sample for multiple time) + refresh_rate=0.4, #propotion of cache to be refresh, should be a positive float smaller than 0.5 + T=100, # refresh time + prob=None, + replace=False, + output_device=None, + exclude_eids=None, + mask=None, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + fused=True, + ): + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.g = g + self.fanouts = fanouts + self.edge_dir = edge_dir + self.amp_rate = amp_rate + self.refresh_rate = refresh_rate + self.replace = replace + self.output_device = output_device + self.exclude_eids = exclude_eids + self.cycle = 0 + + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.prob = prob or mask + self.fused = fused + self.mapping = {} + self.amp_cache_size = [fanout * amp_rate for fanout in fanouts] + self.Toptim = int(self.g.number_of_nodes() / (max(self.amp_cache_size))*self.amp_rate) + self.T = T + # self.cached_graph_structures = [self.initialize_cache(cache_size) for cache_size in self.cache_size] + + self.shared_cache_size = max(self.amp_cache_size) + self.shared_cache = self.initialize_cache(self.shared_cache_size) + + def initialize_cache(self, fanout_cache_storage): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + # mappings=self.mapping + ) + print("end init cache") + return cached_graph + + def refresh_cache(self, fanout_cache_refresh): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + fanout_cache_sample = self.shared_cache_size-fanout_cache_refresh + cache_remain = self.shared_cache.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), + fanout_cache_sample, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + disk_to_add = self.g.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), + fanout_cache_refresh, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + self.shared_cache = dgl.merge([cache_remain, disk_to_add]) + del cache_remain + del disk_to_add + print("end refresh cache") + + def sample_blocks(self, g, seed_nodes, exclude_eids=None): + """ + Samples blocks for GNN layers by combining cached samples with dynamically sampled + neighbors. This method also partially refreshes the cache based on specified parameters + to balance between sampling efficiency and the freshness of the samples. + """ + self.cycle += 1 + blocks = [] + output_nodes = seed_nodes + if((self.cycle % self.Toptim)==0): + # Refresh cache partially + fanout_cache_refresh = int(self.shared_cache_size * self.refresh_rate) + self.refresh_cache(fanout_cache_refresh) + + for i, (fanout) in enumerate(reversed(self.fanouts)): + # Sample from cache + frontier_from_cache = self.shared_cache.sample_neighbors( + seed_nodes, + #fanout_cache_retrieval, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + # Convert the merged frontier to a block + block = to_block(frontier_from_cache, seed_nodes) + if EID in frontier_from_cache.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier_from_cache.edata[EID] + blocks.append(block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + return seed_nodes,output_nodes, blocks + +class NeighborSampler_OTF_struct_FSCRFCF(BlockSampler): + """ + Implements an on-the-fly (OTF) neighbor sampling strategy for Deep Graph Library (DGL) graphs. + This sampler dynamically samples neighbors while balancing efficiency through caching and + freshness of samples by periodically refreshing parts of the cache. It supports specifying + fanouts, sampling direction, and probabilities, along with cache management parameters to + control the trade-offs between sampling efficiency and cache freshness. + + As for the parameters explanations, + 1. amp_rate: sample a larger cache than the original cache to store the local structure + 2. refresh_rate: decide how many portion should be sampled from disk, and the remaining comes out from cache, then combine them as new disk + 3. T: decide how long time will the cache to refresh and store the new structure (refresh mode in OTF is partially refresh) + """ + + def __init__(self, g, + fanouts, + edge_dir='in', + amp_rate=1.5, # cache amplification rate (should be bigger than 1 --> to sample for multiple time) + refresh_rate=0.4, #propotion of cache to be refresh, should be a positive float smaller than 0.5 + T=100, # refresh time + prob=None, + replace=False, + output_device=None, + exclude_eids=None, + mask=None, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + fused=True, + ): + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.g = g + self.fanouts = fanouts + self.edge_dir = edge_dir + self.amp_rate = amp_rate + self.refresh_rate = refresh_rate + self.replace = replace + self.output_device = output_device + self.exclude_eids = exclude_eids + self.cycle = 0 + + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.prob = prob or mask + self.fused = fused + self.mapping = {} + self.cache_size = [fanout * amp_rate for fanout in fanouts] + self.T = T + self.Toptim = int(self.g.number_of_nodes() / (max(self.cache_size))*self.amp_rate) + self.cached_graph_structures = [self.initialize_cache(cache_size) for cache_size in self.cache_size] + + def initialize_cache(self, fanout_cache_storage): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + print("end init cache") + return cached_graph + + def refresh_cache(self,layer_id, cached_graph_structure, fanout_cache_refresh): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + fanout_cache_sample = self.cache_size[layer_id]-fanout_cache_refresh + cache_remain = cached_graph_structure.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), + fanout_cache_sample, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + disk_to_add = self.g.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), + fanout_cache_refresh, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + refreshed_cache = dgl.merge([cache_remain, disk_to_add]) + print("end refresh cache") + return refreshed_cache + + def sample_blocks(self, g, seed_nodes, exclude_eids=None): + """ + Samples blocks for GNN layers by combining cached samples with dynamically sampled + neighbors. This method also partially refreshes the cache based on specified parameters + to balance between sampling efficiency and the freshness of the samples. + """ + blocks = [] + output_nodes = seed_nodes + self.cycle += 1 + if((self.cycle % self.Toptim)==0): + for i in range(0,len(self.cached_graph_structures)): + # Refresh cache partially + fanout_cache_refresh = int(self.cache_size[i] * self.refresh_rate) + self.cached_graph_structures[i]=self.refresh_cache(i, self.cached_graph_structures[i], fanout_cache_refresh) + + for i, (fanout, cached_graph_structure) in enumerate(zip(reversed(self.fanouts), reversed(self.cached_graph_structures))): + # Sample from cache + frontier_from_cache = self.cached_graph_structures[i].sample_neighbors( + seed_nodes, + #fanout_cache_retrieval, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + # Convert the merged frontier to a block + block = to_block(frontier_from_cache, seed_nodes) + if EID in frontier_from_cache.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier_from_cache.edata[EID] + blocks.append(block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + return seed_nodes,output_nodes, blocks + +class NeighborSampler_OTF_struct_PCFFSCR_shared_cache(BlockSampler): + """ + Implements an on-the-fly (OTF) neighbor sampling strategy for Deep Graph Library (DGL) graphs. + This sampler dynamically samples neighbors while balancing efficiency through caching and + freshness of samples by periodically refreshing parts of the cache. It supports specifying + fanouts, sampling direction, and probabilities, along with cache management parameters to + control the trade-offs between sampling efficiency and cache freshness. + + As for the parameters explanations, + 1. amp_rate: sample a larger cache than the original cache to store the local structure + 2. refresh_rate: decide how many portion should be sampled from disk, and the remaining comes out from cache, then combine them as new disk + 3. T: decide how long time will the cache to refresh and store the new structure (refresh mode in OTF is partially refresh) + """ + + def __init__(self, g, + fanouts, + edge_dir='in', + amp_rate=1.5, # cache amplification rate (should be bigger than 1 --> to sample for multiple time) + fetch_rate=0.4, #propotion of cache to be fetch from cache, should be a positive float smaller than 0.5 + T_fetch=3, # fetch period of time + T_refresh=None, # refresh time + prob=None, + replace=False, + output_device=None, + exclude_eids=None, + mask=None, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + fused=True, + ): + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.g = g + self.fanouts = fanouts + self.edge_dir = edge_dir + self.amp_rate = amp_rate + self.fetch_rate = fetch_rate + self.replace = replace + self.output_device = output_device + self.exclude_eids = exclude_eids + + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.prob = prob or mask + self.fused = fused + self.mapping = {} + self.amp_cache_size = [fanout * amp_rate for fanout in fanouts] + if T_refresh!=None: + self.T_refresh = T_refresh + else: + self.T_refresh = int(self.g.number_of_nodes()/max(self.fanouts) *self.amp_rate) + self.T_fetch = T_fetch + # self.cached_graph_structures = None + self.cycle = 0 + + self.shared_cache_size = max(self.amp_cache_size) + self.shared_cache = self.full_cache_refresh(self.shared_cache_size) + + def full_cache_refresh(self, fanout_cache_storage): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + print("cache refresh") + return cached_graph + + def OTF_fetch(self,layer_id, seed_nodes, fanout_cache_fetch): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + print("OTF fetch cache") + if(fanout_cache_fetch==self.fanouts[layer_id]): + cache_fetch = self.shared_cache.sample_neighbors( + seed_nodes, + fanout_cache_fetch, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + return cache_fetch + else: + fanout_disk_fetch = self.fanouts[layer_id]-fanout_cache_fetch + cache_fetch = self.shared_cache.sample_neighbors( + seed_nodes, + fanout_cache_fetch, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + disk_fetch = self.g.sample_neighbors( + seed_nodes, + fanout_disk_fetch, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + OTF_fetch_res = dgl.merge([cache_fetch, disk_fetch]) + return OTF_fetch_res + + def sample_blocks(self, g, seed_nodes, exclude_eids=None): + """ + Samples blocks for GNN layers by combining cached samples with dynamically sampled + neighbors. This method also partially refreshes the cache based on specified parameters + to balance between sampling efficiency and the freshness of the samples. + """ + blocks = [] + output_nodes = seed_nodes + + self.cycle += 1 + print("self.T_refresh=",self.T_refresh) + # refresh full cache after a period of time + if((self.cycle%self.T_refresh)==0): + self.shared_cache = self.full_cache_refresh(self.shared_cache_size) + # self.cached_graph_structures = [self.full_cache_refresh(cache_size) for cache_size in self.cache_size] + + for i, (fanout) in enumerate(reversed(self.fanouts)): + fanout_cache_fetch = int(fanout * self.fetch_rate) + + # fetch cache partially + if((self.cycle%self.T_fetch)==0): + frontier_OTF = self.OTF_fetch(i, seed_nodes, fanout_cache_fetch) + else: + #frontier_OTF = self.OTF_fetch(i, seed_nodes, self.fanouts[i]) + frontier_OTF = self.shared_cache.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + # Convert the merged frontier to a block + block = to_block(frontier_OTF, seed_nodes) + if EID in frontier_OTF.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier_OTF.edata[EID] + blocks.append(block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + return seed_nodes,output_nodes, blocks + +class NeighborSampler_OTF_struct_PCFFSCR(BlockSampler): + """ + Implements an on-the-fly (OTF) neighbor sampling strategy for Deep Graph Library (DGL) graphs. + This sampler dynamically samples neighbors while balancing efficiency through caching and + freshness of samples by periodically refreshing parts of the cache. It supports specifying + fanouts, sampling direction, and probabilities, along with cache management parameters to + control the trade-offs between sampling efficiency and cache freshness. + + As for the parameters explanations, + 1. amp_rate: sample a larger cache than the original cache to store the local structure + 2. refresh_rate: decide how many portion should be sampled from disk, and the remaining comes out from cache, then combine them as new disk + 3. T: decide how long time will the cache to refresh and store the new structure (refresh mode in OTF is partially refresh) + """ + + def __init__(self, g, + fanouts, + edge_dir='in', + amp_rate=1.5, # cache amplification rate (should be bigger than 1 --> to sample for multiple time) + fetch_rate=0.4, #propotion of cache to be refresh, should be a positive float smaller than 0.5 + T_fetch=3, # fetch period of time + T_refresh=None, # refresh time + prob=None, + replace=False, + output_device=None, + exclude_eids=None, + mask=None, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + fused=True, + ): + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.g = g + self.fanouts = fanouts + self.edge_dir = edge_dir + self.amp_rate = amp_rate + self.fetch_rate = fetch_rate + self.replace = replace + self.output_device = output_device + self.exclude_eids = exclude_eids + + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.prob = prob or mask + self.fused = fused + self.mapping = {} + self.cache_size = [fanout * amp_rate for fanout in fanouts] + if T_refresh!=None: + self.T_refresh = T_refresh + else: + self.T_refresh = int(self.g.number_of_nodes()/max(self.fanouts) *self.amp_rate) + self.T_fetch = T_fetch + self.cached_graph_structures = [self.full_cache_refresh(cache_size) for cache_size in self.cache_size] + self.cycle = 0 + + def full_cache_refresh(self, fanout_cache_storage): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + print("cache refresh") + return cached_graph + + def OTF_fetch(self,layer_id, cached_graph_structure, seed_nodes, fanout_cache_fetch): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + fanout_disk_fetch = self.fanouts[layer_id]-fanout_cache_fetch + cache_fetch = cached_graph_structure.sample_neighbors( + seed_nodes, + fanout_cache_fetch, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + disk_fetch = self.g.sample_neighbors( + seed_nodes, + fanout_disk_fetch, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + OTF_fetch_res = dgl.merge([cache_fetch, disk_fetch]) + print("OTF fetch cache") + return OTF_fetch_res + + def sample_blocks(self, g, seed_nodes, exclude_eids=None): + """ + Samples blocks for GNN layers by combining cached samples with dynamically sampled + neighbors. This method also partially refreshes the cache based on specified parameters + to balance between sampling efficiency and the freshness of the samples. + """ + blocks = [] + output_nodes = seed_nodes + self.cycle += 1 + + # refresh full cache after a period of time + if((self.cycle%self.T_refresh)==0): + self.cached_graph_structures = [self.full_cache_refresh(cache_size) for cache_size in self.cache_size] + + for i, (fanout, cached_graph_structure) in enumerate(zip(reversed(self.fanouts), reversed(self.cached_graph_structures))): + fanout_cache_refresh = int(fanout * self.fetch_rate) + + # fetch cache partially + if((self.cycle%self.T_fetch)==0): + frontier_OTF = self.OTF_fetch(i, cached_graph_structure, seed_nodes, fanout_cache_refresh) + else: + # frontier_OTF = self.OTF_fetch(i, cached_graph_structure, seed_nodes, self.fanouts[i]) + frontier_OTF = cached_graph_structure.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + # Convert the merged frontier to a block + block = to_block(frontier_OTF, seed_nodes) + if EID in frontier_OTF.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier_OTF.edata[EID] + blocks.append(block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + return seed_nodes,output_nodes, blocks + +class NeighborSampler_OTF_struct_PCFPSCR_SC(BlockSampler): + """ + Implements an on-the-fly (OTF) neighbor sampling strategy for Deep Graph Library (DGL) graphs. + This sampler dynamically samples neighbors while balancing efficiency through caching and + freshness of samples by periodically refreshing parts of the cache. It supports specifying + fanouts, sampling direction, and probabilities, along with cache management parameters to + control the trade-offs between sampling efficiency and cache freshness. + + As for the parameters explanations, + 1. amp_rate: sample a larger cache than the original cache to store the local structure + 2. refresh_rate: decide how many portion should be sampled from disk, and the remaining comes out from cache, then combine them as new disk + 3. T: decide how long time will the cache to refresh and store the new structure (refresh mode in OTF is partially refresh) + """ + + def __init__(self, g, + fanouts, + edge_dir='in', + amp_rate=1.5, # cache amplification rate (should be bigger than 1 --> to sample for multiple time) + refresh_rate=0.4, #propotion of cache to be refresh, should be a positive float smaller than 0.5 + T=50, # refresh time, for example + prob=None, + replace=False, + output_device=None, + exclude_eids=None, + mask=None, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + fused=True, + ): + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.g = g + self.fanouts = fanouts + self.edge_dir = edge_dir + self.amp_rate = amp_rate + self.refresh_rate = refresh_rate + self.replace = replace + self.output_device = output_device + self.exclude_eids = exclude_eids + + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.prob = prob or mask + self.fused = fused + self.mapping = {} + self.amp_cache_size = [fanout * amp_rate for fanout in fanouts] + self.T = T + self.cycle = 0 + # self.cached_graph_structures = [self.initialize_cache(cache_size) for cache_size in self.cache_size] + + self.shared_cache_size = max(self.amp_cache_size) + self.shared_cache = self.initialize_cache(self.shared_cache_size) + + def initialize_cache(self, fanout_cache_storage): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + print("end init cache") + return cached_graph + + def OTF_rf_cache(self,layer_id, seed_nodes, fanout_cache_refresh, fanout): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + fanout_cache_remain = self.shared_cache_size-fanout_cache_refresh + fanout_cache_pr = fanout-fanout_cache_refresh + + all_nodes = torch.arange(0, self.g.number_of_nodes()) + # mask = ~torch.isin(all_nodes, seed_nodes) + # # 使用布尔掩码来选择不在seed_nodes中的节点 + # unchanged_nodes = all_nodes[mask] + + # unchanged_nodes = torch.arange(0, self.g.number_of_nodes())-seed_nodes + # the rest node structure remain the same + unchanged_structure = self.shared_cache.sample_neighbors( + all_nodes, + # unchanged_nodes, + self.shared_cache_size, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + # the OTF node structure should + changed_cache_remain = self.shared_cache.sample_neighbors( + seed_nodes, + fanout_cache_remain, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + cache_pr = self.shared_cache.sample_neighbors( + seed_nodes, + fanout_cache_pr, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + changed_disk_to_add = self.g.sample_neighbors( + seed_nodes, + fanout_cache_refresh, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + refreshed_cache = dgl.merge([unchanged_structure, changed_cache_remain, changed_disk_to_add]) + retrieval_cache = dgl.merge([cache_pr, changed_disk_to_add]) + return refreshed_cache, retrieval_cache + + def sample_blocks(self, g, seed_nodes, exclude_eids=None): + """ + Samples blocks for GNN layers by combining cached samples with dynamically sampled + neighbors. This method also partially refreshes the cache based on specified parameters + to balance between sampling efficiency and the freshness of the samples. + """ + blocks = [] + output_nodes = seed_nodes + self.cycle += 1 + for i, (fanout) in enumerate(reversed(self.fanouts)): + fanout_cache_refresh = int(fanout * self.refresh_rate) + + # Refresh cache&disk partially, while retrieval cache&disk partially + if(self.cycle%self.T==0): + self.shared_cache, frontier_comp = self.OTF_rf_cache(i, seed_nodes, fanout_cache_refresh, fanout) + else: + frontier_comp = self.shared_cache.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + # Convert the merged frontier to a block + block = to_block(frontier_comp, seed_nodes) + if EID in frontier_comp.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier_comp.edata[EID] + blocks.append(block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + return seed_nodes,output_nodes, blocks + +class NeighborSampler_OTF_struct_PCFPSCR(BlockSampler): + """ + Implements an on-the-fly (OTF) neighbor sampling strategy for Deep Graph Library (DGL) graphs. + This sampler dynamically samples neighbors while balancing efficiency through caching and + freshness of samples by periodically refreshing parts of the cache. It supports specifying + fanouts, sampling direction, and probabilities, along with cache management parameters to + control the trade-offs between sampling efficiency and cache freshness. + + As for the parameters explanations, + 1. amp_rate: sample a larger cache than the original cache to store the local structure + 2. refresh_rate: decide how many portion should be sampled from disk, and the remaining comes out from cache, then combine them as new disk + 3. T: decide how long time will the cache to refresh and store the new structure (refresh mode in OTF is partially refresh) + """ + + def __init__(self, g, + fanouts, + edge_dir='in', + amp_rate=1.5, # cache amplification rate (should be bigger than 1 --> to sample for multiple time) + refresh_rate=0.4, #propotion of cache to be refresh, should be a positive float smaller than 0.5 + T=50, # refresh time, for example + prob=None, + replace=False, + output_device=None, + exclude_eids=None, + mask=None, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + fused=True, + ): + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.g = g + self.fanouts = fanouts + self.edge_dir = edge_dir + self.amp_rate = amp_rate + self.refresh_rate = refresh_rate + self.replace = replace + self.output_device = output_device + self.exclude_eids = exclude_eids + + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.prob = prob or mask + self.fused = fused + self.mapping = {} + self.cache_size = [fanout * amp_rate for fanout in fanouts] + self.T = T + self.cached_graph_structures = [self.initialize_cache(cache_size) for cache_size in self.cache_size] + self.cycle = 0 + + def initialize_cache(self, fanout_cache_storage): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + print("end init cache") + return cached_graph + + def OTF_rf_cache(self,layer_id, cached_graph_structure, seed_nodes, fanout_cache_refresh, fanout): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + fanout_cache_remain = self.cache_size[layer_id]-fanout_cache_refresh + fanout_cache_pr = fanout-fanout_cache_refresh + # unchanged_nodes = range(torch.arange(0, self.g.number_of_nodes()))-seed_nodes + # the rest node structure remain the same + all_nodes = torch.arange(0, self.g.number_of_nodes()) + mask = ~torch.isin(all_nodes, seed_nodes) + # 使用布尔掩码来选择不在seed_nodes中的节点 + unchanged_nodes = all_nodes[mask] + unchanged_structure = cached_graph_structure.sample_neighbors( + unchanged_nodes, + self.cache_size[layer_id], + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + # the OTF node structure should + changed_cache_remain = cached_graph_structure.sample_neighbors( + seed_nodes, + fanout_cache_remain, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + cache_pr = cached_graph_structure.sample_neighbors( + seed_nodes, + fanout_cache_pr, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + changed_disk_to_add = self.g.sample_neighbors( + seed_nodes, + fanout_cache_refresh, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + refreshed_cache = dgl.merge([unchanged_structure, changed_cache_remain, changed_disk_to_add]) + retrieval_cache = dgl.merge([cache_pr, changed_disk_to_add]) + del unchanged_structure, changed_cache_remain, cache_pr, changed_disk_to_add + return refreshed_cache, retrieval_cache + + def sample_blocks(self, g, seed_nodes, exclude_eids=None): + """ + Samples blocks for GNN layers by combining cached samples with dynamically sampled + neighbors. This method also partially refreshes the cache based on specified parameters + to balance between sampling efficiency and the freshness of the samples. + """ + blocks = [] + output_nodes = seed_nodes + self.cycle += 1 + for i, (fanout, cached_graph_structure) in enumerate(zip(reversed(self.fanouts), reversed(self.cached_graph_structures))): + if(self.cycle % self.T) == 0: + fanout_cache_refresh = int(fanout * self.refresh_rate) + + # Refresh cache&disk partially, while retrieval cache&disk partially + self.cached_graph_structures[i], frontier_comp = self.OTF_rf_cache(i, cached_graph_structure, seed_nodes, fanout_cache_refresh, fanout) + else: + frontier_comp = self.cached_graph_structures[i].sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + # Convert the merged frontier to a block + block = to_block(frontier_comp, seed_nodes) + if EID in frontier_comp.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier_comp.edata[EID] + blocks.append(block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + return seed_nodes,output_nodes, blocks + +class NeighborSampler_OTF_struct_PSCRFCF_SC(BlockSampler): + """ + Implements an on-the-fly (OTF) neighbor sampling strategy for Deep Graph Library (DGL) graphs. + This sampler dynamically samples neighbors while balancing efficiency through caching and + freshness of samples by periodically refreshing parts of the cache. It supports specifying + fanouts, sampling direction, and probabilities, along with cache management parameters to + control the trade-offs between sampling efficiency and cache freshness. + + As for the parameters explanations, + 1. amp_rate: sample a larger cache than the original cache to store the local structure + 2. refresh_rate: decide how many portion should be sampled from disk, and the remaining comes out from cache, then combine them as new disk + 3. T: decide how long time will the cache to refresh and store the new structure (refresh mode in OTF is partially refresh) + """ + + def __init__(self, g, + fanouts, + edge_dir='in', + amp_rate=1.5, # cache amplification rate (should be bigger than 1 --> to sample for multiple time) + refresh_rate=0.4, #propotion of cache to be refresh, should be a positive float smaller than 0.5 + T=50, # refresh time + prob=None, + replace=False, + output_device=None, + exclude_eids=None, + mask=None, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + fused=True, + ): + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.g = g + self.fanouts = fanouts + self.edge_dir = edge_dir + self.amp_rate = amp_rate + self.refresh_rate = refresh_rate + self.replace = replace + self.output_device = output_device + self.exclude_eids = exclude_eids + + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.prob = prob or mask + self.fused = fused + self.mapping = {} + # self.cache_size = [fanout * amp_rate for fanout in fanouts] + self.T = T + # self.cached_graph_structures = [self.initialize_cache(cache_size) for cache_size in self.cache_size] + self.cycle = 0 + + self.shared_cache_size = max(self.fanouts)*self.amp_rate + self.shared_cache = self.initialize_cache(self.shared_cache_size) + + def initialize_cache(self, fanout_cache_storage): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + print("end init cache") + return cached_graph + + def OTF_refresh_cache(self,layer_id, cached_graph_structure, seed_nodes, fanout_cache_refresh): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + all_nodes = torch.arange(0, self.g.number_of_nodes()) + mask = ~torch.isin(all_nodes, seed_nodes) + # use bool mask to select those nodes in all nodes but not in seed_nodes + unchanged_nodes = all_nodes[mask] + fanout_cache_sample = self.shared_cache_size-fanout_cache_refresh + # unchanged_nodes = range(torch.arange(0, self.g.number_of_nodes()))-seed_nodes + # the rest node structure remain the same + unchanged_structure = cached_graph_structure.sample_neighbors( + unchanged_nodes, + self.shared_cache_size, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + # the OTF node structure should + changed_cache_remain = cached_graph_structure.sample_neighbors( + seed_nodes, + fanout_cache_sample, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + changed_disk_to_add = self.g.sample_neighbors( + seed_nodes, + fanout_cache_refresh, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + refreshed_cache = dgl.merge([unchanged_structure, changed_cache_remain, changed_disk_to_add]) + del unchanged_structure, changed_cache_remain, changed_disk_to_add + return refreshed_cache + + def sample_blocks(self, g, seed_nodes, exclude_eids=None): + """ + Samples blocks for GNN layers by combining cached samples with dynamically sampled + neighbors. This method also partially refreshes the cache based on specified parameters + to balance between sampling efficiency and the freshness of the samples. + """ + blocks = [] + output_nodes = seed_nodes + self.cycle += 1 + for i, (fanout) in enumerate(reversed(self.fanouts)): + fanout_cache_refresh = int(fanout * self.refresh_rate) + + # Refresh cache partially + if((self.cycle % self.T) ==0): + self.shared_cache = self.OTF_refresh_cache(i, self.shared_cache, seed_nodes, fanout_cache_refresh) + + # Sample from cache + frontier_cache = self.shared_cache.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + merged_frontier = frontier_cache + + # Convert the merged frontier to a block + block = to_block(merged_frontier, seed_nodes) + if EID in merged_frontier.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = merged_frontier.edata[EID] + blocks.append(block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + return seed_nodes,output_nodes, blocks + +class NeighborSampler_OTF_struct_PSCRFCF(BlockSampler): + """ + Implements an on-the-fly (OTF) neighbor sampling strategy for Deep Graph Library (DGL) graphs. + This sampler dynamically samples neighbors while balancing efficiency through caching and + freshness of samples by periodically refreshing parts of the cache. It supports specifying + fanouts, sampling direction, and probabilities, along with cache management parameters to + control the trade-offs between sampling efficiency and cache freshness. + + As for the parameters explanations, + 1. amp_rate: sample a larger cache than the original cache to store the local structure + 2. refresh_rate: decide how many portion should be sampled from disk, and the remaining comes out from cache, then combine them as new disk + 3. T: decide how long time will the cache to refresh and store the new structure (refresh mode in OTF is partially refresh) + """ + + def __init__(self, g, + fanouts, + edge_dir='in', + amp_rate=1.5, # cache amplification rate (should be bigger than 1 --> to sample for multiple time) + refresh_rate=0.4, #propotion of cache to be refresh, should be a positive float smaller than 0.5 + T=50, # refresh time + prob=None, + replace=False, + output_device=None, + exclude_eids=None, + mask=None, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + fused=True, + ): + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.g = g + self.fanouts = fanouts + self.edge_dir = edge_dir + self.amp_rate = amp_rate + self.refresh_rate = refresh_rate + self.replace = replace + self.output_device = output_device + self.exclude_eids = exclude_eids + + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.prob = prob or mask + self.fused = fused + self.mapping = {} + self.cache_size = [fanout * amp_rate for fanout in fanouts] + self.T = T + self.cached_graph_structures = [self.initialize_cache(cache_size) for cache_size in self.cache_size] + self.cycle = 0 + + def initialize_cache(self, fanout_cache_storage): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + torch.arange(0, self.g.number_of_nodes()), + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + print("end init cache") + return cached_graph + + def OTF_refresh_cache(self,layer_id, cached_graph_structure, seed_nodes, fanout_cache_refresh): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + fanout_cache_sample = self.cache_size[layer_id]-fanout_cache_refresh + # unchanged_nodes = range(torch.arange(0, self.g.number_of_nodes()))-seed_nodes + all_nodes = torch.arange(0, self.g.number_of_nodes()) + mask = ~torch.isin(all_nodes, seed_nodes) + # 使用布尔掩码来选择不在seed_nodes中的节点 + unchanged_nodes = all_nodes[mask] + # the rest node structure remain the same + unchanged_structure = cached_graph_structure.sample_neighbors( + unchanged_nodes, + self.cache_size[layer_id], + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + # the OTF node structure should + changed_cache_remain = cached_graph_structure.sample_neighbors( + seed_nodes, + fanout_cache_sample, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + changed_disk_to_add = self.g.sample_neighbors( + seed_nodes, + fanout_cache_refresh, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + refreshed_cache = dgl.merge([unchanged_structure, changed_cache_remain, changed_disk_to_add]) + del unchanged_structure, changed_cache_remain, changed_disk_to_add + return refreshed_cache + + def sample_blocks(self, g, seed_nodes, exclude_eids=None): + """ + Samples blocks for GNN layers by combining cached samples with dynamically sampled + neighbors. This method also partially refreshes the cache based on specified parameters + to balance between sampling efficiency and the freshness of the samples. + """ + blocks = [] + output_nodes = seed_nodes + self.cycle += 1 + for i, (fanout, cached_graph_structure) in enumerate(zip(reversed(self.fanouts), reversed(self.cached_graph_structures))): + fanout_cache_refresh = int(fanout * self.refresh_rate) + + # Refresh cache partially + if(self.cycle%self.T==0): + self.cached_graph_structures[i] = self.OTF_refresh_cache(i, cached_graph_structure, seed_nodes, fanout_cache_refresh) + + # Sample from cache + frontier_cache = self.cached_graph_structures[i].sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + merged_frontier = frontier_cache + + # Convert the merged frontier to a block + block = to_block(merged_frontier, seed_nodes) + if EID in merged_frontier.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = merged_frontier.edata[EID] + blocks.append(block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + return seed_nodes,output_nodes, blocks + + +class NeighborSampler_FCR_struct_hete(BlockSampler): + def __init__( + self, + g, + fanouts, + edge_dir='in', + alpha=2, + T=20, + hete_label=None, + prob=None, + mask=None, + replace=False, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + output_device=None, + fused=True, + ): + + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.fanouts = fanouts + self.edge_dir = edge_dir + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.g = g + self.prob = prob or mask + self.replace = replace + self.fused = fused + self.mapping = {} + + self.alpha = alpha + self.cycle = 0 # Initialize sampling cycle counter + self.amplified_fanouts = [f * alpha for f in fanouts] # Amplified fanouts for pre-sampling + self.T = T + self.Toptim = None # int(self.g.number_of_nodes() / max(self.amplified_fanouts)) + self.cache_struct = [] # Initialize cache structure + self.hete_label = hete_label + # self.cache_refresh(self.g) # Pre-sample and populate the cache + + def cache_refresh(self,g,exclude_eids=None): + """ + Pre-samples neighborhoods with amplified fanouts and refreshes the cache. This method + is automatically called upon initialization and after every T sampling iterations to + ensure that the cache is periodically updated with fresh samples. + """ + self.cache_struct.clear() # Clear existing cache + for fanout in self.amplified_fanouts: + # Sample neighbors for each layer with amplified fanout + frontier = g.sample_neighbors( + {self.hete_label:list(range(0, g.num_nodes(self.hete_label)))}, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids + ) + # frontier = dgl.add_self_loop(frontier) + self.cache_struct.append(frontier) # Update cache with new samples + + def sample_blocks(self, g,seed_nodes, exclude_eids=None): + """ + Samples blocks from the graph for the specified seed nodes using the cache. + + Parameters + ---------- + seed_nodes : Tensor + The nodes for which the neighborhoods are to be sampled. + + Returns + ------- + tuple + A tuple containing the seed nodes for the next layer, the output nodes, and + the list of blocks sampled from the graph. + """ + output_nodes = seed_nodes + + # refresh cache after a period of time for generalization + self.Toptim = int(g.number_of_nodes() / max(self.amplified_fanouts)) + if self.cycle % self.Toptim == 0: + self.cache_refresh(g) # Refresh cache every T cycles + + self.cycle += 1 + + blocks = [] + + if self.fused and get_num_threads() > 1: + # print("fused") + cpu = F.device_type(g.device) == "cpu" + if isinstance(seed_nodes, dict): + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = g.sample_neighbors_fused( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + + for k in range(len(self.cache_struct)-1,-1,-1): + cached_structure = self.cache_struct[k] + fanout = self.fanouts[k] + frontier = cached_structure.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids + ) + + # Sample frontier from the cache for acceleration + block = to_block(frontier, seed_nodes) + if EID in frontier.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier.edata[EID] + blocks.insert(0, block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + # output_nodes = seed_nodes + return seed_nodes, output_nodes, blocks + +# class NeighborSampler_FCR_struct_hete(BlockSampler): + +# def __init__( +# self, +# g, +# fanouts, +# edge_dir="in", +# alpha = 2, +# T = 20, +# hete_label = None, +# prob=None, +# mask=None, +# replace=False, +# prefetch_node_feats=None, +# prefetch_labels=None, +# prefetch_edge_feats=None, +# output_device=None, +# fused=True, +# ): +# super().__init__( +# prefetch_node_feats=prefetch_node_feats, +# prefetch_labels=prefetch_labels, +# prefetch_edge_feats=prefetch_edge_feats, +# output_device=output_device, +# ) +# self.g = g +# self.fanouts = fanouts +# self.edge_dir = edge_dir +# if mask is not None and prob is not None: +# raise ValueError( +# "Mask and probability arguments are mutually exclusive. " +# "Consider multiplying the probability with the mask " +# "to achieve the same goal." +# ) +# self.prob = prob or mask +# self.replace = replace +# self.fused = fused +# self.mapping = {} +# self.g = g +# self.cycle = 0 +# self.cached_structure = [] +# self.amplified_fanouts = [f * alpha for f in fanouts] # Amplified fanouts for pre-sampling +# self.T = T +# self.Toptim = None #int(self.g[self.hete_label].number_of_nodes() / max(self.amplified_fanouts)) +# self.hete_label = hete_label +# self.cache_refresh() + +# # print(self.g.number_of_nodes("paper")) + +# def cache_refresh(self,exclude_eids=None): +# for i in range(0,len(self.fanouts)): +# frontier1 = self.g.sample_neighbors( +# {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, +# self.amplified_fanouts[i], +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=exclude_eids, +# ) +# self.cached_structure.append(frontier1) + + +# def sample_blocks(self, g, seed_nodes, exclude_eids=None): +# output_nodes = seed_nodes +# blocks = [] +# # sample_neighbors_fused function requires multithreading to be more efficient +# # than sample_neighbors + +# self.Toptim = int(g.number_of_nodes() / max(self.amplified_fanouts)) + +# # self.cycle += 1 +# if(self.cycle%self.T == 0): +# self.cache_refresh(exclude_eids=exclude_eids) # refresh cache every T cycles + +# self.cycle += 1 + +# if self.fused and get_num_threads() > 1: +# cpu = F.device_type(g.device) == "cpu" +# if isinstance(seed_nodes, dict): +# for ntype in list(seed_nodes.keys()): +# print("seed dict",seed_nodes.keys) +# if not cpu: +# break +# cpu = ( +# cpu and F.device_type(seed_nodes[ntype].device) == "cpu" +# ) +# else: +# cpu = cpu and F.device_type(seed_nodes.device) == "cpu" +# if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": +# if self.g != g: +# self.mapping = {} +# self.g = g +# for fanout in reversed(self.fanouts): +# block = self.cached_structure[k].sample_neighbors_fused( +# seed_nodes, +# fanout, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# exclude_edges=exclude_eids, +# mapping=self.mapping, +# ) +# seed_nodes = block.srcdata[NID] +# blocks.insert(0, block) +# return seed_nodes, output_nodes, blocks + +# k = len(self.fanouts)-1 +# for fanout in reversed(self.fanouts): +# print("seeds nodes:",seed_nodes) +# print("org g:",g) +# frontier = self.cached_structure[k].sample_neighbors( +# seed_nodes, +# fanout, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=exclude_eids, +# ) +# k-=1 +# print("sampled frontier:",frontier) +# block = to_block(frontier, seed_nodes) +# # If sampled from graphbolt-backed DistGraph, `EID` may not be in +# # the block. +# if EID in frontier.edata.keys(): +# print("--------in this EID code---------") +# block.edata[EID] = frontier.edata[EID] +# seed_nodes = block.srcdata[NID] +# blocks.insert(0, block) + +# return seed_nodes, output_nodes, blocks + +class NeighborSampler_FCR_struct_shared_cache_hete(BlockSampler): + def __init__( + self, + g, + fanouts, + edge_dir='in', + alpha=2, + T=20, + hete_label=None, + prob=None, + mask=None, + replace=False, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + output_device=None, + fused=True, + ): + + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.fanouts = fanouts + self.edge_dir = edge_dir + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.g = g + self.prob = prob or mask + self.replace = replace + self.fused = fused + self.mapping = {} + + self.alpha = alpha + self.cycle = 0 # Initialize sampling cycle counter + self.sc_size = max([f * alpha for f in fanouts]) # shared cache_storage size + self.T = T + self.Toptim = None # int(self.g.number_of_nodes() / max(self.amplified_fanouts)) + self.shared_cache = None # Initialize cache structure + self.hete_label = hete_label + self.cache_refresh() # Pre-sample and populate the cache + self.Toptim = int(self.g.num_nodes(self.hete_label)/ self.sc_size ) + + def cache_refresh(self,exclude_eids=None): + """ + Pre-samples neighborhoods with amplified fanouts and refreshes the cache. This method + is automatically called upon initialization and after every T sampling iterations to + ensure that the cache is periodically updated with fresh samples. + """ + self.shared_cache = self.g.sample_neighbors( + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + self.sc_size, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids + ) + + def sample_blocks(self, g,seed_nodes, exclude_eids=None): + """ + Samples blocks from the graph for the specified seed nodes using the cache. + + Parameters + ---------- + seed_nodes : Tensor + The nodes for which the neighborhoods are to be sampled. + + Returns + ------- + tuple + A tuple containing the seed nodes for the next layer, the output nodes, and + the list of blocks sampled from the graph. + """ + output_nodes = seed_nodes + + # refresh cache after a period of time for generalization + if self.cycle % self.Toptim == 0: + self.cache_refresh() # Refresh cache every T cycles + + self.cycle += 1 + + blocks = [] + + if self.fused and get_num_threads() > 1: + # print("fused") + cpu = F.device_type(g.device) == "cpu" + if isinstance(seed_nodes, dict): + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = self.g.sample_neighbors_fused( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + + for k in range(len(self.fanouts)-1,-1,-1): + frontier = self.shared_cache.sample_neighbors( + seed_nodes, + self.fanouts[k], + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids + ) + + # Sample frontier from the cache for acceleration + block = to_block(frontier, seed_nodes) + if EID in frontier.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier.edata[EID] + blocks.insert(0, block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + # output_nodes = seed_nodes + return seed_nodes, output_nodes, blocks + +# class NeighborSampler_OTF_struct_FSCRFCF_hete(BlockSampler): + +# def __init__(self, g, +# fanouts, +# edge_dir='in', +# amp_rate=1.5, # cache amplification rate (should be bigger than 1 --> to sample for multiple time) +# refresh_rate=0.4, #propotion of cache to be refresh, should be a positive float smaller than 0.5 +# T=100, # refresh time +# hete_label = None, +# prob=None, +# replace=False, +# output_device=None, +# exclude_eids=None, +# mask=None, +# prefetch_node_feats=None, +# prefetch_labels=None, +# prefetch_edge_feats=None, +# fused=True, +# ): +# super().__init__( +# prefetch_node_feats=prefetch_node_feats, +# prefetch_labels=prefetch_labels, +# prefetch_edge_feats=prefetch_edge_feats, +# output_device=output_device, +# ) +# self.g = g +# self.fanouts = fanouts +# self.edge_dir = edge_dir +# self.amp_rate = amp_rate +# self.refresh_rate = refresh_rate +# self.replace = replace +# self.output_device = output_device +# self.exclude_eids = exclude_eids +# self.cycle = 0 + +# if mask is not None and prob is not None: +# raise ValueError( +# "Mask and probability arguments are mutually exclusive. " +# "Consider multiplying the probability with the mask " +# "to achieve the same goal." +# ) +# self.prob = prob or mask +# self.fused = fused +# self.mapping = {} +# self.cache_size = [int(fanout * amp_rate) for fanout in fanouts] +# self.T = T +# self.hete_label = hete_label +# self.Toptim = int(self.g.num_nodes(self.hete_label) / (max(self.cache_size))*self.amp_rate) +# self.cached_graph_structures = [self.initialize_cache(cache_size) for cache_size in self.cache_size] + +# def initialize_cache(self, fanout_cache_storage): +# """ +# Initializes the cache for each layer with an amplified fanout to pre-sample a larger +# set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling +# at every iteration, thereby improving efficiency. +# """ +# cached_graph = self.g.sample_neighbors( +# # torch.arange(0, self.g.number_of_nodes()), +# {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, +# fanout_cache_storage, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) +# print("end init cache") +# return cached_graph + +# def refresh_cache(self,layer_id, cached_graph_structure, fanout_cache_refresh): +# """ +# Refreshes a portion of the cache based on the gamma parameter by replacing some of the +# cached edges with new samples from the graph. This method ensures the cache remains +# relatively fresh and reflects changes in the dynamic graph structure or sampling needs. +# """ +# fanout_cache_sample = self.cache_size[layer_id]-fanout_cache_refresh +# cache_remain = cached_graph_structure.sample_neighbors( +# # torch.arange(0, self.g.number_of_nodes()), +# {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, +# fanout_cache_sample, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) + +# disk_to_add = self.g.sample_neighbors( +# # torch.arange(0, self.g.number_of_nodes()), +# {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, +# fanout_cache_refresh, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) + +# refreshed_cache = dgl.merge([cache_remain, disk_to_add]) +# print("end refresh cache") +# return refreshed_cache + +# def sample_blocks(self, g, seed_nodes, exclude_eids=None): +# """ +# Samples blocks for GNN layers by combining cached samples with dynamically sampled +# neighbors. This method also partially refreshes the cache based on specified parameters +# to balance between sampling efficiency and the freshness of the samples. +# """ +# blocks = [] +# output_nodes = seed_nodes +# self.cycle += 1 +# if((self.cycle % self.Toptim)==0): +# for i in range(0,len(self.cached_graph_structures)): +# # Refresh cache partially +# fanout_cache_refresh = int(self.cache_size[i] * self.refresh_rate) +# self.cached_graph_structures[i]=self.refresh_cache(i, self.cached_graph_structures[i], fanout_cache_refresh) + +# for i, (fanout, cached_graph_structure) in enumerate(zip(reversed(self.fanouts), reversed(self.cached_graph_structures))): +# # Sample from cache +# frontier_from_cache = self.cached_graph_structures[i].sample_neighbors( +# seed_nodes, +# #fanout_cache_retrieval, +# fanout, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) + +# # Convert the merged frontier to a block +# block = to_block(frontier_from_cache, seed_nodes) +# if EID in frontier_from_cache.edata.keys(): +# print("--------in this EID code---------") +# block.edata[EID] = frontier_from_cache.edata[EID] +# blocks.append(block) +# seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + +# return seed_nodes,output_nodes, blocks + +class NeighborSampler_OTF_struct_FSCRFCF_shared_cache_hete(BlockSampler): + + def __init__(self, g, + fanouts, + edge_dir='in', + amp_rate=1.5, # cache amplification rate (should be bigger than 1 --> to sample for multiple time) + refresh_rate=0.4, #propotion of cache to be refresh, should be a positive float smaller than 0.5 + T=100, # refresh time + hete_label = None, + prob=None, + replace=False, + output_device=None, + exclude_eids=None, + mask=None, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + fused=True, + ): + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.g = g + self.fanouts = fanouts + self.edge_dir = edge_dir + self.amp_rate = amp_rate + self.refresh_rate = refresh_rate + self.replace = replace + self.output_device = output_device + self.exclude_eids = exclude_eids + self.cycle = 0 + + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.prob = prob or mask + self.fused = fused + self.mapping = {} + self.amp_cache_size = [fanout * amp_rate for fanout in fanouts] + self.hete_label = hete_label + # self.Toptim = int(self.g.number_of_nodes() / (max(self.amp_cache_size))*self.amp_rate) + self.T = T + # self.cached_graph_structures = [self.initialize_cache(cache_size) for cache_size in self.cache_size] + + self.shared_cache_size = max(self.amp_cache_size) + self.shared_cache = self.initialize_cache(self.shared_cache_size) + self.Toptim = int(self.g.num_nodes(self.hete_label) / (self.shared_cache_size*self.amp_rate)) + + def initialize_cache(self, fanout_cache_storage): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + # mappings=self.mapping + ) + print("end init cache") + return cached_graph + + def refresh_cache(self, fanout_cache_refresh): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + fanout_cache_sample = self.shared_cache_size-fanout_cache_refresh + cache_remain = self.shared_cache.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_sample, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + disk_to_add = self.g.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_refresh, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + self.shared_cache = dgl.merge([cache_remain, disk_to_add]) + del cache_remain + del disk_to_add + print("end refresh cache") + + def sample_blocks(self, g, seed_nodes, exclude_eids=None): + """ + Samples blocks for GNN layers by combining cached samples with dynamically sampled + neighbors. This method also partially refreshes the cache based on specified parameters + to balance between sampling efficiency and the freshness of the samples. + """ + self.cycle += 1 + blocks = [] + output_nodes = seed_nodes + if((self.cycle % self.Toptim)==0): + # Refresh cache partially + fanout_cache_refresh = int(self.shared_cache_size * self.refresh_rate) + self.refresh_cache(fanout_cache_refresh) + + for i, (fanout) in enumerate(reversed(self.fanouts)): + # Sample from cache + frontier_from_cache = self.shared_cache.sample_neighbors( + seed_nodes, + #fanout_cache_retrieval, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + # Convert the merged frontier to a block + block = to_block(frontier_from_cache, seed_nodes) + if EID in frontier_from_cache.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier_from_cache.edata[EID] + blocks.append(block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + return seed_nodes,output_nodes, blocks + + +class NeighborSampler_OTF_refresh_struct_hete(BlockSampler): + def __init__( + self, + g, + fanouts, + edge_dir='in', + alpha=2, + T=20, + refresh_rate=0.4, + hete_label=None, + prob=None, + mask=None, + replace=False, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + output_device=None, + fused=True, + ): + + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.fanouts = fanouts + self.edge_dir = edge_dir + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.g = g + self.prob = prob or mask + self.replace = replace + self.fused = fused + self.mapping = {} + + self.alpha = alpha + self.cycle = 0 # Initialize sampling cycle counter + self.cache_size = [f * alpha for f in fanouts] # Amplified fanouts for pre-sampling + self.refresh_rate = refresh_rate + self.T = T + self.Toptim = None # int(self.g.number_of_nodes() / max(self.amplified_fanouts)) + # self.cache_struct = [] # Initialize cache structure + self.hete_label = hete_label + # self.cache_refresh(self.g) # Pre-sample and populate the cache + self.cached_struct = [self.initialize_cache(cache_size) for cache_size in self.cache_size] + + def initialize_cache(self, fanout_cache_storage, exclude_eids=None): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids, + ) + print("end init cache") + return cached_graph + + def refresh_cache(self,layer_id,fanout_cache_refresh,exclude_eids=None): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + fanout_cache_sample = self.cache_size[layer_id]-fanout_cache_refresh + cache_remain = self.cached_struct[layer_id].sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_sample, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids, + ) + + disk_to_add = self.g.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_refresh, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids, + ) + + refreshed_cache = dgl.merge([cache_remain, disk_to_add]) + print("end refresh cache") + return refreshed_cache + + def sample_blocks(self, g,seed_nodes, exclude_eids=None): + """ + Samples blocks from the graph for the specified seed nodes using the cache. + + Parameters + ---------- + seed_nodes : Tensor + The nodes for which the neighborhoods are to be sampled. + + Returns + ------- + tuple + A tuple containing the seed nodes for the next layer, the output nodes, and + the list of blocks sampled from the graph. + """ + output_nodes = seed_nodes + + # refresh cache after a period of time for generalization + self.Toptim = int(g.number_of_nodes() / max(self.cache_size)) + + self.cycle += 1 + # if self.cycle % self.Toptim == 0: + # self.refresh_cache(g) # Refresh cache every T cycles + if((self.cycle % self.Toptim)==0): + for i in range(0,len(self.cached_struct)): + # Refresh cache partially + fanout_cache_refresh = int(self.cache_size[i] * self.refresh_rate) + self.cached_struct[i]=self.refresh_cache(i, fanout_cache_refresh) + + blocks = [] + + if self.fused and get_num_threads() > 1: + # print("fused") + cpu = F.device_type(g.device) == "cpu" + if isinstance(seed_nodes, dict): + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = self.g.sample_neighbors_fused( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + + for k in range(len(self.cached_struct)-1,-1,-1): + cached_structure = self.cached_struct[k] + fanout = self.fanouts[k] + frontier = cached_structure.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids + ) + + # Sample frontier from the cache for acceleration + block = to_block(frontier, seed_nodes) + if EID in frontier.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier.edata[EID] + blocks.insert(0, block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + # output_nodes = seed_nodes + return seed_nodes, output_nodes, blocks + + +class NeighborSampler_OTF_refresh_struct_shared_cache_hete(BlockSampler): + def __init__( + self, + g, + fanouts, + edge_dir='in', + alpha=2, + T=20, + refresh_rate=0.4, + hete_label=None, + prob=None, + mask=None, + replace=False, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + output_device=None, + fused=True, + ): + + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.fanouts = fanouts + self.edge_dir = edge_dir + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.g = g + self.prob = prob or mask + self.replace = replace + self.fused = fused + self.mapping = {} + + self.alpha = alpha + self.cycle = 0 # Initialize sampling cycle counter + self.sc_size = max([f * alpha for f in fanouts]) # Amplified fanouts for pre-sampling + self.refresh_rate = refresh_rate + self.T = T + self.Toptim = None # int(self.g.number_of_nodes() / max(self.amplified_fanouts)) + # self.cache_struct = [] # Initialize cache structure + self.hete_label = hete_label + # self.cache_refresh(self.g) # Pre-sample and populate the cache + self.shared_cache = self.initialize_cache(self.sc_size) + + def initialize_cache(self, fanout_cache_storage, exclude_eids=None): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids, + ) + print("end init cache") + return cached_graph + + def refresh_cache(self,fanout_cache_refresh,exclude_eids=None): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + fanout_cache_sample = self.sc_size-fanout_cache_refresh + cache_remain = self.shared_cache.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_sample, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids, + ) + + disk_to_add = self.g.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_refresh, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids, + ) + + refreshed_cache = dgl.merge([cache_remain, disk_to_add]) + print("end refresh cache") + return refreshed_cache + + def sample_blocks(self, g,seed_nodes, exclude_eids=None): + """ + Samples blocks from the graph for the specified seed nodes using the cache. + + Parameters + ---------- + seed_nodes : Tensor + The nodes for which the neighborhoods are to be sampled. + + Returns + ------- + tuple + A tuple containing the seed nodes for the next layer, the output nodes, and + the list of blocks sampled from the graph. + """ + output_nodes = seed_nodes + + # refresh cache after a period of time for generalization + self.Toptim = int(g.number_of_nodes() / self.sc_size) + + self.cycle += 1 + # if self.cycle % self.Toptim == 0: + # self.refresh_cache(g) # Refresh cache every T cycles + if((self.cycle % self.Toptim)==0): + fanout_cache_refresh = int(self.sc_size * self.refresh_rate) + self.shared_cache=self.refresh_cache(fanout_cache_refresh) + + blocks = [] + + if self.fused and get_num_threads() > 1: + # print("fused") + cpu = F.device_type(g.device) == "cpu" + if isinstance(seed_nodes, dict): + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = self.g.sample_neighbors_fused( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + + for k in range(len(self.fanouts)-1,-1,-1): + cached_structure = self.shared_cache + fanout = self.fanouts[k] + frontier = cached_structure.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids + ) + + # Sample frontier from the cache for acceleration + block = to_block(frontier, seed_nodes) + if EID in frontier.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier.edata[EID] + blocks.insert(0, block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + # output_nodes = seed_nodes + return seed_nodes, output_nodes, blocks + + +# class NeighborSampler_OTF_fetch_struct_shared_cache_hete(BlockSampler): + +# def __init__(self, g, +# fanouts, +# edge_dir='in', +# amp_rate=1.5, # cache amplification rate (should be bigger than 1 --> to sample for multiple time) +# fetch_rate=0.4, #propotion of cache to be fetch from cache, should be a positive float smaller than 0.5 +# T_fetch=3, # fetch period of time +# T_refresh=None, # refresh time +# hete_label=None, +# prob=None, +# replace=False, +# output_device=None, +# exclude_eids=None, +# mask=None, +# prefetch_node_feats=None, +# prefetch_labels=None, +# prefetch_edge_feats=None, +# fused=True, +# ): +# super().__init__( +# prefetch_node_feats=prefetch_node_feats, +# prefetch_labels=prefetch_labels, +# prefetch_edge_feats=prefetch_edge_feats, +# output_device=output_device, +# ) +# self.g = g +# self.fanouts = fanouts +# self.edge_dir = edge_dir +# self.amp_rate = amp_rate +# self.fetch_rate = fetch_rate +# self.hete_label = hete_label +# self.replace = replace +# self.output_device = output_device +# self.exclude_eids = exclude_eids + +# if mask is not None and prob is not None: +# raise ValueError( +# "Mask and probability arguments are mutually exclusive. " +# "Consider multiplying the probability with the mask " +# "to achieve the same goal." +# ) +# self.prob = prob or mask +# self.fused = fused +# self.mapping = {} +# self.amp_cache_size = [fanout * amp_rate for fanout in fanouts] +# if T_refresh!=None: +# self.T_refresh = T_refresh +# else: +# self.T_refresh = int(self.g.number_of_nodes()/max(self.fanouts) *self.amp_rate) +# self.T_fetch = T_fetch +# # self.cached_graph_structures = None +# self.cycle = 0 + +# self.shared_cache_size = max(self.amp_cache_size) +# self.shared_cache = self.full_cache_refresh(self.shared_cache_size) +# self.hete_label = hete_label + +# def full_cache_refresh(self, fanout_cache_storage): +# cached_graph = self.g.sample_neighbors( +# # torch.arange(0, self.g.number_of_nodes()), +# {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, +# fanout_cache_storage, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) +# print("cache refresh") +# return cached_graph + +# def OTF_fetch(self,layer_id, seed_nodes, fanout_cache_fetch): +# print("OTF fetch cache") +# if(fanout_cache_fetch==self.fanouts[layer_id]): +# cache_fetch = self.shared_cache.sample_neighbors( +# seed_nodes, +# fanout_cache_fetch, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) +# return cache_fetch +# else: +# fanout_disk_fetch = self.fanouts[layer_id]-fanout_cache_fetch +# cache_fetch = self.shared_cache.sample_neighbors( +# seed_nodes, +# fanout_cache_fetch, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) + +# disk_fetch = self.g.sample_neighbors( +# seed_nodes, +# fanout_disk_fetch, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) + +# OTF_fetch_res = dgl.merge([cache_fetch, disk_fetch]) +# return OTF_fetch_res + +# def sample_blocks(self, g, seed_nodes, exclude_eids=None): +# blocks = [] +# output_nodes = seed_nodes + +# self.cycle += 1 +# print("self.T_refresh=",self.T_refresh) +# # refresh full cache after a period of time +# if((self.cycle%self.T_refresh)==0): +# self.shared_cache = self.full_cache_refresh(self.shared_cache_size) +# # self.cached_graph_structures = [self.full_cache_refresh(cache_size) for cache_size in self.cache_size] + +# for i, (fanout) in enumerate(reversed(self.fanouts)): +# fanout_cache_fetch = int(fanout * self.fetch_rate) + +# # fetch cache partially +# if((self.cycle%self.T_fetch)==0): +# frontier_OTF = self.OTF_fetch(i, seed_nodes, fanout_cache_fetch) +# else: +# #frontier_OTF = self.OTF_fetch(i, seed_nodes, self.fanouts[i]) +# frontier_OTF = self.shared_cache.sample_neighbors( +# seed_nodes, +# fanout, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) + +# # Convert the merged frontier to a block +# block = to_block(frontier_OTF, seed_nodes) +# if EID in frontier_OTF.edata.keys(): +# print("--------in this EID code---------") +# block.edata[EID] = frontier_OTF.edata[EID] +# blocks.append(block) +# seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + +# return seed_nodes,output_nodes, blocks + +class NeighborSampler_OTF_fetch_struct_hete(BlockSampler): + def __init__( + self, + g, + fanouts, + edge_dir='in', + amp_rate=2, + fetch_rate = 0.4, + T_refresh=None, + T_fetch=3, # fetch period of time + hete_label=None, + prob=None, + mask=None, + replace=False, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + output_device=None, + fused=True, + ): + + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.fanouts = fanouts + self.edge_dir = edge_dir + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.g = g + self.prob = prob or mask + self.replace = replace + self.fused = fused + self.mapping = {} + self.exclude_eids = None + + self.alpha = amp_rate + self.cycle = 0 # Initialize sampling cycle counter + self.cache_size = [f * amp_rate for f in fanouts] # Amplified fanouts for pre-sampling + if T_refresh!=None: + self.T_refresh = T_refresh + else: + self.T_refresh = int(self.g.number_of_nodes()/max(self.fanouts) *self.amp_rate) + self.T_fetch = T_fetch + self.Toptim = None # int(self.g.number_of_nodes() / max(self.amplified_fanouts)) + self.fetch_rate = fetch_rate + # self.cache_struct = [] # Initialize cache structure + self.hete_label = hete_label + # self.cache_refresh(self.g) # Pre-sample and populate the cache + self.cache_struct = [self.full_cache_refresh(cache_size) for cache_size in self.cache_size] + + def full_cache_refresh(self, fanout_cache_storage, exclude_eids = None): + cached_graph = self.g.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids, + ) + print("cache refresh") + return cached_graph + + def OTF_fetch(self,layer_id, seed_nodes, fanout_cache_fetch, exclude_eids = None): + print("OTF fetch cache") + if(fanout_cache_fetch==self.fanouts[layer_id]): + cache_fetch = self.cache_struct[layer_id].sample_neighbors( + seed_nodes, + fanout_cache_fetch, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids, + ) + return cache_fetch + else: + fanout_disk_fetch = self.fanouts[layer_id]-fanout_cache_fetch + cache_fetch = self.cache_struct[layer_id].sample_neighbors( + seed_nodes, + fanout_cache_fetch, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + disk_fetch = self.g.sample_neighbors( + seed_nodes, + fanout_disk_fetch, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + OTF_fetch_res = dgl.merge([cache_fetch, disk_fetch]) + return OTF_fetch_res + + def sample_blocks(self, g,seed_nodes, exclude_eids=None): + """ + Samples blocks from the graph for the specified seed nodes using the cache. + + Parameters + ---------- + seed_nodes : Tensor + The nodes for which the neighborhoods are to be sampled. + + Returns + ------- + tuple + A tuple containing the seed nodes for the next layer, the output nodes, and + the list of blocks sampled from the graph. + """ + output_nodes = seed_nodes + + self.cycle += 1 + print("self.T_refresh=",self.T_refresh) + # refresh full cache after a period of time + if((self.cycle%self.T_refresh)==0): + for i in range(len(self.fanouts)): + self.cache_struct[i] = self.full_cache_refresh(self.cache_size[i]) + + blocks = [] + + if self.fused and get_num_threads() > 1: + # print("fused") + cpu = F.device_type(g.device) == "cpu" + if isinstance(seed_nodes, dict): + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = self.g.sample_neighbors_fused( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + + for k in range(len(self.fanouts)-1,-1,-1): + fanout = self.fanouts[k] + + fanout_cache_fetch = int(fanout * self.fetch_rate) + + # fetch cache partially + if((self.cycle%self.T_fetch)==0): + frontier_OTF = self.OTF_fetch(k, seed_nodes, fanout_cache_fetch) + else: + #frontier_OTF = self.OTF_fetch(i, seed_nodes, self.fanouts[i]) + frontier_OTF = self.cache_struct[k].sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + # Sample frontier from the cache for acceleration + block = to_block(frontier_OTF, seed_nodes) + if EID in frontier_OTF.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier_OTF.edata[EID] + blocks.insert(0, block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + # output_nodes = seed_nodes + return seed_nodes, output_nodes, blocks + +class NeighborSampler_OTF_fetch_struct_shared_cache_hete(BlockSampler): + def __init__( + self, + g, + fanouts, + edge_dir='in', + amp_rate=2, + fetch_rate = 0.4, + T_refresh=None, + T_fetch=3, # fetch period of time + hete_label=None, + prob=None, + mask=None, + replace=False, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + output_device=None, + fused=True, + ): + + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.fanouts = fanouts + self.edge_dir = edge_dir + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.g = g + self.prob = prob or mask + self.replace = replace + self.fused = fused + self.mapping = {} + self.exclude_eids = None + + self.alpha = amp_rate + self.cycle = 0 # Initialize sampling cycle counter + self.sc_size = max([f * amp_rate for f in fanouts]) # Amplified fanouts for pre-sampling + if T_refresh!=None: + self.T_refresh = T_refresh + else: + self.T_refresh = int(self.g.number_of_nodes()/max(self.fanouts) *self.amp_rate) + self.T_fetch = T_fetch + self.Toptim = None # int(self.g.number_of_nodes() / max(self.amplified_fanouts)) + self.fetch_rate = fetch_rate + # self.cache_struct = [] # Initialize cache structure + self.hete_label = hete_label + # self.cache_refresh(self.g) # Pre-sample and populate the cache + self.shared_cache = self.full_cache_refresh(self.sc_size) + + def full_cache_refresh(self, fanout_cache_storage, exclude_eids = None): + cached_graph = self.g.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids, + ) + print("cache refresh") + return cached_graph + + def OTF_fetch(self,layer_id, seed_nodes, fanout_cache_fetch, exclude_eids = None): + print("OTF fetch cache") + if(fanout_cache_fetch==self.fanouts[layer_id]): + cache_fetch = self.shared_cache.sample_neighbors( + seed_nodes, + fanout_cache_fetch, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids, + ) + return cache_fetch + else: + fanout_disk_fetch = self.fanouts[layer_id]-fanout_cache_fetch + cache_fetch = self.shared_cache.sample_neighbors( + seed_nodes, + fanout_cache_fetch, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + disk_fetch = self.g.sample_neighbors( + seed_nodes, + fanout_disk_fetch, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + OTF_fetch_res = dgl.merge([cache_fetch, disk_fetch]) + return OTF_fetch_res + + def sample_blocks(self, g,seed_nodes, exclude_eids=None): + """ + Samples blocks from the graph for the specified seed nodes using the cache. + + Parameters + ---------- + seed_nodes : Tensor + The nodes for which the neighborhoods are to be sampled. + + Returns + ------- + tuple + A tuple containing the seed nodes for the next layer, the output nodes, and + the list of blocks sampled from the graph. + """ + output_nodes = seed_nodes + + self.cycle += 1 + print("self.T_refresh=",self.T_refresh) + # refresh full cache after a period of time + if((self.cycle%self.T_refresh)==0): + self.shared_cache = self.full_cache_refresh(self.sc_size) + + blocks = [] + + if self.fused and get_num_threads() > 1: + # print("fused") + cpu = F.device_type(g.device) == "cpu" + if isinstance(seed_nodes, dict): + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = self.g.sample_neighbors_fused( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + + for k in range(len(self.fanouts)-1,-1,-1): + fanout = self.fanouts[k] + + fanout_cache_fetch = int(fanout * self.fetch_rate) + + # fetch cache partially + if((self.cycle%self.T_fetch)==0): + frontier_OTF = self.OTF_fetch(k, seed_nodes, fanout_cache_fetch) + else: + #frontier_OTF = self.OTF_fetch(i, seed_nodes, self.fanouts[i]) + frontier_OTF = self.shared_cache.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + # Sample frontier from the cache for acceleration + block = to_block(frontier_OTF, seed_nodes) + if EID in frontier_OTF.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier_OTF.edata[EID] + blocks.insert(0, block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + # output_nodes = seed_nodes + return seed_nodes, output_nodes, blocks + + +# class NeighborSampler_OTF_struct_PCFPSCR_hete(BlockSampler): + +# def __init__(self, g, +# fanouts, +# edge_dir='in', +# amp_rate=1.5, # cache amplification rate (should be bigger than 1 --> to sample for multiple time) +# refresh_rate=0.4, #propotion of cache to be refresh, should be a positive float smaller than 0.5 +# T=50, # refresh time, for example +# hete_label = None, +# prob=None, +# replace=False, +# output_device=None, +# exclude_eids=None, +# mask=None, +# prefetch_node_feats=None, +# prefetch_labels=None, +# prefetch_edge_feats=None, +# fused=True, +# ): +# super().__init__( +# prefetch_node_feats=prefetch_node_feats, +# prefetch_labels=prefetch_labels, +# prefetch_edge_feats=prefetch_edge_feats, +# output_device=output_device, +# ) +# self.g = g +# self.fanouts = fanouts +# self.edge_dir = edge_dir +# self.amp_rate = amp_rate +# self.refresh_rate = refresh_rate +# self.hete_label = hete_label +# self.replace = replace +# self.output_device = output_device +# self.exclude_eids = exclude_eids + +# if mask is not None and prob is not None: +# raise ValueError( +# "Mask and probability arguments are mutually exclusive. " +# "Consider multiplying the probability with the mask " +# "to achieve the same goal." +# ) +# self.prob = prob or mask +# self.fused = fused +# self.mapping = {} +# self.cache_size = [fanout * amp_rate for fanout in fanouts] +# self.T = T +# self.cached_graph_structures = [self.initialize_cache(cache_size) for cache_size in self.cache_size] + +# def initialize_cache(self, fanout_cache_storage): +# """ +# Initializes the cache for each layer with an amplified fanout to pre-sample a larger +# set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling +# at every iteration, thereby improving efficiency. +# """ +# cached_graph = self.g.sample_neighbors( +# # torch.arange(0, self.g.number_of_nodes()), +# {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, +# fanout_cache_storage, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) +# print("end init cache") +# return cached_graph + +# def OTF_rf_cache(self,layer_id, cached_graph_structure, seed_nodes, fanout_cache_refresh, fanout): +# """ +# Refreshes a portion of the cache based on the gamma parameter by replacing some of the +# cached edges with new samples from the graph. This method ensures the cache remains +# relatively fresh and reflects changes in the dynamic graph structure or sampling needs. +# """ +# fanout_cache_remain = self.cache_size[layer_id]-fanout_cache_refresh +# fanout_cache_pr = fanout-fanout_cache_refresh +# # unchanged_nodes = range(torch.arange(0, self.g.number_of_nodes()))-seed_nodes +# # the rest node structure remain the same +# all_nodes = torch.arange(0, self.g.num_nodes(self.hete_label)) +# print("seed nodes:",seed_nodes) +# print("all nodes",all_nodes) +# mask = ~torch.isin(all_nodes, seed_nodes[self.hete_label]) +# # bool mask to select those nodes do not in seed_nodes +# unchanged_nodes = {self.hete_label: all_nodes[mask]} +# unchanged_structure = cached_graph_structure.sample_neighbors( +# unchanged_nodes, +# self.cache_size[layer_id], +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) +# # the OTF node structure should +# changed_cache_remain = cached_graph_structure.sample_neighbors( +# seed_nodes, +# fanout_cache_remain, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) +# cache_pr = cached_graph_structure.sample_neighbors( +# seed_nodes, +# fanout_cache_pr, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) +# changed_disk_to_add = self.g.sample_neighbors( +# seed_nodes, +# fanout_cache_refresh, +# edge_dir=self.edge_dir, +# prob=self.prob, +# replace=self.replace, +# output_device=self.output_device, +# exclude_edges=self.exclude_eids, +# ) +# refreshed_cache = dgl.merge([unchanged_structure, changed_cache_remain, changed_disk_to_add]) +# retrieval_cache = dgl.merge([cache_pr, changed_disk_to_add]) +# return refreshed_cache, retrieval_cache + +# def sample_blocks(self, g, seed_nodes, exclude_eids=None): +# """ +# Samples blocks for GNN layers by combining cached samples with dynamically sampled +# neighbors. This method also partially refreshes the cache based on specified parameters +# to balance between sampling efficiency and the freshness of the samples. +# """ +# blocks = [] +# output_nodes = seed_nodes +# for i, (fanout, cached_graph_structure) in enumerate(zip(reversed(self.fanouts), reversed(self.cached_graph_structures))): +# fanout_cache_refresh = int(fanout * self.refresh_rate) + +# # Refresh cache&disk partially, while retrieval cache&disk partially +# self.cached_graph_structures[i], frontier_comp = self.OTF_rf_cache(i, cached_graph_structure, seed_nodes, fanout_cache_refresh, fanout) + +# # Convert the merged frontier to a block +# block = to_block(frontier_comp, seed_nodes) +# if EID in frontier_comp.edata.keys(): +# print("--------in this EID code---------") +# block.edata[EID] = frontier_comp.edata[EID] +# blocks.append(block) +# seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + +# return seed_nodes,output_nodes, blocks + + +class NeighborSampler_OTF_struct_PCFPSCR_hete(BlockSampler): + def __init__( + self, + g, + fanouts, + edge_dir='in', + amp_rate=2, + T=20, + refresh_rate=0.4, + hete_label=None, + prob=None, + mask=None, + replace=False, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + output_device=None, + fused=True, + ): + + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.fanouts = fanouts + self.edge_dir = edge_dir + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.g = g + self.prob = prob or mask + self.replace = replace + self.fused = fused + self.mapping = {} + self.exclude_eids = None + + self.amp_rate = amp_rate + self.hete_label = hete_label + self.cycle = 0 # Initialize sampling cycle counter + self.amplified_fanouts = [f * self.amp_rate for f in fanouts] # Amplified fanouts for pre-sampling + self.T = T + self.refresh_rate = refresh_rate + self.Toptim = None # int(self.g.number_of_nodes() / max(self.amplified_fanouts)) + self.cache_struct = [self.initialize_cache(fanout_cache_storage=ampf) for ampf in self.amplified_fanouts] # Initialize cache structure + # self.cache_refresh(self.g) # Pre-sample and populate the cache + + def initialize_cache(self, fanout_cache_storage): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + print("end init cache") + return cached_graph + + def OTF_rf_cache(self,layer_id, cached_graph_structure, seed_nodes, fanout_cache_refresh, fanout): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + fanout_cache_remain = self.amplified_fanouts[layer_id]-fanout_cache_refresh + fanout_cache_pr = fanout-fanout_cache_refresh + # unchanged_nodes = range(torch.arange(0, self.g.number_of_nodes()))-seed_nodes + # the rest node structure remain the same + all_nodes = torch.arange(0, self.g.num_nodes(self.hete_label)) + print("seed nodes:",seed_nodes) + print("all nodes",all_nodes) + mask = ~torch.isin(all_nodes, seed_nodes[self.hete_label]) + # bool mask to select those nodes do not in seed_nodes + unchanged_nodes = {self.hete_label: all_nodes[mask]} + unchanged_structure = cached_graph_structure.sample_neighbors( + unchanged_nodes, + self.amplified_fanouts[layer_id], + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + # the OTF node structure should + changed_cache_remain = cached_graph_structure.sample_neighbors( + seed_nodes, + fanout_cache_remain, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + cache_pr = cached_graph_structure.sample_neighbors( + seed_nodes, + fanout_cache_pr, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + changed_disk_to_add = self.g.sample_neighbors( + seed_nodes, + fanout_cache_refresh, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + refreshed_cache = dgl.merge([unchanged_structure, changed_cache_remain, changed_disk_to_add]) + retrieval_cache = dgl.merge([cache_pr, changed_disk_to_add]) + return refreshed_cache, retrieval_cache + + def sample_blocks(self, g,seed_nodes, exclude_eids=None): + """ + Samples blocks from the graph for the specified seed nodes using the cache. + + Parameters + ---------- + seed_nodes : Tensor + The nodes for which the neighborhoods are to be sampled. + + Returns + ------- + tuple + A tuple containing the seed nodes for the next layer, the output nodes, and + the list of blocks sampled from the graph. + """ + output_nodes = seed_nodes + + # # refresh cache after a period of time for generalization + # self.Toptim = int(g.number_of_nodes() / max(self.amplified_fanouts)) + # if self.cycle % self.Toptim == 0: + # self.cache_refresh(g) # Refresh cache every T cycles + + self.cycle += 1 + + blocks = [] + + if self.fused and get_num_threads() > 1: + # print("fused") + cpu = F.device_type(g.device) == "cpu" + if isinstance(seed_nodes, dict): + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = g.sample_neighbors_fused( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + + for k in range(len(self.cache_struct)-1,-1,-1): + fanout_cache_refresh = int(self.fanouts[k] * self.refresh_rate) + + # Refresh cache&disk partially, while retrieval cache&disk partially + self.cache_struct[k], frontier_comp = self.OTF_rf_cache(k, self.cache_struct[k], seed_nodes, fanout_cache_refresh, self.fanouts[k]) + + # Sample frontier from the cache for acceleration + block = to_block(frontier_comp, seed_nodes) + if EID in frontier_comp.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier_comp.edata[EID] + blocks.insert(0, block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + # output_nodes = seed_nodes + return seed_nodes, output_nodes, blocks + + +class NeighborSampler_OTF_struct_PCFPSCR_shared_cache_hete(BlockSampler): + def __init__( + self, + g, + fanouts, + edge_dir='in', + amp_rate=2, + T=20, + refresh_rate=0.4, + hete_label=None, + prob=None, + mask=None, + replace=False, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + output_device=None, + fused=True, + ): + + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.fanouts = fanouts + self.edge_dir = edge_dir + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.g = g + self.prob = prob or mask + self.replace = replace + self.fused = fused + self.mapping = {} + self.exclude_eids = None + + self.amp_rate = amp_rate + self.hete_label = hete_label + self.cycle = 0 # Initialize sampling cycle counter + self.sc_size = max([f * self.amp_rate for f in fanouts]) # Amplified fanouts for pre-sampling + self.T = T + self.refresh_rate = refresh_rate + self.Toptim = None # int(self.g.number_of_nodes() / max(self.amplified_fanouts)) + self.shared_cache = self.initialize_cache(fanout_cache_storage=self.sc_size) # Initialize cache structure + # self.cache_refresh(self.g) # Pre-sample and populate the cache + + def initialize_cache(self, fanout_cache_storage): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + print("end init cache") + return cached_graph + + def OTF_rf_cache(self,seed_nodes, fanout_cache_refresh, fanout): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + fanout_cache_remain = self.sc_size-fanout_cache_refresh + fanout_cache_pr = fanout-fanout_cache_refresh + # unchanged_nodes = range(torch.arange(0, self.g.number_of_nodes()))-seed_nodes + # the rest node structure remain the same + all_nodes = torch.arange(0, self.g.num_nodes(self.hete_label)) + print("seed nodes:",seed_nodes) + print("all nodes",all_nodes) + mask = ~torch.isin(all_nodes, seed_nodes[self.hete_label]) + # bool mask to select those nodes do not in seed_nodes + unchanged_nodes = {self.hete_label: all_nodes[mask]} + unchanged_structure = self.shared_cache.sample_neighbors( + unchanged_nodes, + self.sc_size, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + # the OTF node structure should + changed_cache_remain = self.shared_cache.sample_neighbors( + seed_nodes, + fanout_cache_remain, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + cache_pr = self.shared_cache.sample_neighbors( + seed_nodes, + fanout_cache_pr, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + changed_disk_to_add = self.g.sample_neighbors( + seed_nodes, + fanout_cache_refresh, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + refreshed_cache = dgl.merge([unchanged_structure, changed_cache_remain, changed_disk_to_add]) + retrieval_cache = dgl.merge([cache_pr, changed_disk_to_add]) + return refreshed_cache, retrieval_cache + + def sample_blocks(self, g,seed_nodes, exclude_eids=None): + """ + Samples blocks from the graph for the specified seed nodes using the cache. + + Parameters + ---------- + seed_nodes : Tensor + The nodes for which the neighborhoods are to be sampled. + + Returns + ------- + tuple + A tuple containing the seed nodes for the next layer, the output nodes, and + the list of blocks sampled from the graph. + """ + output_nodes = seed_nodes + + # # refresh cache after a period of time for generalization + # self.Toptim = int(g.number_of_nodes() / max(self.amplified_fanouts)) + # if self.cycle % self.Toptim == 0: + # self.cache_refresh(g) # Refresh cache every T cycles + + self.cycle += 1 + + blocks = [] + + if self.fused and get_num_threads() > 1: + # print("fused") + cpu = F.device_type(g.device) == "cpu" + if isinstance(seed_nodes, dict): + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = g.sample_neighbors_fused( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + + for k in range(len(self.fanouts)-1,-1,-1): + fanout_cache_refresh = int(self.fanouts[k] * self.refresh_rate) + + # Refresh cache&disk partially, while retrieval cache&disk partially + self.shared_cache, frontier_comp = self.OTF_rf_cache( seed_nodes, fanout_cache_refresh, self.fanouts[k]) + + # Sample frontier from the cache for acceleration + block = to_block(frontier_comp, seed_nodes) + if EID in frontier_comp.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier_comp.edata[EID] + blocks.insert(0, block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + # output_nodes = seed_nodes + return seed_nodes, output_nodes, blocks + + +class NeighborSampler_OTF_struct_PSCRFCF_hete(BlockSampler): + def __init__( + self, + g, + fanouts, + edge_dir='in', + amp_rate=2, + T=20, + refresh_rate=0.4, + hete_label=None, + prob=None, + mask=None, + replace=False, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + output_device=None, + fused=True, + ): + + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.fanouts = fanouts + self.edge_dir = edge_dir + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.g = g + self.prob = prob or mask + self.replace = replace + self.fused = fused + self.mapping = {} + self.exclude_eids = None + + self.amp_rate = amp_rate + self.hete_label = hete_label + self.cycle = 0 # Initialize sampling cycle counter + self.cache_size = [f * self.amp_rate for f in fanouts] # Amplified fanouts for pre-sampling + self.T = T + self.refresh_rate = refresh_rate + self.Toptim = None # int(self.g.number_of_nodes() / max(self.amplified_fanouts)) + self.cache_struct = [self.initialize_cache(fanout_cache_storage=ampf) for ampf in self.cache_size] # Initialize cache structure + # self.cache_refresh(self.g) # Pre-sample and populate the cache + + def initialize_cache(self, fanout_cache_storage): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + print("end init cache") + return cached_graph + + def OTF_refresh_cache(self,layer_id, cached_graph_structure, seed_nodes, fanout_cache_refresh): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + fanout_cache_sample = self.cache_size[layer_id]-fanout_cache_refresh + all_nodes = torch.arange(0, self.g.num_nodes(self.hete_label)) + print("seed nodes:",seed_nodes) + print("all nodes",all_nodes) + mask = ~torch.isin(all_nodes, seed_nodes[self.hete_label]) + # bool mask to select those nodes do not in seed_nodes + unchanged_nodes = {self.hete_label: all_nodes[mask]} + # the rest node structure remain the same + unchanged_structure = cached_graph_structure.sample_neighbors( + unchanged_nodes, + self.cache_size[layer_id], + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + # the OTF node structure should + changed_cache_remain = cached_graph_structure.sample_neighbors( + seed_nodes, + fanout_cache_sample, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + changed_disk_to_add = self.g.sample_neighbors( + seed_nodes, + fanout_cache_refresh, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + refreshed_cache = dgl.merge([unchanged_structure, changed_cache_remain, changed_disk_to_add]) + return refreshed_cache + + def sample_blocks(self, g,seed_nodes, exclude_eids=None): + """ + Samples blocks from the graph for the specified seed nodes using the cache. + + Parameters + ---------- + seed_nodes : Tensor + The nodes for which the neighborhoods are to be sampled. + + Returns + ------- + tuple + A tuple containing the seed nodes for the next layer, the output nodes, and + the list of blocks sampled from the graph. + """ + output_nodes = seed_nodes + + # # refresh cache after a period of time for generalization + # self.Toptim = int(g.number_of_nodes() / max(self.amplified_fanouts)) + # if self.cycle % self.Toptim == 0: + # self.cache_refresh(g) # Refresh cache every T cycles + + self.cycle += 1 + + blocks = [] + + if self.fused and get_num_threads() > 1: + # print("fused") + cpu = F.device_type(g.device) == "cpu" + if isinstance(seed_nodes, dict): + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = g.sample_neighbors_fused( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + + for k in range(len(self.cache_struct)-1,-1,-1): + fanout_cache_refresh = int(self.fanouts[k] * self.refresh_rate) + + # Refresh cache&disk partially, while retrieval cache&disk partially + self.cache_struct[k] = self.OTF_refresh_cache(k, self.cache_struct[k], seed_nodes, fanout_cache_refresh) + + # Sample from cache + frontier_cache = self.cache_struct[k].sample_neighbors( + seed_nodes, + self.fanouts[k], + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + # Sample frontier from the cache for acceleration + block = to_block(frontier_cache, seed_nodes) + if EID in frontier_cache.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier_cache.edata[EID] + blocks.insert(0, block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + # output_nodes = seed_nodes + return seed_nodes, output_nodes, blocks + + +class NeighborSampler_OTF_struct_PSCRFCF_shared_cache_hete(BlockSampler): + def __init__( + self, + g, + fanouts, + edge_dir='in', + amp_rate=2, + T=20, + refresh_rate=0.4, + hete_label=None, + prob=None, + mask=None, + replace=False, + prefetch_node_feats=None, + prefetch_labels=None, + prefetch_edge_feats=None, + output_device=None, + fused=True, + ): + + super().__init__( + prefetch_node_feats=prefetch_node_feats, + prefetch_labels=prefetch_labels, + prefetch_edge_feats=prefetch_edge_feats, + output_device=output_device, + ) + self.fanouts = fanouts + self.edge_dir = edge_dir + if mask is not None and prob is not None: + raise ValueError( + "Mask and probability arguments are mutually exclusive. " + "Consider multiplying the probability with the mask " + "to achieve the same goal." + ) + self.g = g + self.prob = prob or mask + self.replace = replace + self.fused = fused + self.mapping = {} + self.exclude_eids = None + + self.amp_rate = amp_rate + self.hete_label = hete_label + self.cycle = 0 # Initialize sampling cycle counter + self.sc_size = max([f * self.amp_rate for f in fanouts]) # Amplified fanouts for pre-sampling + self.T = T + self.refresh_rate = refresh_rate + self.Toptim = None # int(self.g.number_of_nodes() / max(self.amplified_fanouts)) + self.shared_cache = self.initialize_cache(fanout_cache_storage=self.sc_size) # Initialize cache structure + # self.cache_refresh(self.g) # Pre-sample and populate the cache + + def initialize_cache(self, fanout_cache_storage): + """ + Initializes the cache for each layer with an amplified fanout to pre-sample a larger + set of neighbors. This pre-sampling helps in reducing the need for dynamic sampling + at every iteration, thereby improving efficiency. + """ + cached_graph = self.g.sample_neighbors( + # torch.arange(0, self.g.number_of_nodes()), + {self.hete_label:list(range(0, self.g.num_nodes(self.hete_label)))}, + fanout_cache_storage, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + print("end init cache") + return cached_graph + + def OTF_refresh_cache(self, seed_nodes, fanout_cache_refresh): + """ + Refreshes a portion of the cache based on the gamma parameter by replacing some of the + cached edges with new samples from the graph. This method ensures the cache remains + relatively fresh and reflects changes in the dynamic graph structure or sampling needs. + """ + fanout_cache_sample = self.sc_size-fanout_cache_refresh + all_nodes = torch.arange(0, self.g.num_nodes(self.hete_label)) + print("seed nodes:",seed_nodes) + print("all nodes",all_nodes) + mask = ~torch.isin(all_nodes, seed_nodes[self.hete_label]) + # bool mask to select those nodes do not in seed_nodes + unchanged_nodes = {self.hete_label: all_nodes[mask]} + # the rest node structure remain the same + unchanged_structure = self.shared_cache.sample_neighbors( + unchanged_nodes, + self.sc_size, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + # the OTF node structure should + changed_cache_remain = self.shared_cache.sample_neighbors( + seed_nodes, + fanout_cache_sample, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + changed_disk_to_add = self.g.sample_neighbors( + seed_nodes, + fanout_cache_refresh, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + refreshed_cache = dgl.merge([unchanged_structure, changed_cache_remain, changed_disk_to_add]) + return refreshed_cache + + def sample_blocks(self, g,seed_nodes, exclude_eids=None): + """ + Samples blocks from the graph for the specified seed nodes using the cache. + + Parameters + ---------- + seed_nodes : Tensor + The nodes for which the neighborhoods are to be sampled. + + Returns + ------- + tuple + A tuple containing the seed nodes for the next layer, the output nodes, and + the list of blocks sampled from the graph. + """ + output_nodes = seed_nodes + + self.cycle += 1 + + blocks = [] + + if self.fused and get_num_threads() > 1: + # print("fused") + cpu = F.device_type(g.device) == "cpu" + if isinstance(seed_nodes, dict): + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = g.sample_neighbors_fused( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + + for k in range(len(self.fanouts)-1,-1,-1): + fanout_cache_refresh = int(self.fanouts[k] * self.refresh_rate) + + # Refresh cache&disk partially, while retrieval cache&disk partially + self.shared_cache = self.OTF_refresh_cache(seed_nodes, fanout_cache_refresh) + + # Sample from cache + frontier_cache = self.shared_cache.sample_neighbors( + seed_nodes, + self.fanouts[k], + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=self.exclude_eids, + ) + + # Sample frontier from the cache for acceleration + block = to_block(frontier_cache, seed_nodes) + if EID in frontier_cache.edata.keys(): + print("--------in this EID code---------") + block.edata[EID] = frontier_cache.edata[EID] + blocks.insert(0, block) + seed_nodes = block.srcdata[NID] # Update seed nodes for the next layer + + # output_nodes = seed_nodes + return seed_nodes, output_nodes, blocks