Skip to content

Commit

Permalink
Merge pull request #4192 from google:nnx-transforms-guide
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675269951
  • Loading branch information
Flax Authors committed Sep 16, 2024
2 parents d111adf + ea57438 commit 03e034d
Show file tree
Hide file tree
Showing 109 changed files with 6,820 additions and 403 deletions.
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
*.pyc
.tfds
.DS_Store
docs/**/_autosummary
docs/_build
dist/
build/
*.egg-info
*.rej
.pytype
.vscode/*
/.devcontainer
docs/**/tmp
docs*/**/_autosummary
docs*/_build
docs*/**/tmp

# used by direnv
.envrc
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
hooks:
- id: check-toml
- id: trailing-whitespace
exclude: ^docs/.*\.md$
exclude: ^docs*/.*\.md$
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
hooks:
Expand Down
1 change: 0 additions & 1 deletion docs/api_reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ API Reference
flax.core.frozen_dict
flax.cursor
flax.errors
flax.nnx/index
flax.jax_utils
flax.linen/index
flax.serialization
Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@
# href with no underline and white bold text color
announcement = """
<a
href="https://flax.readthedocs.io/en/latest/nnx/index.html"
href="https://flax-nnx.readthedocs.io/en/latest/index.html"
style="text-decoration: none; color: white;"
>
📣 Check out the new <b>NNX</b> API!
This is the Flax Linen site. Check out the new <b>Flax NNX</b>!
</a>
"""

Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,4 @@ Notable examples in Flax include:
philosophy
contributing
api_reference/index
Flax NNX <nnx/index>
Flax NNX <https://flax-nnx.readthedocs.io/en/latest/index.html>
1 change: 1 addition & 0 deletions docs_nnx/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_formatted_howtos
31 changes: 31 additions & 0 deletions docs_nnx/.readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# .readthedocs.yml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details

# Required
version: 2

build:
os: ubuntu-22.04
tools:
python: "3.10"

# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs_nnx/conf.py

# Optionally build your docs in additional formats such as PDF and ePub
formats:
- htmlzip
- epub
# - pdf

# Optionally set the version of Python and requirements required to build your docs
python:
install:
- method: pip
path: .
extra_requirements:
- all
- testing
- docs
20 changes: 20 additions & 0 deletions docs_nnx/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
148 changes: 148 additions & 0 deletions docs_nnx/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Where to find the docs

The FLAX documentation can be found here:
https://flax.readthedocs.io/en/latest/

# How to build the docs

1. Clone the `flax` repository with `git clone https://github.com/google/flax.git`.
2. In the main `flax` folder, install the required dependencies using `pip install -r docs/requirements.txt`.
3. Install [`pandoc`](https://pandoc.org): `pip install pandoc`.
4. [Optional] If you need to make any local changes to the docs, create and switch to a branch. Make your changes to the docs in that branch.
5. To build the docs, in the `flax/docs` folder run the make script: `make html`. Alternatively, install [`entr`](https://github.com/eradman/entr/), which helps run arbitrary commands when files change. Then run `find ../ ! -regex '.*/[\.|\_].*' | entr -s 'make html'`.
6. If the build is successful, you should get the `The HTML pages are in _build/html.` message. You can preview the docs in `flax/docs/_build/html`.

# How to run embedded code tests

We use `doctest` blocks for embedded code in documents, that are also
tested. Learn more at https://www.sphinx-doc.org/en/master/usage/extensions/doctest.html

To run tests locally, run `make doctest`

# How to write code documentation

Our documentation is written in reStructuredText for Sphinx. It is a
meta-language that is compiled into online documentation. For more details,
check out
[Sphinx's documentation](https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html).
As a result, our docstrings adhere to a specific syntax that has to be kept in
mind. Below we provide some guidelines.

To learn how to contribute to Jupyter Notebooks or other formats in Flax docs,
refer to the dedicated
[Contributing](https://flax.readthedocs.io/en/latest/contributing.html) page.

## How much information to put in a docstring

Docstring should be informative. We prefer to err on the side of too much
documentation than too little. For instance, providing a one-line explanation
to a new `Module` which implements new functionality is not sufficient.

Furthermore, we highly encourage adding examples to your docstrings, so users
can directly see how code can be used.

## How to write inline tested code

We use [doctest](https://docs.python.org/3/library/doctest.html) syntax for
writing examples in documentation. These examples are ran as tests as part of
our CI process. In order to write `doctest` code in your documentation, please
use the following notation:

```bash
# Example code::
#
# def sum(a, b):
# return a + b
#
# sum(0, 1)
```

The `Example code` string at the beginning can be replaced by anything as long
as there are two semicolons and a newline following it, and the code is
indented.

## How to use "code font"

When writing code font in a docstring, please use double backticks. Example:

```bash
# This returns a ``str`` object.
```

Note that argument names and objects like True, None or any strings should
usually be put in `code`.

## How to create cross-references/links

It is possible to create cross-references to other classes, functions, and
methods. In the following, `obj_typ` is either `class`, `func`, or `meth`.

```bash
# First method:
# <obj_type>:`path_to_obj`

# Second method:
# :<obj_type>:`description <path_to_obj>`
```

You can use the second method if the `path_to_obj` is very long. Some examples:

```bash
# Create: a reference to class flax.linen.Module.
# :class:`flax.linen.Module`

# Create a reference to local function my_func.
# :func:`my_func`

# Create a reference "Module.apply()" to method flax.linen.Module.apply.
# :meth:`Module.apply() <flax.linen.Module.apply>` #
```

To creata a hyperlink, use the following syntax:
```bash
# Note the double underscore at the end:
# `Link to Google <http://www.google.com>`__
```

### How to specify arguments for classes and methods

* Class attributes should be specified using the `Attributes:` tag.
* Method argument should be specified using the `Args:` tags.
* All attributes and arguments should have types.

Here is an example from our library:

```python
class DenseGeneral(Module):
"""A linear transformation with flexible axes.
Attributes:
features: int or tuple with number of output features.
axis: int or tuple with axes to apply the transformation on. For instance,
(-2, -1) will apply the transformation to the last two axes.
batch_dims: tuple with batch axes.
use_bias: whether to add a bias to the output (default: True).
dtype: the dtype of the computation (default: float32).
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
"""
features: Union[int, Iterable[int]]
axis: Union[int, Iterable[int]] = -1
batch_dims: Iterable[int] = ()
use_bias: bool = True
dtype: Dtype = jnp.float32
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
precision: Any = None

@compact
def __call__(self, inputs: Array) -> Array:
"""Applies a linear transformation to the inputs along multiple dimensions.
Args:
inputs: The nd-array to be transformed.
Returns:
The transformed input.
"""
...
```
Loading

0 comments on commit 03e034d

Please sign in to comment.