Skip to content

Commit

Permalink
[Feature] Gather mm (dmlc#3641)
Browse files Browse the repository at this point in the history
* init

* init

* working cublasGemm

* benchmark high-mem/low-mem, err gather_mm output

* cuda kernel for bmm like kernel

* removed cpu copy for E_per_Rel

* benchmark code from Minjie

* fixed cublas results in gathermm sorted

* use GPU shared mem in unsorted gather mm

* minor

* Added an optimal version of gather_mm_unsorted

* lint

* init gather_mm_scatter

* cublas transpose added

* fixed h_offset for multiple rel

* backward unittest

* cublas support to transpose W

* adding missed file

* forgot to add header file

* lint

* lint

* cleanup

* lint

* docstring

* lint

* added unittest

* lint

* lint

* unittest

* changed err type

* skip cpu test

* skip CPU code

* move in-len loop inside

* lint

* added check different dim length for B

* w_per_len is optional now

* moved gather_mm to pytorch/backend with backward support

* removed a_/b_trans support

* transpose op inside GEMM call

* removed out alloc from API, changed W 2D to 3D

* Added se_gather_mm, Separate API for sortedE

* Fixed gather_mm (unsorted) user interface

* unsorted gmm backward + separate CAPI for un/sorted A

* typecast to float to support atomicAdd

* lint typecast

* lint

* added gather_mm_scatter

* minor

* const

* design changes

* Added idx_a, idx_b support gmm_scatter

* dgl doc

* lint

* adding gather_mm in ops

* lint

* lint

* minor

* removed benchmark files

* minor

* empty commit

Co-authored-by: Israt Nisa <[email protected]>
  • Loading branch information
isratnisa and Israt Nisa authored Feb 15, 2022
1 parent ab50eb9 commit b3d3a2c
Show file tree
Hide file tree
Showing 12 changed files with 1,315 additions and 1 deletion.
13 changes: 13 additions & 0 deletions docs/source/api/python/dgl.ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,19 @@ DGL provide operators to reduce value tensor along the first dimension by segmen

segment_reduce

GatherMM and SegmentMM Module
-----------------------------

SegmentMM: DGL provide operators to perform matrix multiplication according to segments.

GatherMM: DGL provide operators to gather data according to the given indices and perform matrix multiplication.

.. autosummary::
:toctree: ../../generated/

gather_mm
segment_mm

Supported Data types
--------------------
Operators defined in ``dgl.ops`` support floating point data types, i.e. the operands
Expand Down
45 changes: 45 additions & 0 deletions python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,6 +1827,51 @@ def csrmask(A, A_weights, B):
"""
pass

def gather_mm(A, B, idx_a, idx_b):
r""" Dense Matrix Multiplication interface. It multiplies 2D dense tensor A
and 3D dense tensor B according to their relation types. A is unsorted and
the relation type is fetched from idx_b.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
3-D tensor of shape (R, D1, D2)
idx_a : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
pass

def segment_mm(A, B, seglen_A):
r""" Dense Matrix Multiplication interface. It multiplies dense tensor A
and dense tensor B according to relation types. A is sorted and concatenated
according to relation types.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
3-D tensor of shape (R, D1, D2)
seglen_A : Tensor
An integer tensor of shape (R,). Each element is the length of segments
of input ``A``. The summation of all elements must be equal to N.
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
pass


###############################################################################
# Other interfaces
Expand Down
74 changes: 73 additions & 1 deletion python/dgl/backend/pytorch/sparse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch as th
from distutils.version import LooseVersion
from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero, _gather_mm, _gather_mm_scatter, _segment_mm
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _edge_softmax_forward, _edge_softmax_backward
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...heterograph_index import create_unitgraph_from_csr
Expand All @@ -27,7 +29,7 @@ def decorate_bwd(*args, **kwargs):
return decorate_bwd

__all__ = ['gspmm', 'gsddmm', 'gspmm_hetero', 'gsddmm_hetero', 'edge_softmax', 'edge_softmax_hetero',
'segment_reduce', 'scatter_add', 'csrmm', 'csrsum', 'csrmask']
'segment_reduce', 'scatter_add', 'csrmm', 'csrsum', 'csrmask', 'gather_mm', 'segment_mm']


def _reduce_grad(grad, shape):
Expand Down Expand Up @@ -691,6 +693,70 @@ def backward(ctx, dB_weights):
return None, csrmask(gidxB, dB_weights, gidxA), None


class SEGMENTMM(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, seglen_A):
if A.shape[0] != th.sum(seglen_A):
raise Exception("The summation of the elements of seglen_A must be equal to " +
"dimension 0 of A. Expected "+ str(A.shape[0]) + "got" + str(th.sum(seglen_A)))
if B.dim() != 3:
raise Exception("Expected dimension of B is 3. Got " + str(B.dim()))
# Reshaping B form 3D to 2D
B_3D_shape = B.shape
B = B.reshape(B.shape[0] * B.shape[1], B.shape[2])
C = th.zeros((A.shape[0], B.shape[1]), device=A.device, dtype=A.dtype)
C = _segment_mm(A, B, C, seglen_A)
ctx.backward_cache = A, B, seglen_A, B_3D_shape
return C

@staticmethod
def backward(ctx, dZ):
A, B, seglen_A, B_3D_shape = ctx.backward_cache
A_grad = B_grad = None
if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T
A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype)
A_grad = _segment_mm(dZ, B, A_grad, seglen_A, b_trans=True)
if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = _segment_mm(A, dZ, B_grad, seglen_A, a_trans=True)
B_grad = B_grad.reshape(B_3D_shape[0], B_3D_shape[1], B_3D_shape[2])
return A_grad, B_grad, None, None, None, None, None, None


class GATHERMM(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, idx_a, idx_b):
if B.dim() != 3:
raise Exception("Expected dimension of B is 3. Got " + str(B.dim()))
# Reshaping B form 3D to 2D
B_3D_shape = B.shape
B = B.reshape(B.shape[0] * B.shape[1], B.shape[2])
C = th.zeros((A.shape[0], B.shape[1]), device=A.device, dtype=A.dtype)
C = _gather_mm(A, B, C, B_3D_shape[0], idx_a, idx_b)
ctx.backward_cache = A, B, idx_a, idx_b, B_3D_shape
return C

@staticmethod
def backward(ctx, dZ):
A, B, idx_a, idx_b, B_3D_shape = ctx.backward_cache
A_grad = B_grad = None
if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T
A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype)
A_grad = _gather_mm_scatter(dZ, B, A_grad, B_3D_shape[0],
idx_b=idx_b, idx_c=idx_a, b_trans=True)
if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = _gather_mm_scatter(A, dZ, B_grad, B_3D_shape[0],
idx_a=idx_a, idx_c=idx_b)
B_grad = B_grad.reshape(B_3D_shape[0], B_3D_shape[1], B_3D_shape[2])
return A_grad, B_grad, None, None, None, None, None, None

def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
if op == 'sub':
op = 'add'
Expand Down Expand Up @@ -766,3 +832,9 @@ def csrsum(gidxs, weights):

def csrmask(gidxA, A_weights, gidxB):
return CSRMask.apply(gidxA, A_weights, gidxB)

def segment_mm(A, B, seglen_A):
return SEGMENTMM.apply(A, B, seglen_A)

def gather_mm(A, B, idx_a = None, idx_b = None):
return GATHERMM.apply(A, B, idx_a, idx_b)
1 change: 1 addition & 0 deletions python/dgl/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .sddmm import *
from .edge_softmax import *
from .segment import *
from .gather_mm import *
68 changes: 68 additions & 0 deletions python/dgl/ops/gather_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""dgl gather_mm operator module."""
from ..backend import gather_mm as gather_mm_internal
from ..backend import segment_mm as segment_mm_internal

__all__ = ['gather_mm', 'segment_mm']

def segment_mm(lhs_data, rhs_data, seglen_lhs):
r""" Performs matrix multiplication according to segments.
Suppose ``seglen_lhs == [10, 5, 0, 3]``, the operator will perform
four matrix multiplications:
lhs_data[0:10] @ rhs_data[0], lhs_data[10:15] @ rhs_data[1],
lhs_data[15:15] @ rhs_data[2], lhs_data[15:18] @ rhs_data[3]
Parameters
----------
lhs_data : tensor
The left operand, 2-D tensor of shape (N, D1)
rhs_data : tensor
The right operand, 2-D tensor of shape (R * D1, D2)
seglen_lhs : tensor
An integer tensor of shape (R,). Each element is the length of segments
of input ``lhs_data``. The summation of all elements must be equal to N.
Returns
-------
tensor
The output dense matrix of shape (N, D2)
"""
return segment_mm_internal(lhs_data, rhs_data, seglen_lhs)

def gather_mm(lhs_data, rhs_data, idx_lhs = None, idx_rhs = None):
r"""Gather data according to the given indices and perform matrix multiplication.
Let the result tensor be C, the operator conducts the following computation:
If both idx_lhs and idx_rhs are not none:
c[i] = lhs_data[idx_lhs[i]] @ rhs_data[idx_rhs[i]]
, where len(C) == len(idx_lhs) == len(idx_rhs)
If idx_lhs is given but not idx_rhs:
c[i] = rhs_data[idx_lhs[i]] @ rhs_data[i]
, where len(C) == len(idx_lhs)
If idx_rhs is given but not idx_lhs:
c[i] = lhs_data[i] @ rhs_data[idx_rhs[i]]
, where len(C) == len(idx_rhs)
Parameters
----------
lhs_data : tensor
2-D tensor of shape (N, D1)
rhs_data : tensor
3-D tensor of shape (R, D1, D2)
idx_lhs : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
idx_rhs : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
return gather_mm_internal(lhs_data, rhs_data, idx_lhs, idx_rhs)
105 changes: 105 additions & 0 deletions python/dgl/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,111 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
return out, (list_arg_u, list_arg_e, list_arg_u_ntype, list_arg_e_etype)


def _segment_mm(A, B, out, seglen_A, a_trans=False, b_trans=False):
r""" Dense Matrix Multiplication interface. It multiplies dense tensor A
and dense tensor B according to relation types. A is sorted and concatenated
according to relation types.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
seglen_A : Tensor
An integer tensor of shape (R,). Each element is the length of segments
of input ``A``. The summation of all elements must be equal to N.
a_trans : bool
Indicates whether matrix A needs to be tranposed
b_trans : bool
Indicates whether matrix B needs to be tranposed
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
_CAPI_DGLKernelSEGMENTMM(to_dgl_nd(A),
to_dgl_nd(B),
to_dgl_nd_for_write(out),
to_dgl_nd(seglen_A),
a_trans, b_trans)
return out


def _gather_mm(A, B, out, num_rel, idx_a=None, idx_b=None):
r""" Generalized Dense Matrix Multiplication interface. It multiplies
tensor A and B according to relation types and outputs in out. B is a
concatenated tensor across relation types. A is unsorted and the
relation type is fetched from param etypes.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
idx_a : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,)
idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
_CAPI_DGLKernelGATHERMM(to_dgl_nd(A),
to_dgl_nd(B),
to_dgl_nd_for_write(out),
to_dgl_nd(idx_a),
to_dgl_nd(idx_b),
num_rel)
return out


def _gather_mm_scatter(A, B, out, num_rel, idx_a=None, idx_b=None, idx_c=None,
a_trans=False, b_trans=False):
r""" Generalized Dense Matrix Multiplication interface. It multiplies
tensor A and B according to relation types and outputs in out. B is a
concatenated tensor across relation types. A is unsorted and the
relation type is fetched from param etypes.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
idx_a : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,)
idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
idx_c : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
A_trans : bool
Indicates whether matrix A needs to be tranposed
B_trans : bool
Indicates whether matrix B needs to be tranposed
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
_CAPI_DGLKernelGATHERMMSCATTER(to_dgl_nd(A),
to_dgl_nd(B),
to_dgl_nd_for_write(out),
to_dgl_nd(idx_a),
to_dgl_nd(idx_b),
to_dgl_nd(idx_c),
num_rel, a_trans, b_trans)
return out


def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. It
takes the result of :attr:`op` on source node feature and destination node
Expand Down
Loading

0 comments on commit b3d3a2c

Please sign in to comment.