forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init * init * working cublasGemm * benchmark high-mem/low-mem, err gather_mm output * cuda kernel for bmm like kernel * removed cpu copy for E_per_Rel * benchmark code from Minjie * fixed cublas results in gathermm sorted * use GPU shared mem in unsorted gather mm * minor * Added an optimal version of gather_mm_unsorted * lint * init gather_mm_scatter * cublas transpose added * fixed h_offset for multiple rel * backward unittest * cublas support to transpose W * adding missed file * forgot to add header file * lint * lint * cleanup * lint * docstring * lint * added unittest * lint * lint * unittest * changed err type * skip cpu test * skip CPU code * move in-len loop inside * lint * added check different dim length for B * w_per_len is optional now * moved gather_mm to pytorch/backend with backward support * removed a_/b_trans support * transpose op inside GEMM call * removed out alloc from API, changed W 2D to 3D * Added se_gather_mm, Separate API for sortedE * Fixed gather_mm (unsorted) user interface * unsorted gmm backward + separate CAPI for un/sorted A * typecast to float to support atomicAdd * lint typecast * lint * added gather_mm_scatter * minor * const * design changes * Added idx_a, idx_b support gmm_scatter * dgl doc * lint * adding gather_mm in ops * lint * lint * minor * removed benchmark files * minor * empty commit Co-authored-by: Israt Nisa <[email protected]>
- Loading branch information
Showing
12 changed files
with
1,315 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ | |
from .sddmm import * | ||
from .edge_softmax import * | ||
from .segment import * | ||
from .gather_mm import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
"""dgl gather_mm operator module.""" | ||
from ..backend import gather_mm as gather_mm_internal | ||
from ..backend import segment_mm as segment_mm_internal | ||
|
||
__all__ = ['gather_mm', 'segment_mm'] | ||
|
||
def segment_mm(lhs_data, rhs_data, seglen_lhs): | ||
r""" Performs matrix multiplication according to segments. | ||
Suppose ``seglen_lhs == [10, 5, 0, 3]``, the operator will perform | ||
four matrix multiplications: | ||
lhs_data[0:10] @ rhs_data[0], lhs_data[10:15] @ rhs_data[1], | ||
lhs_data[15:15] @ rhs_data[2], lhs_data[15:18] @ rhs_data[3] | ||
Parameters | ||
---------- | ||
lhs_data : tensor | ||
The left operand, 2-D tensor of shape (N, D1) | ||
rhs_data : tensor | ||
The right operand, 2-D tensor of shape (R * D1, D2) | ||
seglen_lhs : tensor | ||
An integer tensor of shape (R,). Each element is the length of segments | ||
of input ``lhs_data``. The summation of all elements must be equal to N. | ||
Returns | ||
------- | ||
tensor | ||
The output dense matrix of shape (N, D2) | ||
""" | ||
return segment_mm_internal(lhs_data, rhs_data, seglen_lhs) | ||
|
||
def gather_mm(lhs_data, rhs_data, idx_lhs = None, idx_rhs = None): | ||
r"""Gather data according to the given indices and perform matrix multiplication. | ||
Let the result tensor be C, the operator conducts the following computation: | ||
If both idx_lhs and idx_rhs are not none: | ||
c[i] = lhs_data[idx_lhs[i]] @ rhs_data[idx_rhs[i]] | ||
, where len(C) == len(idx_lhs) == len(idx_rhs) | ||
If idx_lhs is given but not idx_rhs: | ||
c[i] = rhs_data[idx_lhs[i]] @ rhs_data[i] | ||
, where len(C) == len(idx_lhs) | ||
If idx_rhs is given but not idx_lhs: | ||
c[i] = lhs_data[i] @ rhs_data[idx_rhs[i]] | ||
, where len(C) == len(idx_rhs) | ||
Parameters | ||
---------- | ||
lhs_data : tensor | ||
2-D tensor of shape (N, D1) | ||
rhs_data : tensor | ||
3-D tensor of shape (R, D1, D2) | ||
idx_lhs : Tensor, optional | ||
If specified, must be a 1-D integer tensor of shape (K,). | ||
idx_rhs : Tensor, optional | ||
If specified, must be a 1-D integer tensor of shape (K,). | ||
Returns | ||
------- | ||
Tensor | ||
The output dense matrix of shape (N, D2) | ||
""" | ||
return gather_mm_internal(lhs_data, rhs_data, idx_lhs, idx_rhs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.