Skip to content

Commit

Permalink
[NN] HeteroLinear and HeteroEmbedding (dmlc#3678)
Browse files Browse the repository at this point in the history
* modify hetero

* modify rst document

* update hetero

* update hetero

* update hetero

* update hetero

* Update

* Update

* Update

* Update

* 20220216

* Update

* Update

* Fix

Co-authored-by: Mufei Li <[email protected]>
Co-authored-by: ShelkerX <[email protected]>
  • Loading branch information
3 people authored Feb 17, 2022
1 parent e9c3c0e commit 9e358df
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 5 deletions.
14 changes: 13 additions & 1 deletion docs/source/api/python/nn.pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,24 @@ Heterogeneous Graph Convolution Module
----------------------------------------

HeteroGraphConv
~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.pytorch.HeteroGraphConv
:members:
:show-inheritance:

HeteroLinear
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.HeteroLinear
:members:
:show-inheritance:

HeteroEmbedding
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.HeteroEmbedding
:members:
:show-inheritance:

.. _apinn-pytorch-util:

Utility Modules
Expand Down
128 changes: 127 additions & 1 deletion python/dgl/nn/pytorch/hetero.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
from ...base import DGLError

__all__ = ['HeteroGraphConv']
__all__ = ['HeteroGraphConv', 'HeteroLinear', 'HeteroEmbedding']

class HeteroGraphConv(nn.Module):
r"""A generic module for computing convolution on heterogeneous graphs.
Expand Down Expand Up @@ -250,3 +250,129 @@ def get_aggregate_fn(agg):
return _stack_agg_func
else:
return partial(_agg_func, fn=fn)

class HeteroLinear(nn.Module):
"""Apply linear transformations on heterogeneous inputs.
Parameters
----------
in_size : dict[key, int]
Input feature size for heterogeneous inputs. A key can be a string or a tuple of strings.
out_size : int
Output feature size.
Examples
--------
>>> import dgl
>>> import torch
>>> from dgl.nn import HeteroLinear
>>> layer = HeteroLinear({'user': 1, ('user', 'follows', 'user'): 2}, 3)
>>> in_feats = {'user': torch.randn(2, 1), ('user', 'follows', 'user'): torch.randn(3, 2)}
>>> out_feats = layer(in_feats)
>>> print(out_feats['user'].shape)
torch.Size([2, 3])
>>> print(out_feats[('user', 'follows', 'user')].shape)
torch.Size([3, 3])
"""
def __init__(self, in_size, out_size):
super(HeteroLinear, self).__init__()

self.linears = nn.ModuleDict()
for typ, typ_in_size in in_size.items():
self.linears[str(typ)] = nn.Linear(typ_in_size, out_size)

def forward(self, feat):
"""Forward function
Parameters
----------
feat : dict[key, Tensor]
Heterogeneous input features. It maps keys to features.
Returns
-------
dict[key, Tensor]
Transformed features.
"""
out_feat = dict()
for typ, typ_feat in feat.items():
out_feat[typ] = self.linears[str(typ)](typ_feat)

return out_feat

class HeteroEmbedding(nn.Module):
"""Create a heterogeneous embedding table.
It internally contains multiple ``torch.nn.Embedding`` with different dictionary sizes.
Parameters
----------
num_embeddings : dict[key, int]
Size of the dictionaries. A key can be a string or a tuple of strings.
embedding_dim : int
Size of each embedding vector.
Examples
--------
>>> import dgl
>>> import torch
>>> from dgl.nn import HeteroEmbedding
>>> layer = HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, 4)
>>> # Get the heterogeneous embedding table
>>> embeds = layer.weight
>>> print(embeds['user'].shape)
torch.Size([2, 4])
>>> print(embeds[('user', 'follows', 'user')].shape)
torch.Size([3, 4])
>>> # Get the embeddings for a subset
>>> input_ids = {'user': torch.LongTensor([0]),
... ('user', 'follows', 'user'): torch.LongTensor([0, 2])}
>>> embeds = layer(input_ids)
>>> print(embeds['user'].shape)
torch.Size([1, 4])
>>> print(embeds[('user', 'follows', 'user')].shape)
torch.Size([2, 4])
"""
def __init__(self, num_embeddings, embedding_dim):
super(HeteroEmbedding, self).__init__()

self.embeds = nn.ModuleDict()
self.raw_keys = dict()
for typ, typ_num_rows in num_embeddings.items():
self.embeds[str(typ)] = nn.Embedding(typ_num_rows, embedding_dim)
self.raw_keys[str(typ)] = typ

@property
def weight(self):
"""Get the heterogeneous embedding table
Returns
-------
dict[key, Tensor]
Heterogeneous embedding table
"""
return {self.raw_keys[typ]: emb.weight for typ, emb in self.embeds.items()}

def forward(self, input_ids):
"""Forward function
Parameters
----------
input_ids : dict[key, Tensor]
The row IDs to retrieve embeddings. It maps a key to key-specific IDs.
Returns
-------
dict[key, Tensor]
The retrieved embeddings.
"""
embeds = dict()
for typ, typ_ids in input_ids.items():
embeds[typ] = self.embeds[str(typ)](typ_ids)

return embeds
35 changes: 32 additions & 3 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def test_gin_conv(g, idtype, aggregator_type):
th.save(gin, tmp_buffer)

assert h.shape == (g.number_of_dst_nodes(), 12)

gin = nn.GINConv(None, aggregator_type)
th.save(gin, tmp_buffer)
gin = gin.to(ctx)
Expand Down Expand Up @@ -1246,6 +1246,35 @@ def forward(self, g, h, arg1=None, *, arg2=None):
{'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'}

@pytest.mark.parametrize('out_dim', [1, 2, 100])
def test_hetero_linear(out_dim):
in_feats = {
'user': F.randn((2, 1)),
('user', 'follows', 'user'): F.randn((3, 2))
}

layer = nn.HeteroLinear({'user': 1, ('user', 'follows', 'user'): 2}, out_dim)
layer = layer.to(F.ctx())
out_feats = layer(in_feats)
assert out_feats['user'].shape == (2, out_dim)
assert out_feats[('user', 'follows', 'user')].shape == (3, out_dim)

@pytest.mark.parametrize('out_dim', [1, 2, 100])
def test_hetero_embedding(out_dim):
layer = nn.HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, out_dim)
layer = layer.to(F.ctx())

embeds = layer.weight
assert embeds['user'].shape == (2, out_dim)
assert embeds[('user', 'follows', 'user')].shape == (3, out_dim)

embeds = layer({
'user': F.tensor([0], dtype=F.int64),
('user', 'follows', 'user'): F.tensor([0, 2], dtype=F.int64)
})
assert embeds['user'].shape == (1, out_dim)
assert embeds[('user', 'follows', 'user')].shape == (2, out_dim)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
Expand Down Expand Up @@ -1348,13 +1377,13 @@ def test_ke_score_funcs():
score_func(h_src, h_dst, rels).shape == (num_edges)


def test_twirls():
def test_twirls():
g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
feat = th.ones(6, 10)
conv = nn.TWIRLSConv(10, 2, 128, prop_step = 64)
res = conv(g , feat)
assert ( res.size() == (6,2) )



if __name__ == '__main__':
Expand Down

0 comments on commit 9e358df

Please sign in to comment.