Skip to content

Commit

Permalink
[nnx] add tabulate
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 8, 2025
1 parent e2134af commit a7d60e7
Show file tree
Hide file tree
Showing 11 changed files with 707 additions and 231 deletions.
206 changes: 100 additions & 106 deletions docs_nnx/mnist_tutorial.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs_nnx/mnist_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class CNN(nnx.Module):
# Instantiate the model.
model = CNN(rngs=nnx.Rngs(0))
# Visualize it.
nnx.display(model)
print(nnx.tabulate(model))
```

### Run the model
Expand All @@ -112,7 +112,7 @@ Let's put the CNN model to the test! Here, you’ll perform a forward pass with
import jax.numpy as jnp # JAX NumPy
y = model(jnp.ones((1, 28, 28, 1)))
nnx.display(y)
y
```

## 4. Create the optimizer and define some metrics
Expand Down
211 changes: 158 additions & 53 deletions docs_nnx/nnx_basics.ipynb

Large diffs are not rendered by default.

19 changes: 4 additions & 15 deletions docs_nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,7 @@ jupytext:

Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.

In this guide you will learn about:

- The Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) system: An example of creating and initializing a custom `Linear` layer.
- Stateful computation: An example of creating a Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and updating its value (such as state updates needed during the forward pass).
- Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layers.
- Model surgery: An example of replacing custom `Linear` layers inside a model with custom `LoraLinear` layers.
- Flax transformations: An example of using [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) for automatic state management.
- [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) over layers.
- The Flax NNX Functional API: An example of a custom `StatefulLinear` layer with [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s with fine-grained control over the state.
- [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef).
- [`split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and `update`
- Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s ([`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/guides/filters_guide.html)) to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s.
To begin, install Flax with `pip` and import necessary dependencies:

## Setup

Expand Down Expand Up @@ -106,7 +95,7 @@ to handle them, as demonstrated in later sections of this guide.

Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.

The example below shows how to define a simple `MLP` by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The model consists of two `Linear` layers, an [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer:
The example below shows how to define a simple `MLP` Module consisting of two `Linear` layers, a [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer. We'll also use `nnx.tabulate` to get a summary of the state at the end.

```{code-cell} ipython3
class MLP(nnx.Module):
Expand All @@ -124,7 +113,7 @@ model = MLP(2, 16, 5, rngs=nnx.Rngs(0))
y = model(x=jnp.ones((3, 2)))
nnx.display(model)
print(nnx.tabulate(model))
```

In Flax, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) is a stateful module that stores an [`nnx.Rngs`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.Rngs) object, so that it can generate new masks during the forward pass without the need for the user to pass a new key each time.
Expand Down Expand Up @@ -229,7 +218,7 @@ x = jnp.ones((3, 10))
y = forward(model, x)
print(f'{y.shape = }')
nnx.display(model)
print(nnx.tabulate(model))
```

How do Flax NNX transforms achieve this? To understand how Flax NNX objects interact with JAX transforms, the next section explains the Flax NNX Functional API.
Expand Down
18 changes: 11 additions & 7 deletions flax/linen/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@
LogicalNames,
)

try:
from IPython import get_ipython

in_ipython = get_ipython() is not None
except ImportError:
in_ipython = False


class _ValueRepresentation(ABC):
"""A class that represents a value in the summary table."""
Expand Down Expand Up @@ -242,11 +249,6 @@ def tabulate(
Total Parameters: 50 (200 B)
**Note**: rows order in the table does not represent execution order,
instead it aligns with the order of keys in `variables` which are sorted
alphabetically.
**Note**: `vjp_flops` returns `0` if the module is not differentiable.
Args:
Expand All @@ -267,7 +269,9 @@ def tabulate(
mutable.
console_kwargs: An optional dictionary with additional keyword arguments
that are passed to `rich.console.Console` when rendering the table.
Default arguments are `{'force_terminal': True, 'force_jupyter': False}`.
Default arguments are ``'force_terminal': True``, and ``'force_jupyter'``
is set to ``True`` if the code is running in a Jupyter notebook, otherwise
it is set to ``False``.
table_kwargs: An optional dictionary with additional keyword arguments that
are passed to `rich.table.Table` constructor.
column_kwargs: An optional dictionary with additional keyword arguments that
Expand Down Expand Up @@ -564,7 +568,7 @@ def _render_table(
non_params_cols: list[str],
) -> str:
"""A function that renders a Table to a string representation using rich."""
console_kwargs = {'force_terminal': True, 'force_jupyter': False}
console_kwargs = {'force_terminal': True, 'force_jupyter': in_ipython}
if console_extras is not None:
console_kwargs.update(console_extras)

Expand Down
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@
from .extract import to_tree as to_tree
from .extract import from_tree as from_tree
from .extract import NodeStates as NodeStates
from .summary import tabulate as tabulate
from . import traversals as traversals
4 changes: 3 additions & 1 deletion flax/nnx/filterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def to_predicate(filter: Filter) -> Predicate:
else:
raise TypeError(f'Invalid collection filter: {filter:!r}. ')

def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]:
def filters_to_predicates(
filters: tp.Sequence[Filter],
) -> tuple[Predicate, ...]:
for i, filter_ in enumerate(filters):
if filter_ in (..., True) and i != len(filters) - 1:
remaining_filters = filters[i + 1 :]
Expand Down
Loading

0 comments on commit a7d60e7

Please sign in to comment.