Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Aug 17, 2024
1 parent 503b71c commit a7f9b26
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 126 deletions.
9 changes: 9 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import datetime
import os.path as osp
import sys
Expand Down Expand Up @@ -38,3 +39,11 @@
'python': ('http://docs.python.org', None),
'torch': ('https://pytorch.org/docs/master', None),
}

typehints_use_rtype = False
typehints_defaults = 'comma'


def setup(app):
# Do not drop type hints in signatures:
del app.events.listeners['autodoc-process-signature']
4 changes: 2 additions & 2 deletions pyg_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ def load_library(lib_name: str) -> None:
load_library('libpyg')

import pyg_lib.ops # noqa
import pyg_lib.sampler # noqa
import pyg_lib.partition # noqa
import pyg_lib.sampler # noqa


def cuda_version() -> int:
r"""Returns the CUDA version for which :obj:`pyg_lib` was compiled with.
Returns:
(int): The CUDA version.
The CUDA version.
"""
return torch.ops.pyg.cuda_version()

Expand Down
4 changes: 2 additions & 2 deletions pyg_lib/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_home_dir() -> str:
variable :obj:`$PYG_LIB_HOME` which defaults to :obj:`"~/.cache/pyg_lib"`.
Returns:
(str): The cache directory.
The cache directory.
"""
if _home_dir is not None:
return _home_dir
Expand All @@ -29,7 +29,7 @@ def set_home_dir(path: str):
r"""Sets the cache directory used for storing all :obj:`pyg-lib` data.
Args:
path (str): The path to a local folder.
path: The path to a local folder.
"""
global _home_dir
_home_dir = path
98 changes: 39 additions & 59 deletions pyg_lib/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,12 @@ def grouped_matmul(
assert outs[1] == inputs[1] @ others[1]
Args:
inputs (List[torch.Tensor]): List of left operand 2D matrices of shapes
:obj:`[N_i, K_i]`.
others (List[torch.Tensor]): List of right operand 2D matrices of
shapes :obj:`[K_i, M_i]`.
biases (List[torch.Tensor], optional): Optional bias terms to apply for
each element. (default: :obj:`None`)
inputs: List of left operand 2D matrices of shapes :obj:`[N_i, K_i]`.
others: List of right operand 2D matrices of shapes :obj:`[K_i, M_i]`.
biases: Optional bias terms to apply for each element.
Returns:
List[torch.Tensor]: List of 2D output matrices of shapes
:obj:`[N_i, M_i]`.
List of 2D output matrices of shapes :obj:`[N_i, M_i]`.
"""
# Combine inputs into a single tuple for autograd:
outs = list(GroupedMatmul.apply(tuple(inputs + others)))
Expand Down Expand Up @@ -160,18 +156,14 @@ def segment_matmul(
assert out[5:8] == inputs[5:8] @ other[1]
Args:
inputs (torch.Tensor): The left operand 2D matrix of shape
:obj:`[N, K]`.
ptr (torch.Tensor): Compressed vector of shape :obj:`[B + 1]`, holding
the boundaries of segments. For best performance, given as a CPU
tensor.
other (torch.Tensor): The right operand 3D tensor of shape
:obj:`[B, K, M]`.
bias (torch.Tensor, optional): Optional bias term of shape
:obj:`[B, M]` (default: :obj:`None`)
inputs: The left operand 2D matrix of shape :obj:`[N, K]`.
ptr: Compressed vector of shape :obj:`[B + 1]`, holding the boundaries
of segments. For best performance, given as a CPU tensor.
other: The right operand 3D tensor of shape :obj:`[B, K, M]`.
bias: The bias term of shape :obj:`[B, M]`.
Returns:
torch.Tensor: The 2D output matrix of shape :obj:`[N, M]`.
The 2D output matrix of shape :obj:`[N, M]`.
"""
out = torch.ops.pyg.segment_matmul(inputs, ptr, other)
if bias is not None:
Expand All @@ -198,15 +190,13 @@ def sampled_add(
being more runtime and memory-efficient.
Args:
left (torch.Tensor): The left tensor.
right (torch.Tensor): The right tensor.
left_index (torch.LongTensor, optional): The values to sample from the
:obj:`left` tensor. (default: :obj:`None`)
right_index (torch.LongTensor, optional): The values to sample from the
:obj:`right` tensor. (default: :obj:`None`)
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
torch.Tensor: The output tensor.
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, "add")
return out
Expand All @@ -230,15 +220,13 @@ def sampled_sub(
being more runtime and memory-efficient.
Args:
left (torch.Tensor): The left tensor.
right (torch.Tensor): The right tensor.
left_index (torch.LongTensor, optional): The values to sample from the
:obj:`left` tensor. (default: :obj:`None`)
right_index (torch.LongTensor, optional): The values to sample from the
:obj:`right` tensor. (default: :obj:`None`)
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
torch.Tensor: The output tensor.
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, "sub")
return out
Expand All @@ -262,15 +250,13 @@ def sampled_mul(
thus being more runtime and memory-efficient.
Args:
left (torch.Tensor): The left tensor.
right (torch.Tensor): The right tensor.
left_index (torch.LongTensor, optional): The values to sample from the
:obj:`left` tensor. (default: :obj:`None`)
right_index (torch.LongTensor, optional): The values to sample from the
:obj:`right` tensor. (default: :obj:`None`)
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
torch.Tensor: The output tensor.
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, "mul")
return out
Expand All @@ -294,15 +280,13 @@ def sampled_div(
being more runtime and memory-efficient.
Args:
left (torch.Tensor): The left tensor.
right (torch.Tensor): The right tensor.
left_index (torch.LongTensor, optional): The values to sample from the
:obj:`left` tensor. (default: :obj:`None`)
right_index (torch.LongTensor, optional): The values to sample from the
:obj:`right` tensor. (default: :obj:`None`)
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
torch.Tensor: The output tensor.
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, "div")
return out
Expand All @@ -323,13 +307,12 @@ def index_sort(
device.
Args:
inputs (torch.Tensor): A vector with positive integer values.
max_value (int, optional): The maximum value stored inside
:obj:`inputs`. This value can be an estimation, but needs to be
greater than or equal to the real maximum. (default: :obj:`None`)
inputs: A vector with positive integer values.
max_value: The maximum value stored inside :obj:`inputs`. This value
can be an estimation, but needs to be greater than or equal to the
real maximum.
Returns:
Tuple[torch.LongTensor, torch.LongTensor]:
A tuple containing sorted values and indices of the elements in the
original :obj:`input` tensor.
"""
Expand All @@ -349,14 +332,6 @@ def softmax_csr(
:attr:`ptr`, and then proceeds to compute the softmax individually for
each group.
Args:
src (Tensor): The source tensor.
ptr (LongTensor): Groups defined by CSR representation.
dim (int, optional): The dimension in which to normalize.
(default: :obj:`0`)
:rtype: :class:`Tensor`
Examples:
>>> src = torch.randn(4, 4)
>>> ptr = torch.tensor([0, 4])
Expand All @@ -365,6 +340,11 @@ def softmax_csr(
[0.1453, 0.2591, 0.5907, 0.2410],
[0.0598, 0.2923, 0.1206, 0.0921],
[0.7792, 0.3502, 0.1638, 0.2145]])
Args:
src: The source tensor.
ptr: Groups defined by CSR representation.
dim: The dimension in which to normalize.
"""
dim = dim + src.dim() if dim < 0 else dim
return torch.ops.pyg.softmax_csr(src, ptr, dim)
Expand Down
19 changes: 8 additions & 11 deletions pyg_lib/partition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,16 @@ def metis(
<https://arxiv.org/abs/1905.07953>`_ paper.
Args:
rowptr (torch.Tensor): Compressed source node indices.
col (torch.Tensor): Target node indices.
num_partitions (int): The number of partitions.
node_weight (torch.Tensor, optional): Optional node weights.
(default: :obj:`None`)
edge_weight (torch.Tensor, optional): Optional edge weights.
(default: :obj:`None`)
recursive (bool, optional): If set to :obj:`True`, will use multilevel
recursive bisection instead of multilevel k-way partitioning.
(default: :obj:`False`)
rowptr: Compressed source node indices.
col: Target node indices.
num_partitions: The number of partitions.
node_weight: The node weights.
edge_weight: The edge weights.
recursive: If set to :obj:`True`, will use multilevel recursive
bisection instead of multilevel k-way partitioning.
Returns:
torch.Tensor: A vector that assings each node to a partition.
A vector that assings each node to a partition.
"""
return torch.ops.pyg.metis(rowptr, col, num_partitions, node_weight,
edge_weight, recursive)
Expand Down
97 changes: 45 additions & 52 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,52 +35,47 @@ def neighbor_sample(
Args:
rowptr: Compressed source node indices.
col (torch.Tensor): Target node indices.
seed (torch.Tensor): The seed node indices.
col: Target node indices.
seed: The seed node indices.
num_neighbors: The number of neighbors to sample for each node in each
iteration.
If an entry is set to :obj:`-1`, all neighbors will be included.
node_time (torch.Tensor, optional): Timestamps for the nodes in the
graph. If set, temporal sampling will be used such that neighbors
are guaranteed to fulfill temporal constraints, *i.e.* sampled
node_time: Timestamps for the nodes in the graph.
If set, temporal sampling will be used such that neighbors are
guaranteed to fulfill temporal constraints, *i.e.* sampled
nodes have an earlier or equal timestamp than the seed node.
If used, the :obj:`col` vector needs to be sorted according to time
within individual neighborhoods. Requires :obj:`disjoint=True`.
within individual neighborhoods.
Requires :obj:`disjoint=True`.
Only either :obj:`node_time` or :obj:`edge_time` can be specified.
(default: :obj:`None`)
edge_time (torch.Tensor, optional): Timestamps for the edges in the
graph. If set, temporal sampling will be used such that neighbors
are guaranteed to fulfill temporal constraints, *i.e.* sampled
edge_time: Timestamps for the edges in the graph.
If set, temporal sampling will be used such that neighbors are
guaranteed to fulfill temporal constraints, *i.e.* sampled
edges have an earlier or equal timestamp than the seed node.
If used, the :obj:`col` vector needs to be sorted according to time
within individual neighborhoods. Requires :obj:`disjoint=True`.
within individual neighborhoods.
Requires :obj:`disjoint=True`.
Only either :obj:`node_time` or :obj:`edge_time` can be specified.
(default: :obj:`None`)
seed_time (torch.Tensor, optional): Optional values to override the
timestamp for seed nodes. If not set, will use timestamps in
:obj:`node_time` as default for seed nodes.
seed_time: Optional values to override the timestamp for seed nodes.
If not set, will use timestamps in :obj:`node_time` as default for
seed nodes.
Needs to be specified in case edge-level sampling is used via
:obj:`edge_time`. (default: :obj:`None`)
edge_weight (torch.Tensor, optional): If given, will perform biased
sampling based on the weight of each edge. (default: :obj:`None`)
csc (bool, optional): If set to :obj:`True`, assumes that the graph is
given in CSC format :obj:`(colptr, row)`. (default: :obj:`False`)
replace (bool, optional): If set to :obj:`True`, will sample with
replacement. (default: :obj:`False`)
directed (bool, optional): If set to :obj:`False`, will include all
edges between all sampled nodes. (default: :obj:`True`)
disjoint (bool, optional): If set to :obj:`True` , will create disjoint
subgraphs for every seed node. (default: :obj:`False`)
temporal_strategy (string, optional): The sampling strategy when using
temporal sampling (:obj:`"uniform"`, :obj:`"last"`).
(default: :obj:`"uniform"`)
return_edge_id (bool, optional): If set to :obj:`False`, will not
return the indices of edges of the original graph.
(default: :obj: `True`)
:obj:`edge_time`.
edge_weight: If given, will perform biased sampling based on the weight
of each edge.
csc: If set to :obj:`True`, assumes that the graph is given in CSC
format :obj:`(colptr, row)`.
replace: If set to :obj:`True`, will sample with replacement.
directed: If set to :obj:`False`, will include all edges between all
sampled nodes.
disjoint: If set to :obj:`True` , will create disjoint subgraphs for
every seed node.
temporal_strategy: The sampling strategy when using temporal sampling
(:obj:`"uniform"`, :obj:`"last"`).
return_edge_id: If set to :obj:`False`, will not return the indices of
edges of the original graph.
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor],
List[int], List[int]):
Row indices, col indices of the returned subtree/subgraph, as well as
original node indices for all nodes sampled.
In addition, may return the indices of edges of the original graph.
Expand Down Expand Up @@ -176,16 +171,16 @@ def subgraph(
:obj:`(rowptr, col)`, containing only the nodes in :obj:`nodes`.
Args:
rowptr (torch.Tensor): Compressed source node indices.
col (torch.Tensor): Target node indices.
nodes (torch.Tensor): Node indices of the induced subgraph.
return_edge_id (bool, optional): If set to :obj:`False`, will not
rowptr: Compressed source node indices.
col: Target node indices.
nodes: Node indices of the induced subgraph.
return_edge_id: If set to :obj:`False`, will not
return the indices of edges of the original graph contained in the
induced subgraph. (default: :obj:`True`)
induced subgraph.
Returns:
(torch.Tensor, torch.Tensor, Optional[torch.Tensor]): Compressed source
node indices and target node indices of the induced subgraph.
Compressed source node indices and target node indices of the induced
subgraph.
In addition, may return the indices of edges of the original graph.
"""
return torch.ops.pyg.subgraph(rowptr, col, nodes, return_edge_id)
Expand All @@ -205,19 +200,17 @@ def random_walk(
<https://arxiv.org/abs/1607.00653>`_ paper.
Args:
rowptr (torch.Tensor): Compressed source node indices.
col (torch.Tensor): Target node indices.
seed (torch.Tensor): Seed node indices from where random walks start.
walk_length (int): The walk length of a random walk.
p (float, optional): Likelihood of immediately revisiting a node in the
walk. (default: :obj:`1.0`)
q (float, optional): Control parameter to interpolate between
breadth-first strategy and depth-first strategy.
(default: :obj:`1.0`)
rowptr: Compressed source node indices.
col: Target node indices.
seed: Seed node indices from where random walks start.
walk_length: The walk length of a random walk.
p: Likelihood of immediately revisiting a node in the walk.
q: Control parameter to interpolate between breadth-first strategy and
depth-first strategy.
Returns:
torch.Tensor: A tensor of shape :obj:`[seed.size(0), walk_length + 1]`,
holding the nodes indices of each walk for each seed node.
A tensor of shape :obj:`[seed.size(0), walk_length + 1]`, holding the
nodes indices of each walk for each seed node.
"""
return torch.ops.pyg.random_walk(rowptr, col, seed, walk_length, p, q)

Expand Down

0 comments on commit a7f9b26

Please sign in to comment.