Skip to content

Commit

Permalink
[Transform] Modules for Augmentation (dmlc#3668)
Browse files Browse the repository at this point in the history
* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix lint

* lint

* Update

* Update

* lint fix

* Fix CI

* Fix

* Fix CI

* Update

* Fix

* Update

* Update

* Augmentation (dmlc#10)

* Update

* PPR

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* CI

* lint

* lint

* Update

* Update

* Fix AddEdge

* try import

* Update

* Fix

* CI

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
3 people authored Jan 25, 2022
1 parent ba62b73 commit 2b98e76
Show file tree
Hide file tree
Showing 8 changed files with 830 additions and 8 deletions.
52 changes: 50 additions & 2 deletions docs/source/api/python/transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ BaseTransform
:members: __call__, __repr__
:show-inheritance:

Compose
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: Compose
:show-inheritance:

AddSelfLoop
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -55,8 +61,50 @@ AddMetaPaths
.. autoclass:: AddMetaPaths
:show-inheritance:

KNNGraph
GCNNorm
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: GCNNorm
:show-inheritance:

PPR
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: PPR
:show-inheritance:

HeatKernel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: HeatKernel
:show-inheritance:

GDC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: GDC
:show-inheritance:

NodeShuffle
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: NodeShuffle
:show-inheritance:

DropNode
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: DropNode
:show-inheritance:

DropEdge
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: DropEdge
:show-inheritance:

AddEdge
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: KNNGraph
.. autoclass:: AddEdge
:show-inheritance:
30 changes: 30 additions & 0 deletions python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,21 @@ def exp(input):
"""
pass

def inverse(input):
"""Returns the inverse matrix of a square matrix if it exists.
Parameters
----------
input : Tensor
The input square matrix.
Returns
-------
Tensor
The output tensor.
"""
pass

def sqrt(input):
"""Returns a new tensor with the square root of the elements of the input tensor `input`.
Expand Down Expand Up @@ -1057,6 +1072,21 @@ def equal(x, y):
"""
pass

def allclose(x, y, rtol=1e-4, atol=1e-4):
"""Compares whether all elements are close.
Parameters
----------
x : Tensor
First tensor
y : Tensor
Second tensor
rtol : float, optional
Relative tolerance
atol : float, optional
Absolute tolerance
"""

def logical_not(input):
"""Perform a logical not operation. Equivalent to np.logical_not
Expand Down
6 changes: 6 additions & 0 deletions python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def argsort(input, dim, descending):
def exp(input):
return nd.exp(input)

def inverse(input):
return nd.linalg_inverse(input)

def sqrt(input):
return nd.sqrt(input)

Expand Down Expand Up @@ -327,6 +330,9 @@ def boolean_mask(input, mask):
def equal(x, y):
return x == y

def allclose(x, y, rtol=1e-4, atol=1e-4):
return np.allclose(x.asnumpy(), y.asnumpy(), rtol=rtol, atol=atol)

def logical_not(input):
return nd.logical_not(input)

Expand Down
10 changes: 8 additions & 2 deletions python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from ...function.base import TargetCode
from ...base import dgl_warning

if LooseVersion(th.__version__) < LooseVersion("1.5.0"):
raise Exception("Detected an old version of PyTorch. Please update torch>=1.5.0 "
if LooseVersion(th.__version__) < LooseVersion("1.8.0"):
raise Exception("Detected an old version of PyTorch. Please update torch>=1.8.0 "
"for the best experience.")

def data_type_dict():
Expand Down Expand Up @@ -164,6 +164,9 @@ def argtopk(input, k, dim, descending=True):
def exp(input):
return th.exp(input)

def inverse(input):
return th.inverse(input)

def sqrt(input):
return th.sqrt(input)

Expand Down Expand Up @@ -276,6 +279,9 @@ def boolean_mask(input, mask):
def equal(x, y):
return x == y

def allclose(x, y, rtol=1e-4, atol=1e-4):
return th.allclose(x, y, rtol=rtol, atol=atol)

def logical_not(input):
return ~input

Expand Down
9 changes: 9 additions & 0 deletions python/dgl/backend/tensorflow/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ def exp(input):
return tf.exp(input)


def inverse(input):
return tf.linalg.inv(input)


def sqrt(input):
return tf.sqrt(input)

Expand Down Expand Up @@ -396,6 +400,11 @@ def equal(x, y):
return x == y


def allclose(x, y, rtol=1e-4, atol=1e-4):
return np.allclose(tf.convert_to_tensor(x).numpy(),
tf.convert_to_tensor(y).numpy(), rtol=rtol, atol=atol)


def logical_not(input):
return ~input

Expand Down
Loading

0 comments on commit 2b98e76

Please sign in to comment.