Skip to content

Commit

Permalink
Fully deprecate MessagePassing.jittable() (pyg-team#8781)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 16, 2024
1 parent 5bfabe0 commit 93aef5e
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 509 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated `MessagePassing.jittable` ([#8781](https://github.com/pyg-team/pytorch_geometric/pull/8781))
- Deprecated the usage of `torch_geometric.compile`; Use `torch.compile` instead ([#8780](https://github.com/pyg-team/pytorch_geometric/pull/8780))
- Deprecated the `typing` argument in `MessagePassing.jittable()` ([#8731](https://github.com/pyg-team/pytorch_geometric/pull/8731))
- Deprecated `torch_geometric.data.makedirs` in favor of `os.makedirs` ([#8421](https://github.com/pyg-team/pytorch_geometric/pull/8421))
Expand Down
72 changes: 24 additions & 48 deletions docs/source/advanced/jit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ If you are unfamilar with TorchScript, we recommend to read the official "`Intro
Converting GNN Models
---------------------

Converting your :pyg:`PyG` model to a TorchScript program is straightforward and requires only a few code changes.
.. note::
From :pyg:`PyG` 2.5 (and onwards), GNN layers are now fully compatible with :meth:`torch.jit.script` without any modification needed.
If you are on an earlier version of :pyg:`PyG`, consider to convert your GNN layers into "jittable" instances first by calling :meth:`~torch_geometric.nn.conv.MessagePassing.jittable`.

As always, it is best understood by an example, so let's consider the following model for now:
Converting your :pyg:`PyG` model to a TorchScript program is straightforward and requires only a few code changes.
Let's consider the following model:

.. code-block:: python
Expand All @@ -32,56 +35,23 @@ As always, it is best understood by an example, so let's consider the following
model = GNN(dataset.num_features, dataset.num_classes)
For TorchScript support, we need to convert our GNN operators into "jittable" instances.
This is done by calling the :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.jittable` function provided by the underlying :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface:

.. code-block:: python
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, 64).jittable()
self.conv2 = GCNConv(64, out_channels).jittable()
This will create temporary instances of the :class:`~torch_geometric.nn.conv.GCNConv` operator that can now be passed into :func:`torch.jit.script`:
The instantiated model can now be directly passed into :meth:`torch.jit.script`:

.. code-block:: python
model = torch.jit.script(model)
Under the hood, the :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.jittable` call applies the following two modifications to the original class:

1. It parses and converts the arguments of the internal :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate` function into a :obj:`NamedTuple` which can be handled by the TorchScript compiler.
2. It replaces any :obj:`Union` arguments of the :func:`forward` function (*i.e.*, arguments that may contain different types) with :obj:`@torch.jit._overload_method` annotations.
With this, we can do the following while everything remains jittable:

.. code-block:: python
from typing import Union, Tuple
from torch import Tensor
def forward(self, x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Tensor) -> Tensor:
pass
conv(x, edge_index)
conv((x_src, x_dst), edge_index)
This technique is, *e.g.*, applied in the :class:`~torch_geometric.nn.conv.SAGEConv` class, which can operate on both single node feature matrices and tuples of node feature matrices at the same time.

And that is all you need to know on how to convert your :pyg:`PyG` models to TorchScript programs.
That is all you need to know on how to convert your :pyg:`PyG` models to TorchScript programs.
You can have a further look at our JIT examples that show-case how to obtain TorchScript programs for `node <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/jit/gat.py>`_ and `graph classification <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/jit/gin.py>`_ models.

.. note::
TorchScript support is still experimental.
If you have problems converting your model to a TorchScript program, *e.g.*, because an internal function does not support TorchScript yet, please `let us know <https://github.com/pyg-team/pytorch_geometric/issues/new/choose>`_.

Creating Jittable GNN Operators
--------------------------------

All :pyg:`PyG` :class:`~torch_geometric.nn.conv.MessagePassing` operators are tested to be convertible to a TorchScript program.
However, if you want your own GNN module to be jittable, you need to account for the following two things:
However, if you want your own GNN module to be compatible with :meth:`torch.jit.script`, you need to account for the following two things:

1. As one would expect, your :meth:`forward` code may need to be adjusted so that it passes the TorchScript compiler requirements, *e.g.*, by adding type notations.
2. You need to tell the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` module the types that you pass to its :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate` function.
2. You need to tell the :class:`~torch_geometric.nn.conv.MessagePassing` module the types that you pass to its :meth:`~torch_geometric.nn.conv.MessagePassing.propagate` function.
This can be achieved in two different ways:

1. Declaring the type of propagation arguments in a dictionary called :obj:`propagate_type`:
Expand All @@ -93,15 +63,17 @@ However, if you want your own GNN module to be jittable, you need to account for
from torch_geometric.nn import MessagePassing
class MyConv(MessagePassing):
propagate_type = {'x': Tensor, 'edge_weight': Optional[Tensor] }
def forward(self, x: Tensor, edge_index: Tensor,
edge_weight: Optional[Tensor]) -> Tensor:
def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_weight: Optional[Tensor] = None,
) -> Tensor:
return self.propagate(edge_index, x=x, edge_weight=edge_weight)
2. Declaring the type of propagation arguments as a comment anywhere inside your module:
2. Declaring the type of propagation arguments as a comment inside your module:

.. code-block:: python
Expand All @@ -110,9 +82,13 @@ However, if you want your own GNN module to be jittable, you need to account for
from torch_geometric.nn import MessagePassing
class MyConv(MessagePassing):
def forward(self, x: Tensor, edge_index: Tensor,
edge_weight: Optional[Tensor]) -> Tensor:
def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_weight: Optional[Tensor] = None,
) -> Tensor:
# propagate_type: (x: Tensor, edge_weight: Optional[Tensor])
return self.propagate(edge_index, x=x, edge_weight=edge_weight)
If none of these options are given, the :class:`~torch_geometric.nn.conv.MessagePassing` module will infer the arguments of :meth:`~torch_geometric.nn.conv.MessagePassing.propagate` to be of type :class:`torch.Tensor` (mimicing the default type that TorchScript is inferring for non-annotated arguments).
2 changes: 2 additions & 0 deletions torch_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
from .dir_gnn_conv import DirGNNConv
from .mixhop_conv import MixHopConv

import torch_geometric.nn.conv.utils # noqa

__all__ = [
'MessagePassing',
'SimpleConv',
Expand Down
231 changes: 0 additions & 231 deletions torch_geometric/nn/conv/message_passing.jinja

This file was deleted.

Loading

0 comments on commit 93aef5e

Please sign in to comment.