From ea57438f605f94c23c6ce186a77534c308f3e8d6 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 16 Sep 2024 21:40:14 +0200 Subject: [PATCH] split docs --- .gitignore | 6 +- .pre-commit-config.yaml | 2 +- docs/api_reference/index.rst | 1 - docs/conf.py | 4 +- docs/index.rst | 2 +- docs_nnx/.gitignore | 1 + docs_nnx/.readthedocs.yaml | 31 + docs_nnx/Makefile | 20 + docs_nnx/README.md | 148 +++ docs_nnx/_ext/codediff.py | 221 ++++ docs_nnx/_ext/codediff_test.py | 120 ++ docs_nnx/_ext/flax_module.py | 87 ++ docs_nnx/_static/css/flax_theme.css | 23 + .../_templates/autosummary/flax_module.rst | 29 + .../api_reference/flax.core.frozen_dict.rst | 18 + .../api_reference/flax.nnx/bridge.rst | 0 .../api_reference/flax.nnx/filterlib.rst | 0 .../api_reference/flax.nnx/graph.rst | 0 .../api_reference/flax.nnx/helpers.rst | 0 .../api_reference/flax.nnx/index.rst | 2 - .../api_reference/flax.nnx/module.rst | 0 .../api_reference/flax.nnx/nn/activations.rst | 0 .../api_reference/flax.nnx/nn/attention.rst | 0 .../api_reference/flax.nnx/nn/index.rst | 0 .../flax.nnx/nn/initializers.rst | 0 .../api_reference/flax.nnx/nn/linear.rst | 0 .../flax.nnx/nn/normalization.rst | 0 .../api_reference/flax.nnx/nn/stochastic.rst | 0 .../api_reference/flax.nnx/rnglib.rst | 0 .../api_reference/flax.nnx/spmd.rst | 0 .../api_reference/flax.nnx/state.rst | 0 .../api_reference/flax.nnx/training/index.rst | 0 .../flax.nnx/training/metrics.rst | 0 .../flax.nnx/training/optimizer.rst | 0 .../api_reference/flax.nnx/transforms.rst | 0 .../api_reference/flax.nnx/variables.rst | 0 .../api_reference/flax.nnx/visualization.rst | 0 docs_nnx/api_reference/flax.struct.rst | 13 + docs_nnx/api_reference/flax.training.rst | 12 + docs_nnx/api_reference/index.rst | 10 + docs_nnx/conf.py | 186 +++ docs_nnx/conf_sphinx_patch.py | 200 +++ docs_nnx/contributing.md | 297 +++++ docs_nnx/examples/community_examples.rst | 110 ++ docs_nnx/examples/core_examples.rst | 87 ++ .../examples/google_research_examples.rst | 269 ++++ docs_nnx/examples/index.rst | 12 + .../examples/repositories_that_use_flax.rst | 51 + docs_nnx/faq.rst | 38 + docs_nnx/flax.png | Bin 0 -> 20991 bytes docs_nnx/flip/0000-template.md | 25 + docs_nnx/flip/1009-optimizer-api.md | 504 ++++++++ docs_nnx/flip/1777-default-dtype.md | 133 ++ docs_nnx/flip/2396-rnn.md | 238 ++++ docs_nnx/flip/2434-general-metadata.md | 230 ++++ docs_nnx/flip/2974-kw-only-dataclasses.md | 99 ++ docs_nnx/flip/3099-rnnbase-refactor.md | 79 ++ .../flip/4105-jax-style-nnx-transforms.md | 177 +++ docs_nnx/flip/README.md | 32 + docs_nnx/glossary.rst | 112 ++ {docs/nnx => docs_nnx/guides}/blog.md | 0 .../guides}/bridge_guide.ipynb | 0 {docs/nnx => docs_nnx/guides}/bridge_guide.md | 30 +- {docs/nnx => docs_nnx/guides}/demo.ipynb | 0 {docs/nnx => docs_nnx/guides}/demo.md | 0 .../guides}/filters_guide.ipynb | 22 +- .../nnx => docs_nnx/guides}/filters_guide.md | 22 +- docs_nnx/guides/graph_mutations.ipynb | 23 + docs_nnx/guides/graph_mutations.md | 13 + .../guides}/haiku_linen_vs_nnx.rst | 0 .../guides}/images/stateful-transforms.png | Bin docs_nnx/guides/index.rst | 12 + .../guides/jax_and_nnx_transforms.rst | 0 .../nnx => docs_nnx/guides}/quick_start.ipynb | 0 {docs/nnx => docs_nnx/guides}/surgery.ipynb | 0 {docs/nnx => docs_nnx/guides}/surgery.md | 0 {docs/nnx => docs_nnx/guides}/tiny_nnx.ipynb | 0 {docs/nnx => docs_nnx/guides}/why.ipynb | 0 {docs/nnx => docs_nnx/guides}/why.md | 0 {docs/nnx => docs_nnx}/index.rst | 11 +- docs_nnx/linen_intro.ipynb | 1097 +++++++++++++++++ docs_nnx/linen_intro.md | 597 +++++++++ {docs/nnx => docs_nnx}/mnist_tutorial.ipynb | 18 +- {docs/nnx => docs_nnx}/mnist_tutorial.md | 18 +- {docs/nnx => docs_nnx}/nnx_basics.ipynb | 52 +- {docs/nnx => docs_nnx}/nnx_basics.md | 52 +- docs_nnx/philosophy.md | 121 ++ docs_nnx/quick_start.ipynb | 701 +++++++++++ docs_nnx/quick_start.md | 355 ++++++ docs_nnx/robots.txt | 5 + flax/core/meta.py | 14 - flax/errors.py | 9 - flax/linen/spmd.py | 15 - flax/nnx/bridge/variables.py | 30 +- flax/nnx/bridge/wrappers.py | 53 +- flax/nnx/errors.py | 17 + flax/nnx/extract.py | 18 +- flax/nnx/object.py | 2 +- flax/nnx/spmd.py | 10 +- flax/nnx/transforms/compilation.py | 1 - flax/nnx/transforms/iteration.py | 28 +- flax/nnx/variables.py | 98 +- tests/nnx/bridge/wrappers_test.py | 47 +- tests/nnx/module_test.py | 4 +- tests/nnx/rngs_test.py | 5 +- tests/nnx/spmd_test.py | 58 - tests/nnx/transforms_test.py | 23 +- tests/run_all_tests.sh | 2 + uv.lock | 11 +- 109 files changed, 6820 insertions(+), 403 deletions(-) create mode 100644 docs_nnx/.gitignore create mode 100644 docs_nnx/.readthedocs.yaml create mode 100644 docs_nnx/Makefile create mode 100644 docs_nnx/README.md create mode 100644 docs_nnx/_ext/codediff.py create mode 100644 docs_nnx/_ext/codediff_test.py create mode 100644 docs_nnx/_ext/flax_module.py create mode 100644 docs_nnx/_static/css/flax_theme.css create mode 100644 docs_nnx/_templates/autosummary/flax_module.rst create mode 100644 docs_nnx/api_reference/flax.core.frozen_dict.rst rename {docs => docs_nnx}/api_reference/flax.nnx/bridge.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/filterlib.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/graph.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/helpers.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/index.rst (95%) rename {docs => docs_nnx}/api_reference/flax.nnx/module.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/nn/activations.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/nn/attention.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/nn/index.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/nn/initializers.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/nn/linear.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/nn/normalization.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/nn/stochastic.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/rnglib.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/spmd.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/state.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/training/index.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/training/metrics.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/training/optimizer.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/transforms.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/variables.rst (100%) rename {docs => docs_nnx}/api_reference/flax.nnx/visualization.rst (100%) create mode 100644 docs_nnx/api_reference/flax.struct.rst create mode 100644 docs_nnx/api_reference/flax.training.rst create mode 100644 docs_nnx/api_reference/index.rst create mode 100644 docs_nnx/conf.py create mode 100644 docs_nnx/conf_sphinx_patch.py create mode 100644 docs_nnx/contributing.md create mode 100644 docs_nnx/examples/community_examples.rst create mode 100644 docs_nnx/examples/core_examples.rst create mode 100644 docs_nnx/examples/google_research_examples.rst create mode 100644 docs_nnx/examples/index.rst create mode 100644 docs_nnx/examples/repositories_that_use_flax.rst create mode 100644 docs_nnx/faq.rst create mode 100644 docs_nnx/flax.png create mode 100644 docs_nnx/flip/0000-template.md create mode 100644 docs_nnx/flip/1009-optimizer-api.md create mode 100644 docs_nnx/flip/1777-default-dtype.md create mode 100644 docs_nnx/flip/2396-rnn.md create mode 100644 docs_nnx/flip/2434-general-metadata.md create mode 100644 docs_nnx/flip/2974-kw-only-dataclasses.md create mode 100644 docs_nnx/flip/3099-rnnbase-refactor.md create mode 100644 docs_nnx/flip/4105-jax-style-nnx-transforms.md create mode 100644 docs_nnx/flip/README.md create mode 100644 docs_nnx/glossary.rst rename {docs/nnx => docs_nnx/guides}/blog.md (100%) rename {docs/nnx => docs_nnx/guides}/bridge_guide.ipynb (100%) rename {docs/nnx => docs_nnx/guides}/bridge_guide.md (98%) rename {docs/nnx => docs_nnx/guides}/demo.ipynb (100%) rename {docs/nnx => docs_nnx/guides}/demo.md (100%) rename {docs/nnx => docs_nnx/guides}/filters_guide.ipynb (96%) rename {docs/nnx => docs_nnx/guides}/filters_guide.md (95%) create mode 100644 docs_nnx/guides/graph_mutations.ipynb create mode 100644 docs_nnx/guides/graph_mutations.md rename {docs/nnx => docs_nnx/guides}/haiku_linen_vs_nnx.rst (100%) rename {docs/nnx => docs_nnx/guides}/images/stateful-transforms.png (100%) create mode 100644 docs_nnx/guides/index.rst rename docs/nnx/transforms.rst => docs_nnx/guides/jax_and_nnx_transforms.rst (100%) rename {docs/nnx => docs_nnx/guides}/quick_start.ipynb (100%) rename {docs/nnx => docs_nnx/guides}/surgery.ipynb (100%) rename {docs/nnx => docs_nnx/guides}/surgery.md (100%) rename {docs/nnx => docs_nnx/guides}/tiny_nnx.ipynb (100%) rename {docs/nnx => docs_nnx/guides}/why.ipynb (100%) rename {docs/nnx => docs_nnx/guides}/why.md (100%) rename {docs/nnx => docs_nnx}/index.rst (95%) create mode 100644 docs_nnx/linen_intro.ipynb create mode 100644 docs_nnx/linen_intro.md rename {docs/nnx => docs_nnx}/mnist_tutorial.ipynb (99%) rename {docs/nnx => docs_nnx}/mnist_tutorial.md (97%) rename {docs/nnx => docs_nnx}/nnx_basics.ipynb (99%) rename {docs/nnx => docs_nnx}/nnx_basics.md (96%) create mode 100644 docs_nnx/philosophy.md create mode 100644 docs_nnx/quick_start.ipynb create mode 100644 docs_nnx/quick_start.md create mode 100644 docs_nnx/robots.txt create mode 100644 flax/nnx/errors.py diff --git a/.gitignore b/.gitignore index ab2066a79c..2d436c7105 100644 --- a/.gitignore +++ b/.gitignore @@ -3,8 +3,6 @@ *.pyc .tfds .DS_Store -docs/**/_autosummary -docs/_build dist/ build/ *.egg-info @@ -12,7 +10,9 @@ build/ .pytype .vscode/* /.devcontainer -docs/**/tmp +docs*/**/_autosummary +docs*/_build +docs*/**/tmp # used by direnv .envrc diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2860a1d920..776f5c3d82 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: hooks: - id: check-toml - id: trailing-whitespace - exclude: ^docs/.*\.md$ + exclude: ^docs*/.*\.md$ - repo: https://github.com/kynan/nbstripout rev: 0.6.1 hooks: diff --git a/docs/api_reference/index.rst b/docs/api_reference/index.rst index 2c0d360254..3e062ff2f4 100644 --- a/docs/api_reference/index.rst +++ b/docs/api_reference/index.rst @@ -8,7 +8,6 @@ API Reference flax.core.frozen_dict flax.cursor flax.errors - flax.nnx/index flax.jax_utils flax.linen/index flax.serialization diff --git a/docs/conf.py b/docs/conf.py index 93d3d7009e..32dc8addfb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -113,10 +113,10 @@ # href with no underline and white bold text color announcement = """ - 📣 Check out the new NNX API! + This is the Flax Linen site. Check out the new Flax NNX! """ diff --git a/docs/index.rst b/docs/index.rst index ce04817a65..202f81e7e9 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -326,4 +326,4 @@ Notable examples in Flax include: philosophy contributing api_reference/index - Flax NNX + Flax NNX diff --git a/docs_nnx/.gitignore b/docs_nnx/.gitignore new file mode 100644 index 0000000000..2f934a9a92 --- /dev/null +++ b/docs_nnx/.gitignore @@ -0,0 +1 @@ +_formatted_howtos diff --git a/docs_nnx/.readthedocs.yaml b/docs_nnx/.readthedocs.yaml new file mode 100644 index 0000000000..eaf0b4f673 --- /dev/null +++ b/docs_nnx/.readthedocs.yaml @@ -0,0 +1,31 @@ +# .readthedocs.yml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.10" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs_nnx/conf.py + +# Optionally build your docs in additional formats such as PDF and ePub +formats: + - htmlzip + - epub + # - pdf + +# Optionally set the version of Python and requirements required to build your docs +python: + install: + - method: pip + path: . + extra_requirements: + - all + - testing + - docs diff --git a/docs_nnx/Makefile b/docs_nnx/Makefile new file mode 100644 index 0000000000..d4bb2cbb9e --- /dev/null +++ b/docs_nnx/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs_nnx/README.md b/docs_nnx/README.md new file mode 100644 index 0000000000..88ed7e24af --- /dev/null +++ b/docs_nnx/README.md @@ -0,0 +1,148 @@ +# Where to find the docs + +The FLAX documentation can be found here: +https://flax.readthedocs.io/en/latest/ + +# How to build the docs + +1. Clone the `flax` repository with `git clone https://github.com/google/flax.git`. +2. In the main `flax` folder, install the required dependencies using `pip install -r docs/requirements.txt`. +3. Install [`pandoc`](https://pandoc.org): `pip install pandoc`. +4. [Optional] If you need to make any local changes to the docs, create and switch to a branch. Make your changes to the docs in that branch. +5. To build the docs, in the `flax/docs` folder run the make script: `make html`. Alternatively, install [`entr`](https://github.com/eradman/entr/), which helps run arbitrary commands when files change. Then run `find ../ ! -regex '.*/[\.|\_].*' | entr -s 'make html'`. +6. If the build is successful, you should get the `The HTML pages are in _build/html.` message. You can preview the docs in `flax/docs/_build/html`. + +# How to run embedded code tests + +We use `doctest` blocks for embedded code in documents, that are also +tested. Learn more at https://www.sphinx-doc.org/en/master/usage/extensions/doctest.html + +To run tests locally, run `make doctest` + +# How to write code documentation + +Our documentation is written in reStructuredText for Sphinx. It is a +meta-language that is compiled into online documentation. For more details, +check out +[Sphinx's documentation](https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html). +As a result, our docstrings adhere to a specific syntax that has to be kept in +mind. Below we provide some guidelines. + +To learn how to contribute to Jupyter Notebooks or other formats in Flax docs, +refer to the dedicated +[Contributing](https://flax.readthedocs.io/en/latest/contributing.html) page. + +## How much information to put in a docstring + +Docstring should be informative. We prefer to err on the side of too much +documentation than too little. For instance, providing a one-line explanation +to a new `Module` which implements new functionality is not sufficient. + +Furthermore, we highly encourage adding examples to your docstrings, so users +can directly see how code can be used. + +## How to write inline tested code + +We use [doctest](https://docs.python.org/3/library/doctest.html) syntax for +writing examples in documentation. These examples are ran as tests as part of +our CI process. In order to write `doctest` code in your documentation, please +use the following notation: + +```bash +# Example code:: +# +# def sum(a, b): +# return a + b +# +# sum(0, 1) +``` + +The `Example code` string at the beginning can be replaced by anything as long +as there are two semicolons and a newline following it, and the code is +indented. + +## How to use "code font" + +When writing code font in a docstring, please use double backticks. Example: + +```bash +# This returns a ``str`` object. +``` + +Note that argument names and objects like True, None or any strings should +usually be put in `code`. + +## How to create cross-references/links + +It is possible to create cross-references to other classes, functions, and +methods. In the following, `obj_typ` is either `class`, `func`, or `meth`. + +```bash +# First method: +# :`path_to_obj` + +# Second method: +# ::`description ` +``` + +You can use the second method if the `path_to_obj` is very long. Some examples: + +```bash +# Create: a reference to class flax.linen.Module. +# :class:`flax.linen.Module` + +# Create a reference to local function my_func. +# :func:`my_func` + +# Create a reference "Module.apply()" to method flax.linen.Module.apply. +# :meth:`Module.apply() ` # +``` + +To creata a hyperlink, use the following syntax: +```bash +# Note the double underscore at the end: +# `Link to Google `__ +``` + +### How to specify arguments for classes and methods + +* Class attributes should be specified using the `Attributes:` tag. +* Method argument should be specified using the `Args:` tags. +* All attributes and arguments should have types. + +Here is an example from our library: + +```python +class DenseGeneral(Module): + """A linear transformation with flexible axes. + Attributes: + features: int or tuple with number of output features. + axis: int or tuple with axes to apply the transformation on. For instance, + (-2, -1) will apply the transformation to the last two axes. + batch_dims: tuple with batch axes. + use_bias: whether to add a bias to the output (default: True). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + bias_init: initializer function for the bias. + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + """ + features: Union[int, Iterable[int]] + axis: Union[int, Iterable[int]] = -1 + batch_dims: Iterable[int] = () + use_bias: bool = True + dtype: Dtype = jnp.float32 + kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + precision: Any = None + + @compact + def __call__(self, inputs: Array) -> Array: + """Applies a linear transformation to the inputs along multiple dimensions. + Args: + inputs: The nd-array to be transformed. + Returns: + The transformed input. + """ + ... +``` \ No newline at end of file diff --git a/docs_nnx/_ext/codediff.py b/docs_nnx/_ext/codediff.py new file mode 100644 index 0000000000..3c0a8c0248 --- /dev/null +++ b/docs_nnx/_ext/codediff.py @@ -0,0 +1,221 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sphinx directive for creating code diff tables. + +Use directive as follows: + +.. codediff:: + :title: , + + + --- + + +In order to highlight a line of code, append "#!" to it. +""" + + +import sphinx +from docutils import nodes +from docutils.parsers.rst import directives +from docutils.statemachine import ViewList +from sphinx.util.docutils import SphinxDirective + +MISSING = object() + + +class CodeDiffParser: + def parse( + self, + lines: list[str], + title: str, + groups: list[str] | None = None, + skip_test: str | None = None, + code_sep: str = '---', + sync: object = MISSING, + ): + """Parse the code diff block and format it so that it + renders in different tabs and is tested by doctest. + + For example: + + .. testcode:: tab0, tab2, tab3 + + + + .. codediff:: + :title: Tab 0, Tab 1, Tab 2, Tab 3 + :groups: tab0, tab1, tab2, tab3 + :skip_test: tab1, tab3 + + + + --- + + + + --- + + + + --- + + + + For group tab0: and are executed. + For group tab1: Nothing is executed. + For group tab2: and are executed. + For group tab3: is executed. + + Arguments: + lines: a string list, where each element is a single string code line + title: a single string that contains the titles of each tab (they should + be separated by commas) + groups: a single string that contains the group of each tab (they should + be separated by commas). Code snippets that are part of the same group + will be executed together. If groups=None, then the group names will + default to the tab title names. + skip_test: a single string denoting which group(s) to skip testing (they + should be separated by commas). This is useful for legacy code snippets + that no longer run correctly anymore. If skip_test=None, then no tests + are skipped. + code_sep: the separator character(s) used to denote a separate code block + for a new tab. The default code separator is '---'. + sync: an option for Sphinx directives, that will sync all tabs together. + This means that if the user clicks to switch to another tab, all tabs + will switch to the new tab. + """ + titles = [t.strip() for t in title.split(',')] + num_tabs = len(titles) + + sync = sync is not MISSING + # skip legacy code snippets in upgrade guides + if skip_test is not None: + skip_tests = {index.strip() for index in skip_test.split(',')} + else: + skip_tests = set() + + code_blocks = '\n'.join(lines) + if code_blocks.count(code_sep) != num_tabs - 1: + raise ValueError( + f'Expected {num_tabs-1} code separator(s) for {num_tabs} tab(s), but got {code_blocks.count(code_sep)} code separator(s) instead.' + ) + code_blocks = [ + code_block.split('\n') + for code_block in code_blocks.split(code_sep + '\n') + ] # list[code_tab_list1[string_line1, ...], ...] + + # by default, put each code snippet in a different group denoted by an index number, to be executed separately + if groups is not None: + groups = [group_name.strip() for group_name in groups.split(',')] + else: + groups = titles + if len(groups) != num_tabs: + raise ValueError( + f'Expected {num_tabs} group assignment(s) for {num_tabs} tab(s), but got {len(groups)} group assignment(s) instead.' + ) + + tabs = [] + test_codes = [] + for i, code_block in enumerate(code_blocks): + if groups[i] not in skip_tests: + test_codes.append((code_block, groups[i])) + tabs.append((titles[i], self._code_block(code_block))) + output = self._tabs(*tabs, sync=sync) + + return output, test_codes + + def _code_block(self, lines): + """Creates a codeblock.""" + # Remove right trailing whitespace so we can detect the comments. + lines = [x.rstrip() for x in lines] + highlight = lambda x: x.endswith('#!') + code = map(lambda x: x[:-2].rstrip() if highlight(x) else x, lines) + highlights = [i + 1 for i in range(len(lines)) if highlight(lines[i])] + highlights = ','.join(str(i) for i in highlights) + + directive = ['.. code-block:: python'] + if highlights: + directive += [f' :emphasize-lines: {highlights}'] + + # Indent code and add empty line so the code is picked up by the directive. + return directive + [''] + list(map(lambda x: ' ' + x, code)) + + def _tabs(self, *contents: tuple[str, list[str]], sync): + output = ['.. tab-set::'] + [' '] + + for title, content in contents: + output += [f' .. tab-item:: {title}'] + + if sync: + key = title.strip() + output += [f' :sync: {key}'] + + output += [' '] + output += [' ' + line for line in content] + + return output + + +class CodeDiffDirective(SphinxDirective): + has_content = True + option_spec = { + 'title': directives.unchanged, + 'groups': directives.unchanged, + 'skip_test': directives.unchanged, + 'code_sep': directives.unchanged, + 'sync': directives.flag, + } + + def run(self): + table_code, test_codes = CodeDiffParser().parse( + list(self.content), **self.options + ) + + # Create a test node as a comment node so it won't show up in the docs. + # We add attribute "testnodetype" so it is be picked up by the doctest + # builder. This functionality is not officially documented but can be found + # in the source code: + # https://github.com/sphinx-doc/sphinx/blob/master/sphinx/ext/doctest.py + # (search for 'testnodetype'). + test_nodes = [] + for test_code, group in test_codes: + test_node = nodes.comment( + '\n'.join(test_code), + '\n'.join(test_code), + testnodetype='testcode', + groups=[group], + ) + self.set_source_info(test_node) + test_node['options'] = {} + test_node['language'] = 'python3' + test_nodes.append(test_node) + + # The table node is the side-by-side diff view that will be shown on RTD. + table_node = nodes.paragraph() + self.content = ViewList(table_code, self.content.parent) + self.state.nested_parse(self.content, self.content_offset, table_node) + + return [table_node] + test_nodes + + +def setup(app): + app.add_directive('codediff', CodeDiffDirective) + + return { + 'version': sphinx.__display_version__, + 'parallel_read_safe': True, + 'parallel_write_safe': True, + } diff --git a/docs_nnx/_ext/codediff_test.py b/docs_nnx/_ext/codediff_test.py new file mode 100644 index 0000000000..83e15733c9 --- /dev/null +++ b/docs_nnx/_ext/codediff_test.py @@ -0,0 +1,120 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for codediff Sphinx extension.""" + +from absl.testing import parameterized +from codediff import CodeDiffParser + + +class CodeDiffTest(parameterized.TestCase): + def test_parse(self): + input_text = r"""@jax.jit #! +def get_initial_params(key): #! + init_val = jnp.ones((1, 28, 28, 1), jnp.float32) + initial_params = CNN().init(key, init_val)['params'] + extra_line + return initial_params +--- +@jax.pmap #! +def get_initial_params(key): + init_val = jnp.ones((1, 28, 28, 1), jnp.float32) + initial_params = CNN().init(key, init_val)['params'] + return initial_params""" + + expected_table = """.. tab-set::\n \n .. tab-item:: Single device\n \n .. code-block:: python\n :emphasize-lines: 1,2\n \n @jax.jit\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n extra_line\n return initial_params\n \n .. tab-item:: Ensembling on multiple devices\n \n .. code-block:: python\n :emphasize-lines: 1\n \n @jax.pmap\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n return initial_params""" + + expected_testcodes = [ + r"""@jax.jit #! +def get_initial_params(key): #! + init_val = jnp.ones((1, 28, 28, 1), jnp.float32) + initial_params = CNN().init(key, init_val)['params'] + extra_line + return initial_params +""", + r"""@jax.pmap #! +def get_initial_params(key): + init_val = jnp.ones((1, 28, 28, 1), jnp.float32) + initial_params = CNN().init(key, init_val)['params'] + return initial_params""", + ] + + title_left = 'Single device' + title_right = 'Ensembling on multiple devices' + + actual_table, actual_testcodes = CodeDiffParser().parse( + lines=input_text.split('\n'), + title=f'{title_left}, {title_right}', + ) + + actual_table = '\n'.join(actual_table) + actual_testcodes = ['\n'.join(testcode) for testcode, _ in actual_testcodes] + + self.assertEqual(expected_table, actual_table) + self.assertEqual(expected_testcodes[0], actual_testcodes[0]) + self.assertEqual(expected_testcodes[1], actual_testcodes[1]) + + @parameterized.parameters( + { + 'input_text': r"""x = 1 + --- + x = 2 +""", + 'title': 'Tab 0, Tab1, Tab2', + 'groups': None, + 'error_msg': 'Expected 2 code separator\\(s\\) for 3 tab\\(s\\), but got 1 code separator\\(s\\) instead.', + }, + { + 'input_text': r"""x = 1 + --- + x = 2 + --- + x = 3 + --- + x = 4 +""", + 'title': 'Tab 0, Tab1, Tab2', + 'groups': None, + 'error_msg': 'Expected 2 code separator\\(s\\) for 3 tab\\(s\\), but got 3 code separator\\(s\\) instead.', + }, + { + 'input_text': r"""x = 1 + --- + x = 2 + --- + x = 3 +""", + 'title': 'Tab 0, Tab1, Tab2', + 'groups': 'tab0, tab2', + 'error_msg': 'Expected 3 group assignment\\(s\\) for 3 tab\\(s\\), but got 2 group assignment\\(s\\) instead.', + }, + { + 'input_text': r"""x = 1 + --- + x = 2 + --- + x = 3 +""", + 'title': 'Tab 0, Tab1, Tab2', + 'groups': 'tab0, tab1, tab2, tab3', + 'error_msg': 'Expected 3 group assignment\\(s\\) for 3 tab\\(s\\), but got 4 group assignment\\(s\\) instead.', + }, + ) + def test_parse_errors(self, input_text, title, groups, error_msg): + with self.assertRaisesRegex(ValueError, error_msg): + _, _ = CodeDiffParser().parse( + lines=input_text.split('\n'), + title=title, + groups=groups, + ) diff --git a/docs_nnx/_ext/flax_module.py b/docs_nnx/_ext/flax_module.py new file mode 100644 index 0000000000..8cfebd9bd5 --- /dev/null +++ b/docs_nnx/_ext/flax_module.py @@ -0,0 +1,87 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sphinx directive for visualizing Flax modules. + +Use directive as follows: + +.. flax_module:: + :module: flax.linen + :class: Dense +""" + +import importlib + +import sphinx +import sphinx.ext.autosummary.generate as ag +from docutils import nodes +from docutils.parsers.rst import directives +from docutils.statemachine import ViewList +from sphinx.util.docutils import SphinxDirective + +from docs.conf_sphinx_patch import generate_autosummary_content + + +def render_module(modname: str, qualname: str, app): + parent = importlib.import_module(modname) + obj = getattr(parent, qualname) + template = ag.AutosummaryRenderer(app) + template_name = 'flax_module' + imported_members = False + recursive = False + context = {} + return generate_autosummary_content( + qualname, + obj, + parent, + template, + template_name, + imported_members, + app, + recursive, + context, + modname, + qualname, + ) + + +class FlaxModuleDirective(SphinxDirective): + has_content = True + option_spec = { + 'module': directives.unchanged, + 'class': directives.unchanged, + } + + def run(self): + module_template = render_module( + self.options['module'], self.options['class'], self.env.app + ) + module_template = module_template.splitlines() + + # Create a container for the rendered nodes + container_node = nodes.container() + self.content = ViewList(module_template, self.content.parent) + self.state.nested_parse(self.content, self.content_offset, container_node) + + return [container_node] + + +def setup(app): + app.add_directive('flax_module', FlaxModuleDirective) + + return { + 'version': sphinx.__display_version__, + 'parallel_read_safe': True, + 'parallel_write_safe': True, + } diff --git a/docs_nnx/_static/css/flax_theme.css b/docs_nnx/_static/css/flax_theme.css new file mode 100644 index 0000000000..b8207032d2 --- /dev/null +++ b/docs_nnx/_static/css/flax_theme.css @@ -0,0 +1,23 @@ +@import url("theme.css"); + +.wy-nav-content { + max-width: 1290px; +} + +.rst-content table.docutils { + width: 100%; +} + +.rst-content table.docutils td { + vertical-align: top; + padding: 0; +} + +.rst-content table.docutils td p { + padding: 8px; +} + +.rst-content div[class^=highlight] { + border: 0; + margin: 0; +} diff --git a/docs_nnx/_templates/autosummary/flax_module.rst b/docs_nnx/_templates/autosummary/flax_module.rst new file mode 100644 index 0000000000..21b8d8c6cb --- /dev/null +++ b/docs_nnx/_templates/autosummary/flax_module.rst @@ -0,0 +1,29 @@ +{{ fullname | escape | underline }} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :exclude-members: + + .. automethod:: __call__ + + {% block methods %} + + {% for item in methods %} + {%- if item not in inherited_members and item not in annotations and not item in ['__init__', 'setup'] %} + .. automethod:: {{ item }} + {%- endif %} + {%- endfor %} + + {% if methods %} + .. rubric:: Methods + + .. autosummary:: + + {% for item in methods %} + {%- if item not in inherited_members and item not in annotations and not item in ['__init__', 'setup'] %} + ~{{ name }}.{{ item }} + {%- endif %} + {%- endfor %} + {% endif %} + {% endblock %} \ No newline at end of file diff --git a/docs_nnx/api_reference/flax.core.frozen_dict.rst b/docs_nnx/api_reference/flax.core.frozen_dict.rst new file mode 100644 index 0000000000..30bca1faf6 --- /dev/null +++ b/docs_nnx/api_reference/flax.core.frozen_dict.rst @@ -0,0 +1,18 @@ + +flax.core.frozen_dict package +============================= + +.. currentmodule:: flax.core.frozen_dict + +.. autoclass:: FrozenDict + :members: pretty_repr, copy, pop, unfreeze, tree_flatten + +.. autofunction:: freeze + +.. autofunction:: unfreeze + +.. autofunction:: copy + +.. autofunction:: pop + +.. autofunction:: pretty_repr diff --git a/docs/api_reference/flax.nnx/bridge.rst b/docs_nnx/api_reference/flax.nnx/bridge.rst similarity index 100% rename from docs/api_reference/flax.nnx/bridge.rst rename to docs_nnx/api_reference/flax.nnx/bridge.rst diff --git a/docs/api_reference/flax.nnx/filterlib.rst b/docs_nnx/api_reference/flax.nnx/filterlib.rst similarity index 100% rename from docs/api_reference/flax.nnx/filterlib.rst rename to docs_nnx/api_reference/flax.nnx/filterlib.rst diff --git a/docs/api_reference/flax.nnx/graph.rst b/docs_nnx/api_reference/flax.nnx/graph.rst similarity index 100% rename from docs/api_reference/flax.nnx/graph.rst rename to docs_nnx/api_reference/flax.nnx/graph.rst diff --git a/docs/api_reference/flax.nnx/helpers.rst b/docs_nnx/api_reference/flax.nnx/helpers.rst similarity index 100% rename from docs/api_reference/flax.nnx/helpers.rst rename to docs_nnx/api_reference/flax.nnx/helpers.rst diff --git a/docs/api_reference/flax.nnx/index.rst b/docs_nnx/api_reference/flax.nnx/index.rst similarity index 95% rename from docs/api_reference/flax.nnx/index.rst rename to docs_nnx/api_reference/flax.nnx/index.rst index d1fef3a0aa..af7da4a53c 100644 --- a/docs/api_reference/flax.nnx/index.rst +++ b/docs_nnx/api_reference/flax.nnx/index.rst @@ -19,5 +19,3 @@ Experimental API. See the `NNX page + This is the Flax NNX site. Click here for Flax Linen. + +""" + +html_theme_options = { + 'repository_url': 'https://github.com/google/flax', + 'use_repository_button': True, # add a 'link to repository' button + 'use_issues_button': False, # add an 'Open an Issue' button + 'path_to_docs': ( + 'docs' + ), # used to compute the path to launch notebooks in colab + 'launch_buttons': { + 'colab_url': 'https://colab.research.google.com/', + }, + 'prev_next_buttons_location': None, + 'show_navbar_depth': 1, + 'announcement': announcement, +} + +# -- Options for myst ---------------------------------------------- +# uncomment line below to avoid running notebooks during development +nb_execution_mode = 'off' +# Notebook cell execution timeout; defaults to 30. +nb_execution_timeout = 100 +# List of patterns, relative to source directory, that match notebook +# files that will not be executed. +myst_enable_extensions = ['dollarmath'] +nb_execution_excludepatterns = [ + 'quick_start.ipynb', # <-- times out + 'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0 + 'flax/nnx', # exclude nnx +] +# raise exceptions on execution so CI can catch errors +nb_execution_allow_errors = False +nb_execution_raise_on_error = True + +# -- Extension configuration ------------------------------------------------- + +# Tell sphinx-autodoc-typehints to generate stub parameter annotations including +# types, even if the parameters aren't explicitly documented. +always_document_param_types = True + +# -- doctest configuration ------------------------------------------------- +doctest_global_setup = """ +import jax +import jax.numpy as jnp +from flax import nnx + +import logging as slog +from absl import logging as alog + +# Avoid certain absl logging messages to break doctest +filtered_message = [ + 'SaveArgs.aggregate is deprecated', + '', +] + +class _CustomLogFilter(slog.Formatter): + def format(self, record): + message = super(_CustomLogFilter, self).format(record) + for m in filtered_message: + if m in message: + return '' + return message + +alog.use_absl_handler() +alog.get_absl_handler().setFormatter(_CustomLogFilter()) +""" diff --git a/docs_nnx/conf_sphinx_patch.py b/docs_nnx/conf_sphinx_patch.py new file mode 100644 index 0000000000..a423b79405 --- /dev/null +++ b/docs_nnx/conf_sphinx_patch.py @@ -0,0 +1,200 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Patch Sphinx to improve documentation aesthetics.""" + +# TODO(cgarciae): Send a PR to sphinx to upstream this fix. Issue: https://github.com/google/flax/issues/2196 +# This patch is needed to make autosummary provide the "annotations" +# variable so we can exclude function attributes from the methods list +# in flax_module.rst. The patch as such only adds this single line: +# +# ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys())' +# +# We should consider sending a PR to sphinx so we can get rid of this. +# Original source: https://github.com/sphinx-doc/sphinx/blob/0aedcc9a916daa92d477226da67d33ce1831822e/sphinx/ext/autosummary/generate.py#L211-L351 +from typing import Any + +import sphinx.ext.autodoc +import sphinx.ext.autosummary.generate as ag + + +def generate_autosummary_content( + name: str, + obj: Any, + parent: Any, + template: ag.AutosummaryRenderer, + template_name: str, + imported_members: bool, + app: Any, + recursive: bool, + context: dict, + modname: str = None, + qualname: str = None, +) -> str: + doc = ag.get_documenter(app, obj, parent) + + def skip_member(obj: Any, name: str, objtype: str) -> bool: + try: + return app.emit_firstresult( + 'autodoc-skip-member', objtype, name, obj, False, {} + ) + except Exception as exc: + ag.logger.warning( + __( + 'autosummary: failed to determine %r to be documented, ' + 'the following exception was raised:\n%s' + ), + name, + exc, + type='autosummary', + ) + return False + + def get_class_members(obj: Any) -> dict[str, Any]: + members = sphinx.ext.autodoc.get_class_members( + obj, [qualname], ag.safe_getattr + ) + return {name: member.object for name, member in members.items()} + + def get_module_members(obj: Any) -> dict[str, Any]: + members = {} + for name in ag.members_of(obj, app.config): + try: + members[name] = ag.safe_getattr(obj, name) + except AttributeError: + continue + return members + + def get_all_members(obj: Any) -> dict[str, Any]: + if doc.objtype == 'module': + return get_module_members(obj) + elif doc.objtype == 'class': + return get_class_members(obj) + return {} + + def get_members( + obj: Any, + types: set[str], + include_public: list[str] = [], + imported: bool = True, + ) -> tuple[list[str], list[str]]: + items: list[str] = [] + public: list[str] = [] + + all_members = get_all_members(obj) + for name, value in all_members.items(): + documenter = ag.get_documenter(app, value, obj) + if documenter.objtype in types: + # skip imported members if expected + if imported or getattr(value, '__module__', None) == obj.__name__: + skipped = skip_member(value, name, documenter.objtype) + if skipped is True: + pass + elif skipped is False: + # show the member forcedly + items.append(name) + public.append(name) + else: + items.append(name) + if name in include_public or not name.startswith('_'): + # considers member as public + public.append(name) + return public, items + + def get_module_attrs(members: Any) -> tuple[list[str], list[str]]: + """Find module attributes with docstrings.""" + attrs, public = [], [] + try: + analyzer = ag.ModuleAnalyzer.for_module(name) + attr_docs = analyzer.find_attr_docs() + for namespace, attr_name in attr_docs: + if namespace == '' and attr_name in members: + attrs.append(attr_name) + if not attr_name.startswith('_'): + public.append(attr_name) + except ag.PycodeError: + pass # give up if ModuleAnalyzer fails to parse code + return public, attrs + + def get_modules(obj: Any) -> tuple[list[str], list[str]]: + items: list[str] = [] + for _, modname, _ispkg in ag.pkgutil.iter_modules(obj.__path__): + fullname = name + '.' + modname + try: + module = ag.import_module(fullname) + if module and hasattr(module, '__sphinx_mock__'): + continue + except ImportError: + pass + + items.append(fullname) + public = [x for x in items if not x.split('.')[-1].startswith('_')] + return public, items + + ns: dict[str, Any] = {} + ns.update(context) + + if doc.objtype == 'module': + scanner = ag.ModuleScanner(app, obj) + ns['members'] = scanner.scan(imported_members) + ns['functions'], ns['all_functions'] = get_members( + obj, {'function'}, imported=imported_members + ) + ns['classes'], ns['all_classes'] = get_members( + obj, {'class'}, imported=imported_members + ) + ns['exceptions'], ns['all_exceptions'] = get_members( + obj, {'exception'}, imported=imported_members + ) + ns['attributes'], ns['all_attributes'] = get_module_attrs(ns['members']) + ispackage = hasattr(obj, '__path__') + if ispackage and recursive: + ns['modules'], ns['all_modules'] = get_modules(obj) + elif doc.objtype == 'class': + ns['members'] = dir(obj) + ns['inherited_members'] = set(dir(obj)) - set(obj.__dict__.keys()) + ns['methods'], ns['all_methods'] = get_members( + obj, {'method'}, ['__init__'] + ) + ns['attributes'], ns['all_attributes'] = get_members( + obj, {'attribute', 'property'} + ) + ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys()) + + if modname is None or qualname is None: + modname, qualname = ag.split_full_qualified_name(name) + + if doc.objtype in ('method', 'attribute', 'property'): + ns['class'] = qualname.rsplit('.', 1)[0] + + if doc.objtype in ('class',): + shortname = qualname + else: + shortname = qualname.rsplit('.', 1)[-1] + + ns['fullname'] = name + ns['module'] = modname + ns['objname'] = qualname + ns['name'] = shortname + + ns['objtype'] = doc.objtype + ns['underline'] = len(name) * '=' + + if template_name: + return template.render(template_name, ns) + else: + return template.render(doc.objtype, ns) + + +ag.generate_autosummary_content = generate_autosummary_content diff --git a/docs_nnx/contributing.md b/docs_nnx/contributing.md new file mode 100644 index 0000000000..72c48ae1af --- /dev/null +++ b/docs_nnx/contributing.md @@ -0,0 +1,297 @@ +# How to contribute + +Everyone can contribute to Flax, and the Flax development team values everyone's contributions! +You can contribute in many more ways than just writing code. Answering questions +on the [Flax GitHub Discussions page](https://github.com/google/flax/discussions), helping +each other, and improving Flax documentation are extremely valuable to the Flax +ecosystem. + +We also appreciate if you spread the word, for instance by starring the [Flax GitHub repository](https://github.com/google/flax), +or referencing Flax in blog posts of projects that used it. + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google/conduct/). + +## Ways to contribute + +We welcome pull requests (PRs), in particular for those issues +[marked as PR-ready](https://github.com/google/flax/issues?q=is%3Aopen+is%3Aissue+label%3A%22Status%3A+pull+requests+welcome%22). +For other proposals, you should first open a GitHub Issue or a GitHub Discussion to +start a conversation about your planned contribution. + +## Contributing code using pull requests + +The Flax development team performs all development using [Git](https://git-scm.com/). To contribute, +you should have basic knowledge of [Git](https://git-scm.com/) and [GitHub](https://docs.github.com). +(You can learn how to set up Git by following Git's official +[Getting Started - First-Time Git Setup](https://git-scm.com/book/en/v2/Getting-Started-First-Time-Git-Setup) +and GitHub's [Set Up Git](https://docs.github.com/en/get-started/quickstart/set-up-git) guides.) + +To contribute code to Flax on GitHub, follow these steps: + +### To create a pull request from a fork + +1. Using GitHub's web UI, fork the Flax repository by clicking the 'Fork' button on the + [`github.com/google/flax` repository page](http://www.github.com/google/flax). This creates a + fork (a copy) of the Flax repository in your own GitHub. + + Reference: [Creating a pull request from a fork](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork). + +2. Install [Python >=3.7](https://www.python.org/downloads/). + +3. (Optional) Create a virtual environment or a Docker container. See + [`dev/README.md`](https://github.com/google/flax/blob/main/dev/README.md) + for details on how to set up a Docker Container. To set up a virtual environment, + run the following: + + ```bash + python3 -m virtualenv env + . env/bin/activate + ``` + + This ensures all your dependencies are installed in this environment. + +4. Clone your local forked Flax repo with `git clone`. Then, install the required packages + with [PyPi](https://pip.pypa.io/en/stable/cli/pip_install/). This enables you to immediately + test the code after modifying it: + + ```bash + git clone https://github.com/YOUR_USERNAME/flax + cd flax + pip install -e ".[all,testing,docs]" + ``` + + You can also use [uv](https://docs.astral.sh/uv/) to setup + the development environment: + + ```bash + uv sync --all-extras + ``` + +5. Set up pre-commit hooks, this will run some automated checks during each `git` commit and + possibly update some files that require changes. + + ```bash + pip install pre-commit + pre-commit install + ``` + +6. Add the Google Flax repo (not your fork) as an upstream remote, so you can use it to sync your + changes. + + ```bash + git remote add upstream http://www.github.com/google/flax + ``` + + +7. Create a branch, such as `my_development_branch`, you will develop from: + + ```bash + git checkout -b my_development_branch + ``` + +8. Implement your changes using your favorite editor (we recommend + [Visual Studio Code](https://code.visualstudio.com/)). + + Make sure the tests pass by running the following command from the top of + the repository: + + ```bash + ./tests/run_all_tests.sh + ``` + +9. Once you finish making changes, don't forget to create commits + ([learn how to write a commit message](https://chris.beams.io/posts/git-commit/)): + + ```bash + git add file1.py file2.py ... + # or use `git add .` to add all changed files + git commit -m "Your commit message" + ``` + + Then sync your code with the main repository: + + ```bash + git fetch upstream + git rebase upstream/main + ``` + +10. Finally, push your commit on your `my_development_branch`, and create a remote + branch in your fork that you can use to create a pull request from: + + ```bash + git push --set-upstream origin my_development_branch + ``` + + After running the command, you should get a GitHub link in your (VS Code) terminal output for creating a pull request. + If you don't receive a link after `git push`, use the [GitHub web UI](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request?tool=webui) to create a pull request. + +11. Make sure your pull request passes the + [Flax PR checklist](https://github.com/google/flax/blob/main/.github/pull_request_template.md#checklist). + If so, create a pull request from the Flax repository and send it for review. + Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) + for more information on using pull requests. + +You can learn more in GitHub's [Creating a pull request from a fork +](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork). documentation. + +### Adding or updating dependencies + +To add or update dependencies, you must use `uv` after +updating the `pyproject.toml` file to ensure that the `uv.lock` file is up-to-date. + +```bash +uv sync --all-extras +``` +Alternatively use can use `uv add` to add or update the dependencies automatically, for example: + +```bash +uv add 'some-package>=1.2.3' +``` + +### Updating Jupyter Notebooks + +We use [jupytext](https://jupytext.readthedocs.io/) to maintain two synced copies of docs +in `docs/notebooks`: one in the Jupyter Notebook (`.ipynb`) format, and one in Markdown (`.md`). + +The former can be opened and executed directly in [Google Colab](https://colab.research.google.com/). +Markdown makes it easier to track changes/diffs within version control and, for example, GitHub +web UI, since `.ipynb` files are based on JSON. + +#### Editing Jupyter Notebooks (`.ipynb`) + +For making large changes that substantially modify code and outputs, it's recommended to edit +the notebooks in [Jupyter](https://jupyter.org/install) or in [Colab](https://colab.research.google.com/). + +If you choose to work in Colab, go to **File** and click **Upload notebook**, then pick your file. +After loading it into Colab and editing it, make sure you run the cells, and that there aren't any errors. +Click on **Runtime**, then select **Run all**. After you finish, click **File** > **Download** > **Download ipynb**. +You may also want to test that the file executes properly by using `sphinx-build`, as explained above. + +After you make changes in your Jupyter Notebook, follow the steps _Syncing notebooks_ below. + +#### Editing Markdown files (`.md`) + +For making smaller changes to the text content of the notebooks, it is easiest to edit the +`.md` versions using a text editor. + +After you make changes in your Markdown file, follow the steps _Syncing notebooks_ below. + +#### Syncing notebooks + +After editing either the `.ipynb` or `.md` versions of the docs, sync the two versions +using [jupytext](https://jupytext.readthedocs.io/) by running `jupytext --sync` on the updated +notebooks. + +First, make sure you have jupytext installed. The jupytext version should match +the one specified in [.pre-commit-config.yaml](https://github.com/google/flax/blob/main/.pre-commit-config.yaml) +(currently, it is v1.13.8). + +```bash +pip install jupytext==1.13.8 +``` + +Then, after you have made your changes in the Jupyter Notebook, sync the contents with its Markdown-equivalent +file by running the following command: + +```bash +jupytext --sync path/to/the/file.ipynb +``` + +Similarly, to sync your Markdown file with its Jupyter Notebook version, run: + +```bash +jupytext --sync path/to/the/file.md +``` + +Note that if you receive an error, and it is the first time you worked in a Jupyter Notebook, you may need +to (re)create a synced copy of the document (which is explained in detail in _Creating new notebooks_ section below): + +```bash +jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb +``` + +Once you're finished with syncing the `.md` and `.ipynb` files, you can check that they are properly synced using the +[pre-commit](https://pre-commit.com/) framework to perform the same checks used +in the Flax GitHub CI: + +```bash +git add docs -u # pre-commit runs on files in git staging. +pre-commit run jupytext +``` + +#### Creating new notebooks + +If you are adding a new Jupyter Notebook to the documentation, you can use `jupytext --set-formats`. +It can set up both the Jupyter Notebook (`.ipynb`) and Markdown (`.md`) versions of the file: + +```bash +jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb +``` + +This works by adding a `"jupytext"` metadata field to the notebook file which specifies the +desired formats. The `jupytext --sync` command can then recognize them when invoked. + +After you make changes in your file(s), follow the steps from the _Syncing notebooks_ +section above to keep the contents of both Markdown and Jupyter Notebook files in sync. + +#### Notebooks within the Sphinx build + +Some of the notebooks are built automatically as part of the pre-submit checks and +as part of the [Read the Docs](https://flax.readthedocs.io/en/latest) build. +The build will fail if cells raise errors. If the errors are intentional, you can either catch them, +or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/google/jax/pull/2402/files)). +You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else +re-saves the notebook. + +We exclude some notebooks from the build because, for example, they contain long computations. +See `exclude_patterns` in [`conf.py`](https://github.com/google/flax/blob/main/docs/conf.py). + +### Updating the pull request contents + +Every pull request should ideally be limited to just one commit, so if you have multiple commits please squash them. + +Assuming you now have only one commit in your pull request, and want to add changes requested during review: + +1. Make the changes locally in your editor. +2. Run `git commit -a --amend`. This updates the commit contents and allows you to edit the commit message. +3. At this point, `git push` alone will result in an error. Instead, use `git push --force`. +4. Check that it's done: The changes to your commit should be immediately reflected in the Github web UI. + +## Troubleshooting + +### Too many commits in a pull request + +If your PR has too many commits associated with it (for example, more than five), +you need to squash them. Otherwise, the Flax docs build process may fail with an +error message. This is because of the following reasons: + +* There are more than five commits in your pull request; and +* The Flax source sync process fails when the commit tree is too large. + +To squash your commits, you can rebase your branch to `main` and create a new +commit containing all your changes, run the following command: + +```bash +git rebase main && git reset --soft main && git commit +``` + +This will apply all your changes to the main branch. Note that if you had to +resolve any conflicts while working on your change (for instance, you did a +`pull upstream main` which led to conflict), then you will have to resolve these +conflicts again. + +After you have successfully rebased your branch, you should push your changes. +And because you changed the commit history, you may have to use `git push --force`. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. diff --git a/docs_nnx/examples/community_examples.rst b/docs_nnx/examples/community_examples.rst new file mode 100644 index 0000000000..079568c9a7 --- /dev/null +++ b/docs_nnx/examples/community_examples.rst @@ -0,0 +1,110 @@ +Community examples +================== + +In addition to the `curated list of official Flax examples on GitHub `__, +there is a growing community of people using Flax to build new types of machine +learning models. We are happy to showcase any example built by the community here! + +If you want to submit your own Flax example, you can start by forking +one of the `official Flax examples on GitHub `__. + +Models +****** +.. list-table:: + :header-rows: 1 + + * - Link + - Author + - Task type + - Reference + * - `matthias-wright/flaxmodels `__ + - `@matthias-wright `__ + - Various + - GPT-2, ResNet, StyleGAN-2, VGG, ... + * - `DarshanDeshpande/jax-models `__ + - `@DarshanDeshpande `__ + - Various + - Segformer, Swin Transformer, ... also some stand-alone layers + * - `google/vision_transformer `__ + - `@andsteing `__ + - Image classification, image/text + - https://arxiv.org/abs/2010.11929, https://arxiv.org/abs/2105.01601, https://arxiv.org/abs/2111.07991, ... + * - `jax-resnet `__ + - `@n2cholas `__ + - Various resnet implementations + - `torch.hub `__ + * - `Wav2Vec2 finetuning `__ + - `@vasudevgupta7 `__ + - Automatic Speech Recognition + - https://arxiv.org/abs/2006.11477 + +Examples +******** + +.. list-table:: + :header-rows: 1 + + * - Link + - Author + - Task type + - Reference + * - `JAX-RL `__ + - `@henry-prior `__ + - Reinforcement learning + - N/A + * - `BigBird Fine-tuning `__ + - `@vasudevgupta7 `__ + - Question-Answering + - https://arxiv.org/abs/2007.14062 + * - `DCGAN `__ + - `@bkkaggle `__ + - Image Synthesis + - https://arxiv.org/abs/1511.06434 + * - `denoising-diffusion-flax `__ + - `@yiyixuxu `__ + - Image generation + - https://arxiv.org/abs/2006.11239 + +Tutorials +********* + +.. currently left empty as a placeholder for tutorials +.. list-table:: + :header-rows: 1 + + * - Link + - Author + - Task type + - Reference + * - + - + - + - + +Contributing policy +******************* + +If you are interested in adding a project to the Community Examples section, take the following +into consideration: + +* **Code examples**: Examples must contain a README that is helpful, clear, and explains + how to run the code. The code itself should be easy to follow. +* **Tutorials**: These docs should preferrably be a Jupyter Notebook format + (refer to `Contributing `__ + to learn how to convert a Jupyter Notebook into a Markdown file with `jupytext`). + Your tutorial should be well-written, and discuss/describe an interesting topic/task. + To avoid duplication, the content of these docs must be different from + `existing docs on the Flax documentation site `__ + or other community examples mentioned in this document. +* **Models**: repositories with models ported to Flax must provide at least one of the following: + + * Metrics that are comparable to the original work when the model is trained to completion. Having + available plots of the metric's history during training is highly encouraged. + * Tests to verify numerical equivalence against a well known implementation (same inputs + + weights = same outputs) preferably using pretrained weights. + +In all cases mentioned above, the code must work with the latest stable versions of the +following packages: ``jax``, ``flax``, and ``optax``, and make substantial use of Flax. +Note that both ``jax`` and ``optax`` are `required packages `__ +of ``flax`` (refer to the `installation instructions `__ +for more details). diff --git a/docs_nnx/examples/core_examples.rst b/docs_nnx/examples/core_examples.rst new file mode 100644 index 0000000000..34e3bacc11 --- /dev/null +++ b/docs_nnx/examples/core_examples.rst @@ -0,0 +1,87 @@ +Core examples +============= + +Core examples are hosted on the GitHub Flax repository in the `examples `__ +directory. + +Each example is designed to be **self-contained and easily forkable**, while +reproducing relevant results in different areas of machine learning. + +As discussed in `#231 `__, we decided +to go for a standard pattern for all examples including the simplest ones (like MNIST). +This makes every example a bit more verbose, but once you know one example, you +know the structure of all of them. Having unit tests and integration tests is also +very useful when you fork these examples. + +Some of the examples below have a link "Interactive🕹" that lets you run them +directly in Colab. + +Image classification +******************** + +- :octicon:`mark-github;0.9em` `MNIST `__ - + `Interactive🕹 `__: + Convolutional neural network for MNIST classification (featuring simple + code). + +- :octicon:`mark-github;0.9em` `ImageNet `__ - + `Interactive🕹 `__: + Resnet-50 on ImageNet with weight decay (featuring multi-host SPMD, custom + preprocessing, checkpointing, dynamic scaling, mixed precision). + +Reinforcement learning +********************** + +- :octicon:`mark-github;0.9em` `Proximal Policy Optimization `__: + Learning to play Atari games (featuring single host SPMD, RL setup). + +Natural language processing +*************************** + +- :octicon:`mark-github;0.9em` `Sequence to sequence for number + addition `__: + (featuring simple code, LSTM state handling, on the fly data generation). +- :octicon:`mark-github;0.9em` `Parts-of-speech + tagging `__: Simple + transformer encoder model using the universal dependency dataset. +- :octicon:`mark-github;0.9em` `Sentiment + classification `__: + with a LSTM model. +- :octicon:`mark-github;0.9em` `Transformer encoder/decoder model trained on + WMT `__: + Translating English/German (featuring multihost SPMD, dynamic bucketing, + attention cache, packed sequences, recipe for TPU training on GCP). +- :octicon:`mark-github;0.9em` `Transformer encoder trained on one billion word + benchmark `__: + for autoregressive language modeling, based on the WMT example above. + +Generative models +***************** + +- :octicon:`mark-github;0.9em` `Variational + auto-encoder `__: + Trained on binarized MNIST (featuring simple code, vmap). + +Graph modeling +************** + +- :octicon:`mark-github;0.9em` `Graph Neural Networks `__: + Molecular predictions on ogbg-molpcba from the Open Graph Benchmark. + +Contributing to core Flax examples +********************************** + +Most of the `core Flax examples on GitHub `__ +follow a structure that the Flax dev team found works well with Flax projects. +The team strives to make these examples easy to explore and fork. In particular +(as per GitHub Issue `#231 `__): + +- README: contains links to paper, command line, `TensorBoard `__ metrics. +- Focus: an example is about a single model/dataset. +- Configs: we use ``ml_collections.ConfigDict`` stored under ``configs/``. +- Tests: executable ``main.py`` loads ``train.py`` which has ``train_test.py``. +- Data: is read from `TensorFlow Datasets `__. +- Standalone: every directory is self-contained. +- Requirements: versions are pinned in ``requirements.txt``. +- Boilerplate: is reduced by using `clu `__. +- Interactive: the example can be explored with a `Colab `__. \ No newline at end of file diff --git a/docs_nnx/examples/google_research_examples.rst b/docs_nnx/examples/google_research_examples.rst new file mode 100644 index 0000000000..83e0101001 --- /dev/null +++ b/docs_nnx/examples/google_research_examples.rst @@ -0,0 +1,269 @@ +######################## +Google Research examples +######################## + +A collection of research by Google Research made with Flax. + +Attention +********* + +Fast Attention (FAVOR+) and Rethinking Attention with Performers +================================================================ + +- Code on GitHub: + + - `Performer's Fast Attention (FAVOR+) module `__ + +- Research paper: + + - `Rethinking Attention with Performers `__ (Choromanski et al., 2020) + + - Introduces *"Performers, Transformer architectures which can estimate regular (softmax) full-rank-attention Transformers with provable accuracy, but using only linear (as opposed to quadratic) space and time complexity, without relying on any priors such as sparsity or low-rankness. To approximate softmax attention-kernels, Performers use a novel Fast Attention Via positive Orthogonal Random features approach (FAVOR+), which may be of independent interest for scalable kernel methods. FAVOR+ can be also used to efficiently model kernelizable attention mechanisms beyond softmax."* + +Self-attention Does Not Need O(n^2) Memory +========================================== + +- `Code on GitHub `__ +- `Colab notebook `__ + +- Research paper: + + - `Self-attention Does Not Need O(n^2) Memory `__ (Rabe and Staats, 2021) + + - *"We present a very simple algorithm for attention that requires O(1) memory with respect to sequence length and an extension to self-attention that requires O(log n) memory. This is in contrast with the frequently stated belief that self-attention requires O(n^2) memory. While the time complexity is still O(n^2), device memory rather than compute capability is often the limiting factor on modern accelerators. Thus, reducing the memory requirements of attention allows processing of longer sequences than might otherwise be feasible..."* + +Computer vision +*************** + +Colorization Transformer (ColTran) +================================== + +- `Code on GitHub `__ + +- Research paper: + + - `Colorization Transformer `__ (Kumar et al., 2020) + + - *"We presented the Colorization Transformer (ColTran), an architecture that entirely relies on self-attention for image colorization. We introduce conditional transformer layers, a novel building block for conditional, generative models based on self-attention. Our ablations show the superiority of employing this mechanism over a number of different baselines. Finally, we demonstrate that ColTran can generate diverse, high-fidelity colorizations on ImageNet, which are largely indistinguishable from the ground-truth even for human raters."* + +Vision Transformer (ViT), MLP-Mixer Architectures *and* Big Vision +================================================================== + +- Code on GitHub: + + - `Vision Transformer and MLP-Mixer Architectures `__ + + - `Big Vision `__ + + - *"This codebase is designed for training large-scale vision models using Cloud TPU VMs or GPU machines. It is based on Jax/Flax libraries, and uses tf.data and TensorFlow Datasets for scalable and reproducible input pipelines."* + +- `Colab notebooks `__: + + - The JAX code of Vision Transformers and MLP Mixers + - More than 50k Vision Transformer and hybrid checkpoints that were used to generate the data of "How to train your ViT?" + +- Research papers: + + - `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `__ (Dosovitskiy et al., 2020) + + - *"In vision, attention is either applied in conjunction with convolutional networks, or used to replace certain components of convolutional networks while keeping their overall structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to sequences of image patches can perform very well on image classification tasks. When pre-trained on large amounts of data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.), Vision Transformer (ViT) attains excellent results compared to state-of-the-art convolutional networks while requiring substantially fewer computational resources to train."* + + - `MLP-Mixer: An All-MLP Architecture for Vision `__ (Tolstikhin et al., 2021) + + - *"In this paper we show that while convolutions and attention are both sufficient for good performance, neither of them are necessary. We present MLP-Mixer, an architecture based exclusively on multi-layer perceptrons (MLPs). MLP-Mixer contains two types of layers: one with MLPs applied independently to image patches (i.e. "mixing" the per-location features), and one with MLPs applied across patches (i.e. "mixing" spatial information). When trained on large datasets, or with modern regularization schemes, MLP-Mixer attains competitive scores on image classification benchmarks, with pre-training and inference cost comparable to state-of-the-art models."* + + - `How to Train Your ViT? Data, Augmentation, and Regularization in Vision Transformers `__ (Steiner et al., 2021) + + - *"Vision Transformers (ViT) have been shown to attain highly competitive performance for a wide range of vision applications, such as image classification, object detection and semantic image segmentation. In comparison to convolutional neural networks, the Vision Transformer's weaker inductive bias is generally found to cause an increased reliance on model regularization or data augmentation ("AugReg" for short) when training on smaller training datasets. We conduct a systematic empirical study in order to better understand the interplay between the amount of training data, AugReg, model size and compute budget."* + + - `When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations `__ (X. Chen et al., 2021) + + - *"Vision Transformers (ViTs) and MLPs signal further efforts on replacing hand-wired features or inductive biases with general-purpose neural architectures. Existing works empower the models by massive data, such as large-scale pre-training and/or repeated strong data augmentations, and still report optimization-related problems (e.g., sensitivity to initialization and learning rates). Hence, this paper investigates ViTs and MLP-Mixers from the lens of loss geometry, intending to improve the models' data efficiency at training and generalization at inference."* + + - `LiT: Zero-Shot Transfer with Locked-image Text Tuning `__ (X. Zhai et al., 2021) + + - *"This paper presents contrastive-tuning, a simple method employing contrastive training to align image and text models while still taking advantage of their pre-training. In our empirical study we find that locked pre-trained image models with unlocked text models work best. We call this instance of contrastive-tuning "Locked-image Tuning" (LiT), which just teaches a text model to read out good representations from a pre-trained image model for new tasks. A LiT model gains the capability of zero-shot transfer to new vision tasks, such as image classification or retrieval. The proposed LiT is widely applicable; it works reliably with multiple pre-training methods (supervised and unsupervised) and across diverse architectures (ResNet, Vision Transformers and MLP-Mixer) using three different image-text datasets."* + +Scaling Vision with Sparse Mixture of Experts (MoE) +=================================================== + +- `Code on GitHub `__ +- Research paper: + + - `Scaling Vision with Sparse Mixture of Experts `__ (Riquelme et al., 2021) + + - *"Sparsely-gated Mixture of Experts networks (MoEs) have demonstrated excellent scalability in Natural Language Processing. In Computer Vision, however, almost all performant networks are "dense", that is, every input is processed by every parameter. We present a Vision MoE (V-MoE), a sparse version of the Vision Transformer, that is scalable and competitive with the largest dense networks... we demonstrate the potential of V-MoE to scale vision models, and train a 15B parameter model that attains 90.35% on ImageNet..."* + +Diffusion +********* + +Variational Diffusion Models +============================ + +- `Code on GitHub `__ +- `Colab notebooks `__ +- Research paper: + + - `Variational Diffusion Models `__ (Kingma et al., 2021) + + - *"Diffusion-based generative models have demonstrated a capacity for perceptually impressive synthesis, but can they also be great likelihood-based models? We answer this in the affirmative, and introduce a family of diffusion-based generative models that obtain state-of-the-art likelihoods on standard image density estimation benchmarks. Unlike other diffusion-based models, our method allows for efficient optimization of the noise schedule jointly with the rest of the model. We show that the variational lower bound (VLB) simplifies to a remarkably short expression in terms of the signal-to-noise ratio of the diffused data, thereby improving our theoretical understanding of this model class. Using this insight, we prove an equivalence between several models proposed in the literature. In addition, we show that the continuous-time VLB is invariant to the noise schedule, except for the signal-to-noise ratio at its endpoints. This enables us to learn a noise schedule that minimizes the variance of the resulting VLB estimator, leading to faster optimization..."* + +Domain adaptation +***************** + +GIFT (Gradual Interpolation of Features toward Target) +====================================================== + +- `Code on GitHub `__ +- Research paper: + + - `Gradual Domain Adaptation in the Wild: When Intermediate Distributions are Absent `__ (Abnar et al., 2021) + + - *"We focus on the problem of domain adaptation when the goal is shifting the model towards the target distribution, rather than learning domain invariant representations. It has been shown that under the following two assumptions: (a) access to samples from intermediate distributions, and (b) samples being annotated with the amount of change from the source distribution, self-training can be successfully applied on gradually shifted samples to adapt the model toward the target distribution. We hypothesize having (a) is enough to enable iterative self-training to slowly adapt the model to the target distribution, by making use of an implicit curriculum. In the case where (a) does not hold, we observe that iterative self-training falls short. We propose GIFT, a method that creates virtual samples from intermediate distributions by interpolating representations of examples from source and target domains..."* + +Generalization +************** + +Surrogate Gap Minimization Improves Sharpness-Aware Training +============================================================ + +- `Code on GitHub `__ +- Research paper: + + - `Surrogate Gap Minimization Improves Sharpness-Aware Training `__ (J. Zhuang et al., 2022) + + - *"The recently proposed Sharpness-Aware Minimization (SAM) improves generalization by minimizing a perturbed loss defined as the maximum loss within a neighborhood in the parameter space. However, we show that both sharp and flat minima can have a low perturbed loss, implying that SAM does not always prefer flat minima. Instead, we define a surrogate gap, a measure equivalent to the dominant eigenvalue of Hessian at a local minimum when the radius of neighborhood (to derive the perturbed loss) is small. The surrogate gap is easy to compute and feasible for direct minimization during training. Based on the above observations, we propose Surrogate Gap Guided Sharpness-Aware Minimization (GSAM), a novel improvement over SAM with negligible computation overhead..."* + +Meta learning +************* + +``learned_optimization`` +======================= + +- Code on GitHub: `learned_optimization `__ +- `Colab notebooks `__ + +- Research papers: + + - `Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies `__ (Vicol et al., 2021) + + - *"We introduce a method called Persistent Evolution Strategies (PES), which divides the computation graph into a series of truncated unrolls, and performs an evolution strategies-based update step after each unroll. PES eliminates bias from these truncations by accumulating correction terms over the entire sequence of unrolls. PES allows for rapid parameter updates, has low memory usage, is unbiased, and has reasonable variance characteristics."* + + - `Gradients Are Not All You Need `__ (Metz et al., 2021) + + - *"...In this short report, we discuss a common chaos based failure mode which appears in a variety of differentiable circumstances, ranging from recurrent neural networks and numerical physics simulation to training learned optimizers. We trace this failure to the spectrum of the Jacobian of the system under study, and provide criteria for when a practitioner might expect this failure to spoil their differentiation based optimization algorithms."* + +Model efficiency +**************** + +Efficiently Scaling Transformer Inference +========================================= + +- Code on GitHub: + + - `T5X `__ + - `AQT: Accurate Quantized Training `__ + +- Research paper: + + - `Efficiently Scaling Transformer Inference `__ (Pope et al., 2022) + + - *"We develop a simple analytical model for inference efficiency to select the best multi-dimensional partitioning techniques optimized for TPU v4 slices based on the application requirements. We combine these with a suite of low-level optimizations to achieve a new Pareto frontier on the latency and model FLOPS utilization (MFU) tradeoffs on 500B+ parameter models that outperforms the FasterTransformer suite of benchmarks. We further show that with appropriate partitioning, the lower memory requirements of multiquery attention (i.e. multiple query heads share single key/value head) enables scaling up to 32× larger context lengths."* + +Neural rendering / NeRF +*********************** + +Generalizable Patch-Based Neural Rendering +========================================== + +- `Code on GitHub `__ +- Research paper: + + - `Generalizable Patch-Based Neural Rendering `__ (Suhail et al., 2022) + + - *"...We propose a different paradigm, where no deep features and no NeRF-like volume rendering are needed. Our method is capable of predicting the color of a target ray in a novel scene directly, just from a collection of patches sampled from the scene."* + +Voxel-based Radiance Fields in JAX and Flax +=========================================== + +- `Colab notebook `__ (Velez and Dellaert, 2022) + + - *"In this notebook we show how with JAX/Flax, it is relatively easy to quickly get a voxel-based NeRF variant up and running. Specifically, we will develop a simplified version of DVGO that directly regresses color instead of having a small MLP. It works remarkably well."* + +Optimization +************ + +Amos Optimizer *and* JEstimator +=============================== + +- Code on GitHub: + + - `Amos and JEstimator `__ + + - *"... implements Amos, an optimizer compatible with the optax library, and JEstimator, a light-weight library with a tf.Estimator-like interface to manage T5X-compatible checkpoints for machine learning programs in JAX, which we use to run experiments in the paper."* + +- Research paper: + + - `Amos: An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale `__ (Tian and Parikh, 2022) + + - Presents *"Amos, an optimizer compatible with the optax library, and JEstimator, a light-weight library with a tf.Estimator-like interface to manage T5X-compatible checkpoints for machine learning programs in JAX."* *"When used for pre-training BERT variants and T5, Amos consistently converges faster than the state-of-the-art settings of AdamW, achieving better validation loss within <=70% training steps and time, while requiring <=51% memory for slot variables."* + +Quantization +************ + +Pareto-Optimal Quantized ResNet Is Mostly 4-bit *and* AQT: Accurate Quantized Training +====================================================================================== + +- Code on GitHub: + + - `AQT: Accurate Quantized Training `__ + +- Research paper: + + - `Pareto-Optimal Quantized ResNet Is Mostly 4-bit `__ (Abdolrashidi et al., 2021) + + - *"In this work, we use ResNet as a case study to systematically investigate the effects of quantization on inference compute cost-quality tradeoff curves. Our results suggest that for each bfloat16 ResNet model, there are quantized models with lower cost and higher accuracy; in other words, the bfloat16 compute cost-quality tradeoff curve is Pareto-dominated by the 4-bit and 8-bit curves, with models primarily quantized to 4-bit yielding the best Pareto curve... The quantization method we used is optimized for practicality: It requires little tuning and is designed with hardware capabilities in mind... As part of this work, we contribute a quantization library written in JAX..."* + +Reinforcement learning +********************** + +Continuous Control with Action Quantization from Demonstrations (AQuaDem) +========================================================================= + +- `Code on GitHub `__ + +- Research paper: + + - `Continuous Control with Action Quantization from Demonstrations `__ (Dadashi et al., 2021) + + - Proposes *"a novel Reinforcement Learning (RL) framework for problems with continuous action spaces: Action Quantization from Demonstrations (AQuaDem). The proposed approach consists in learning a discretization of continuous action spaces from human demonstrations. This discretization returns a set of plausible actions (in light of the demonstrations) for each input state, thus capturing the priors of the demonstrator and their multimodal behavior. By discretizing the action space, any discrete action deep RL technique can be readily applied to the continuous control problem. Experiments show that the proposed approach outperforms state-of-the-art methods such as SAC in the RL setup, and GAIL in the Imitation Learning setup."* + +Sequence models / Model parallelism +*********************************** + +T5X: Scaling Up Models and Data with ``t5x`` and ``seqio`` +========================================================== + +- `Code on GitHub `__ + + - *"T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of sequence models (starting with language) at many scales."* + +- Research paper: + + - `T5X: Scaling Up Models and Data with t5x and seqio `__ (Roberts et al., 2022) + + - *"Recent neural network-based language models have benefited greatly from scaling up the size of training datasets and the number of parameters in the models themselves. Scaling can be complicated due to various factors including the need to distribute computation on supercomputer clusters (e.g., TPUs), prevent bottlenecks when infeeding data, and ensure reproducible results. In this work, we present two software libraries that ease these issues: t5x simplifies the process of building and training large language models at scale while maintaining ease of use, and seqio provides a task-based API for simple creation of fast and reproducible training data and evaluation pipelines. These open-source libraries have been used to train models with hundreds of billions of parameters on datasets with multiple terabytes of training data. Along with the libraries, we release configurations and instructions for T5-like encoder-decoder models as well as GPT-like decoder-only architectures."* + +Simulation +********** + +Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation +============================================================================ + +- `Code on GitHub `__ +- `Colab notebooks `__ +- Research paper: + + - `Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation `__ (Freeman et al., 2021) + + - *"We present Brax, an open source library for rigid body simulation with a focus on performance and parallelism on accelerators, written in JAX. We present results on a suite of tasks inspired by the existing reinforcement learning literature, but remade in our engine. Additionally, we provide reimplementations of PPO, SAC, ES, and direct policy optimization in JAX that compile alongside our environments, allowing the learning algorithm and the environment processing to occur on the same device, and to scale seamlessly on accelerators."* diff --git a/docs_nnx/examples/index.rst b/docs_nnx/examples/index.rst new file mode 100644 index 0000000000..cd77fd9cee --- /dev/null +++ b/docs_nnx/examples/index.rst @@ -0,0 +1,12 @@ +Examples +======== + +.. toctree:: + :maxdepth: 2 + + core_examples + google_research_examples + repositories_that_use_flax + community_examples + + diff --git a/docs_nnx/examples/repositories_that_use_flax.rst b/docs_nnx/examples/repositories_that_use_flax.rst new file mode 100644 index 0000000000..dfc23f6ad4 --- /dev/null +++ b/docs_nnx/examples/repositories_that_use_flax.rst @@ -0,0 +1,51 @@ +Repositories that use Flax +========================== + +The following code bases use Flax and provide training frameworks and a wealth +of examples. In many cases, you can also find pre-trained weights: + + +🤗 Hugging Face +*************** + +`🤗 Hugging Face `__ is a +very popular library for building, training, and deploying state of the art +machine learning models. +These models can be applied on text, images, and audio. After organizing the +`JAX/Flax community week `__, +they have now over 5,000 +`Flax/JAX models `__ in +their repository. + +🥑 DALLE Mini +************* + +`🥑 DALLE Mini `__ is a Transformer-based +text-to-image model implemented in JAX/Flax that follows the ideas from the +original `DALLE `__ paper by OpenAI. + +Scenic +****** + +`Scenic `__ is a codebase/library +for computer vision research and beyond. Scenic's main focus is around +attention-based models. Scenic has been successfully used to develop +classification, segmentation, and detection models for multiple modalities +including images, video, audio, and multimodal combinations of them. + +Big Vision +********** + +`Big Vision `__ is a codebase +designed for training large-scale vision models using Cloud TPU VMs or GPU +machines. It is based on Jax/Flax libraries, and uses tf.data and TensorFlow +Datasets for scalable and reproducible input pipelines. This is the original +codebase of ViT, MLP-Mixer, LiT, UViM, and many more models. + +T5X +*** + +`T5X `__ is a modular, composable, +research-friendly framework for high-performance, configurable, self-service +training, evaluation, and inference of sequence models (starting with +language) at many scales. \ No newline at end of file diff --git a/docs_nnx/faq.rst b/docs_nnx/faq.rst new file mode 100644 index 0000000000..726ed9bfbe --- /dev/null +++ b/docs_nnx/faq.rst @@ -0,0 +1,38 @@ +Frequently Asked Questions (FAQ) +================================ + +This is a collection of answers to frequently asked questions (FAQ). You can contribute to the Flax FAQ by starting a new topic in `GitHub Discussions `__. + +Where to search for an answer to a Flax-related question? +********************************************************* + +There are a number of official Flax resources to search for information: + +- `Flax Documentation on ReadTheDocs `__ (this site): Use the `search bar `__ or the table of contents on the left-hand side. +- `google/flax GitHub Discussions `__: Search for an existing topic or start a new one. If you can't find what you're looking for, feel free to ask the Flax team or community a question. +- `google/flax GitHub Issues `__: Use the search bar to look for an existing issue or a feature request, or start a new one. + +How to take the derivative with respect to an intermediate value (using :code:`Module.perturb`)? +************************************************************************************************ + +To take the derivative(s) or gradient(s) of the output with respect to a hidden/intermediate activation inside a model layer, you can use :meth:`flax.linen.Module.perturb`. You define a zero-value :class:`flax.linen.Module` "perturbation" parameter – :code:`perturb(...)` – in the forward pass with the same shape as the intermediate activation, define the loss function with :code:`'perturbations'` as an added standalone argument, perform a JAX derivative operation with :code:`jax.grad` on the perturbation argument. + +For full examples and detailed documentation, go to: + +- The :meth:`flax.linen.Module.perturb` API docs +- The `Extracting gradients of intermediate values `_ guide +- `Flax GitHub Discussions #1152 `__ + +Is Flax Linen :code:`remat_scan()` the same as :code:`scan(remat(...))`? +************************************************************************ + +Flax :code:`remat_scan()` (:meth:`flax.linen.remat_scan()`) and :code:`scan(remat(...))` (:meth:`flax.linen.scan` over :meth:`flax.linen.remat`) are not the same, and :code:`remat_scan()` is limited in cases it supports. Namely, :code:`remat_scan()` treats the inputs and outputs as carries (hidden states that are carried through the training loop). You are recommended to use :code:`scan(remat(...))`, as typically you would need the extra parameters, such as ``in_axes`` (for input array axes) or ``out_axes`` (output array axes), which :meth:`flax.linen.remat_scan` does not expose. + +What are the recommended training loop libraries? +************************************************* + +Consider using CLU (Common Loop Utils) `google/CommonLoopUtils `__. To get started, go to this `CLU Synopsis Colab `__. You can find answers to common questions about CLU with Flax on `google/flax GitHub Discussions `__. + +Check out the official `google/flax Examples `__ for examples of using the training loop with (CLU) metrics. For example, this is `Flax ImageNet's train.py `__. + +For computer vision research, consider `google-research/scenic `__. Scenic is a set of shared light-weight libraries solving commonly encountered tasks when training large-scale vision models (with examples of several projects). Scenic is developed in JAX with Flax. To get started, go to the `README page on GitHub `__. \ No newline at end of file diff --git a/docs_nnx/flax.png b/docs_nnx/flax.png new file mode 100644 index 0000000000000000000000000000000000000000..24b65b0c9f12c905b0d820fd6307e38b1c6831aa GIT binary patch literal 20991 zcma%i1yo#3v+dwIkRZW5cyM( zdczu~r@O1Wt52OXUA1?Gs3=JTQ3+82002-{M)D&707Llqjf@Db8LdfbhJL`Ciz$i$ z098?FPeusP@8rfZ9~A)r4;lc#FAxBDfL8hK0|3sf0KnmA0DwOU0Kl_PYgQG6HdvTy z%bF=F0vMoWWB@!2A^;9rf&sun|6%>R3?`PJ^+9)^b-Jp%Yyl@o>{Q} z(F#MD1@~X&za0;GHJt$f*lkO7Z6|F-1%6|&Ewhmc_=_pCo2~ud4gf(nerVCw)X9k4 z&DO@wk>5>-@}D03(DL7E7E1Dex;R-2QEDrykV}9aOv$;KS(#ZWg;B}L$psxu%=kY_ zg8nTI{U$_d;pAk`&%)yB>dNfO!3=gVXJO;x<6~iEXJKb&g7#o?bhmRda$~Y{r25w& z|1plFsiU!jrM;6S*pB?~xJF;V&Q3y11j4Uo+V`{#z}mf-HYqSlF0Z zS^i&Qrf!!1i`d_me~JC0*T1F{{5u%G1lZQz!PLUv+?|VT)$GSHg|ety$$$K^Z~hXSzJY1QuHkXC#4*E7!4@5J0<#ek|ZM6i{cl` zEKqJ2NeVDq~|K$4(j;6|=TQWMxFE_QEjZW6NYE?2RpQ#e>Ov;?~5kukv(bmqy9$C@TX z_yk=itiIO(HrWIu@N7r>`ntJ9d1u0Wu6G7fYz`i%*? zAzjQJc|}P5DoEG?%CeYo>ShIr%C>H`=%=nh=`TTkYL*@vhGs?ZoZ}`IT`3Ep`$jaB z6{0=Lo)vOt@nAVw%-l&CZow9v3Oz2mH6>q(&V_UvkB8@V8L0C1gAA5(R37-S!*8zH z3QK8~4R?N!g7TAXWS&QoCJ=TH><__3A85)KO|$-v|2Z3U#oF`Pw)G2m3F&dCy-$m` zB3Si+mWbhbx%P3Oj%FRVG5=4NP0slU5o^^y5uIM)n~VkP*jDY zYaW04l04)Y(YW}=r_1U|2ze{+>s^El!8I$UNFC@McyWhJTrDf`A$7xBDKRE-C$+=i zJJG08!e0wv{d=A9chtk={DxE{RWPRlFdJDcAIc&wuDVBBNARVg+w>M;_TWtCbF`mY zI$O9I3>EYetM>zvBFPg0YVw%ITx)O0e_~N4Jfwb|bKwS?ue)Ne(^@?%ic=~)$`e{l z8o)bazagM```PNi=#YI)#dWaDV{4LmWT*MXDpq%wNc|xh$I3D1 zg;^2EDNIM^hwjWuakpRZZ!&E0;D=xsE&3e6O?w!Y!yAYUR8ABwBDlNW4EJCm@GZD3VPINFUbB6rJ*rCDis530pRnOO8I?DXlMQ~$ zloLP;Z)lBIpINryP9OHLjSI2LT)oyO%jUl-M3eQ~FTB~RW!RB1lqp}X*L!*xl7Z+PggceWQ3|}n<>o?*_X?lv$ zVLbBuX+*!9(r`@$t9c}uFrTa&T4#l*1;zC0Md*Yd@>b?ws}=wY1NaqC#r0*)rDjBp zsl4UEMvQ(mW4l~ylKuvhcy~QMVk@@5d)L4xEf7PXrjj(g5eDaYQr=t3Uq+HrBNi8I zel%|!ssSdSM?FK0L}*xn+?2Qbc3$eqLT4Tt>_p1H5prPYzpS9Dsyf7Qn>+TF3O06C zsgB0<+N$`0CNO;Z<&jtCYSf5n7Sd#NFq+}&_D2ydIs-N|HxgUpXO_)Q!Lx8c2HCfv zzLc#KbSSK91(k7!B#$POWFACZR2HHVIaKR>B#Cf9nzOCKHZ^W@8($PGUlg5Ji3y_k zadb|t*IfJlA1|*mQ(SaTtP+TDuBCbWJE~y)8S5!dd&ZKQ9WiPAjL_dY} zy5?Pqf5B-)4bY&h(0Vhy4KVR|Zco|061ri7QiEY}OqiWcFp>oeyh9_wX~IQV`jK$= zO~`Y*YbR`b+St*=sqjy%VtBS>%NJdqd%d;JWRwF zQ0Ax#RY(~Op?Z+sFTLdloX=!2C*^ zpRs&Ihc+rcNQ7rU(b~n&j5mNSd$`%Yh#-je&8b z+D(!SH?*@EKHg?X1gq`FPDQ~V>bq;ietxQ!f3(dJ%|(u=GU%WCTwXq1tmMqS!Kt2yzE&beNR>=h%j1zxkG%ekvytR+%fI4C=mlNM4oNy=L-XWX_w z45~MuBRt$lb-1Gx0T#w05MC&j5_E?<5Fl~VUA>Vj!e6xr zRQFJN9`i`Py$P1ZO(XNC`IwEuiV&LI97K88)+~NxA9{L;pClQq<^qA`v4nGvTQBaA zJ20ovMAi@ zg{UF2@S?<4xT3{DPLcXv%^^ zXltF|fOkZw54;NO9u_WBGw1uRn>uA1)6GnogBss1@}>7ubQpJcrlAXNe&!_k^L{+q zx5a=mN2{ymRx^mzo;h0dTL4o)-U2dED14Ed)$^Xa50za1uj-wW_)yhLt(yr?n-9w% z^p0l4-P9bZCNHxg)<;fvVliz+)3>=|L1{7~F`Q|LgbC*8se(0UX?|+ut7zi*bXg%d zbTBj}=v-hT`;Y4ea6mT!qAL+p-^fj>g#)YWbfRo$UX9FTW}f|;WQ@J&EhEnUwWgR% zKE(z<(2h|Z`;RBshIEQIqHtZiV*Bt(6moah$pIRcztP%Y{~9s>9m|aF#8v~SZ2!wV z6|EK_>DRon`CFY~Dc9ZGw$?~Cl6E4_aQOA%uVL>R{6P~W+b8Jlh|sH`;I`z)@P<~L z#Sigm1?^K`7g>SD&8NBSi;=u>r3UyHhK*oY`+EKn#=$Cl%hj=b9Fe9Y=fb$k1k!6B zvE@;FHzsB;r*GYiuzVDE?}~{1**4)lj-1wQ;SD_CSI61T;Tj4`lU?fhD#9}a!1dY% z5CL%vUGm3{Y~klcqvUP&wiAKl1~1-ZyoMhN4El_hiS1y`eznX7tX4%S=&8?+_R@*6 z{FE=|N1UaoHp#|Nv>A#QTEU1d+}zU(m%GP?*(&0!f2H7EVb(7%LIr$s$jVGZhp;rL z{k&w$`U@;0T`=-q!APm0$LRcLLfKfMyRUTLtnS2vt7YZc^)_Bw`c`9G$g~-bWu=zs z0=Jp+D4(z%3Z&sS&J2wx1;3j%nlu25PM_X^+!zr%B}aIYESF#Fo)IN31s?d_QM7S? zBemayB88WWE|{Vhz2fEh#0_oDWq!(6xEVhDVKx?kTvJ8p-#@_-VXlT!wkA+0v|J*L zi_%fZn|&|cC%P6(K9EnPgP5g)8n*Hy{8UvyIRj?<@4 z{pQX?;tu6cS@OUEx%!oI$wYw{$kT~P!K+QHND6jh1d8JOVJGOt$mtOJo0}m`@6W5F z>x`_XE6S-7FnE-gtb*UaG(Ol6saW#9oz)^hU*!+LfJJAl*(f_QI*$M8wU@y)vrbYuT%IHK_3Xva$=sZR{f(}cN}e0JhDix6pU@7Z6?bX?~F_qZjPA-Np=#L)$f4n zi#DG1`9cNX=+1)5>B|ye(s4K+m2;g5Pcm@5kKs$vJ`jSe@!=Av{^_N}+>r0^EM;MF&oaJ40ZCq&NB4fz(%#3N8G$0-JF z(nlR)%~LmAj}mkqqTl3&=XB*`cxUe zw$GDoqC|ir3>y@fnY;(_n?VN5qP%3LI;fd=A!sI@2^VRS8#%K-a;=g-6#I75)tMWQp37)zJu$qaXDq3Tx#O= z;dbWrrI57hR49R`mfdyi{+r37J6s^-!%2G_AXqjYZO6E}SGsAZOAA7KY!MP&<6~@) zY&x$U>BS4b@h;vPIiqb_l>ZR?!qp(>#t#?l@>a{wUic0bvR5p_BQfWX7;fL(jHM9Y zNfZ%{ET8A>0FF%+4^7}kLutnu+k-T$drqUAk;OtXw09Zq4TxO<`b(68r)Y+ww-(Zg zs9PP*hx9g|7mj(CH*6-4v$axB&?*BI9zNQgwWuXvdZs@cq)=e0e_&M(iIVo1q1($* zpSp&)5;LjzIz4)aSt|yAtJFB7Xt2EcG8+P%un4q+KKXRD8N6{6m@C3IbWnhO*Q~di z@jcU!qFXHB#OJ{?JEOdSVEdH-P><(cx)DCf9=cqpXi#3j!3VUne#t@9WjCxlY3BsN z%NMg6KGYYnpIINYcc;I`_dKqSjamHgWz%Fyt>-wa4L@JE<+)4Qe|J)GYa!YQ!%j@| zR+U2(Qw=_&WCIwTO&av-_e2|B7~J(p&xhlKPv#Ps3FweM02d{Gh72M0J!=rttoW=V znB{84b$Vv#Mo1ihK^g>y92Q+Ux@%kYS9r2Lhj8dhJwWi-V2eNQGSs@`y}wBsXAzl{mtf zi8en#9T`VpgLV_i^(!7?ny;D>HP(lp1-`@s5>b7=&Lg3kiHol~>L{7Lv%(N`)t-2=o zb#q&}<2)5rNC%U1c8iHup_)`QGxYP+9dNO9`V&aO^xVd+L;JL5=}E>N#w1eQ9y&AA14Im8tRK3-rXjt|7U=L>T*$@`Jnkj3t7i4)-RHIGl}R`vAeH{A}f` zWml2QLb)YoY_4Ch*Y=4~v0IykRC1#%4 z^}S_^-{bp(+?WkaS$5xF52Q2^<-;=m>X^U(9?ITncEVw8>bPKba9-n>C-F-7O?Nr^ zjw*4?bd!~Sx;)XnEB%AFUBrj<>fc$6a_^ai4Qk{I^e@9}JF45Rh!?)Mv5tmD26SMx zzP$bHn0fz9uIFg}l9qld?dh;k(lLOVRh1DYAvBCWABt*{yP~9X zc1D?7)obMkpE-A$!BT2op0LVu6T5q�JQMU-2vi+7%mN6)q*n2ZGw=RZPHTOU^6# zOQt0FJ@=a(0pOeE3fT45uKq0Enimkk?jZNGA_#7TfOeA!gn%^Ue1$oTKTDq9ulFvB>Vd0>l?> zdAXcO=D# z8oS;o>TCYd`{|fI(9yD|cCAF_Y=kk<=o9+*xZ$Y&)a}a>H#1hoZ-ShYF^5~d5u9kl zam6F#yR6|rID^03d=_4hWzfcHq|47hX6MtIuXg(hlJ_tvqvCZ?pJ*y1S~a)-N6r^A z6Q3zcBN%dFmf))*F5zp;hhesDdhamc_mkc9p}8|vxr#A39FDN`UK|=#u;g^Q% zW!v8>60QYSyN>PdgCws#&~cGCw{>WQb2<<- zy`~+^kd~6aXMY)UU5d$eHX>M-N+!FN6H6@%jKj*GK@XKHjsO>TxT9Y*!)<&9j2=o( zT|zvou6SIR=U1-PSlYHla{YGabK8xBHUIF^VB|>TZm2eO1n}VJL>iyh?tsQ1wm!!} z3}(*}i@z)z?k5EE|InnH<=0#CrwIwC*wJY@IL09sPH8@*TCe(!?JIuNNt6B4@-tah z&p$CMKp_v`r!+n>U{6-PoMq>N-shfjS*N&*wsTM>uoWj0ZX5T(XtT)41&|0Y&}Mj{^2=kBAVy8vV*cMy>LjDD z`*xc6d<2m zoOn|GuRT?l?h1xH5FOkGd;EzJm_9Gc4-KFU3NzQ;k>y6+T&_HkPd z;X3IG+=ry5v(k9Y(pzFi^Cx<{PS=qQuAsy@wL;O%3LP5E-|IT>e(KN}7xFfU9&Ai+u#PS~>?m#KP8G&xM;iOv_hd8SIk2s2iUyaP>-!kM z#t|xL#DfhL?Bk#n+zuZg*Lb4}d)3cv7Ack>Ymu_iR7y$5lIGQoeE{ zS-NYxRyR1vP~B6YO&c}o4Pa($*Wd`*gXTIJhN4A$#5_JUgMqc3{9?5CZ)?~68rh52J>bxppQY1O)6Vk^MMHA>`+D6eT%DYUKxcx zL7!J$!Tfiptt4a;j%JZ2n0{KV_}gdjO!gO2dSheELAik&oLp5Xw&3V zjKsciK<>8V(G9D4kRpkYDSOV7I7LOA3Q|6blZ=h<$9H7$gnRWwHjmmx%G`$ z?@qRV@1B%Z;ZxfE_!k}4@{jGyuUsx4$td9Lc%_*iWX*FaRD5f`c4~&yne~kwFDhzj zF};euT;4jki&Ve-@LD=Q%3Dwo8KNtywahxV99W* zuQ{1-JQm07txF&TbA!lRR+?r&@M9#!MW~*@*N}Y3Tw{iTqUUr_tHfkod5ka6u;ntx ziaHBCNeRWFEZ=%QDU7HsdO1E8D4Xs|pzWN@l^FZToF%j0AkZI8D1e^oD>((8;f*D- z*BUtHZHZWK=fAKGlVaTKdx>SBw;U{Yo)Lv*Dn!kRZslUtTRTntNoN@{1G5InhD0|P ze`#3`o+52ASOhUYyGMO`Ko`-5s0pfI<|^r{e034xzW$&ob4ES3?YeF2eSH(59U$lE z1%grZ3EP_bFgjh`uv7{G@oZ(dSA+L_9zU}yQ6FxA8OrsyjJIYpA7h+L)=c2X(^fuc zRS?D-9!UA(7rf{^%B$JS!D6mZA`GYt`YpjCN<_@ zv71-+hXYuTI6%);9c*bIcj=b(EQ&0CrB?VHtTE;thwRW_&#O0AiC+j&UfK<{&;$QFI#OFSb3kH=1cA$q(2zel1T%_ zP`Q=x|W-#NS9}_WtJzgR9g~DOT8Oeq>CTC&Y zt52;8rjMse`}0~u#kcUmaYpYT=iP8fUa{*)gn(m#g%g8~+aMeJ(~S>ezvPsS=uuaf zebJnRY6qGa(~*l<)M%lo5UXf`kQOz^AP<6^FE)NwNOo{?Xg6|~RxUaG_wLSHyS@VD z?QrZUu6>)cHbjT>{fbFX8G&T~%i8iI(DIkwvN3$;JyBziBx~@e+tvbuC4#ghzEYMk zz2EbmCejL?x8XF`*NLxjRFNfoUnvO*{LrRP*rONTI=3K0-MjPtp^KeZ7$(zyL4}{w zbap(qwx=QCB#W#2{8$rIk$o9AQWQA!(W;mMZ>o_b`b@9W^mDA;y4f67^x;gq<)tQT zS7bqHcp}s0PnPp>>^o7F_=7sTvPV8r$`LrsIr7u0n2s-G4}?TQO&f8Mt2b@YrcYp9vQ2F zs_icAClkLXAit2XZ|c;s-`gz(rGuzAS*ook%VM534oV|H5679i1FKOlsh_Q`Jrxf) z)RoCq9dYc&(x?Y#Fn__Qjc^u&XIt+x6d>rNru}ej`SjzX@N=a?vt@zT zuRq4P391x}qFngZ6KiRd|ArsD3Omkm$A3>kM2l2~6hJfhXdHSbeHxusQ{)Ws*&PIT zTchmPqfHryP&b0WB6QPf_(~@#MD>Fkmt(<2E@^M;L-P5W*K@hxb0WLW)9|P3PoUr; zGSkn$e!c|B3yq#$!>t{FyfHo_?GXKtX2K z^z|K)I(_A^2cnu>(%l5#ww-=h3CHI13T2*?y-m+)NEzMImt*hHTm1rI2_h5;#QgP2 zd>_HNH5Fe+sB`R&?`sgoL*j7}G*>OdI1asRZ$8sCmN}>4eyQBV-G4V-_6E0idd)_= zUcb9x$b~7}$ z;v7xpey(JcKUbwz6p-$s93$vHmOM_iq^Sz*Tww!~y+jH(D5}rZ;1`v-_GPtwYl=M@ z7Ms=fX~;6*6p4k*&<&g7K0do$j( zX^o=)8jeFF-cz^+iB*_Q4@0Cn;363CrlZ*P=88M+pE_1Q2%Mo?q;Lv@dR-po>-m zz>!jqnZ>t<;7~F`8nHQ@c||SXn+YFxi(?z*g3|I-eH}~7;ZLFS^^8@<|?aJ>>4_Is){yEQcl14zy;^{KtxGt8zHK^mI7pSCkipzage z)=WnS_7}~m8BT+I;=}oQb^pAEv1O((C1#aGufz$ z`$j8(@W#MRxL^HF^w0>49yuehv&PTOrM{ZDk0JrDw|lL(!p788NoQs&pd-NXayx-m z%dLK$`RSX#UX{{yik-sq$ZKgZTai?=9V3*`o={9Q$wm_Mh~n6 zq*$(Pov%eZ)_zr;#ZM}?pMyeDiaw~hW7?`6fcZBc2h}HKu88LRX;j4n&J13+lz&nZ zN0eYhg-LUajr#ocdJ_gph_89B%Es;Zp!crk^>ok(n$NOS26EeBNCg$CFELH=%-Erj zl9p%wa6N;E`=mt=g<|M(L?g;UE9VXcf}J-;^y3!UJp6IHP6BPFH5t5*37W;X-Ml? zCH2K2M1Huk$wSfDa4_Ftk}j&REngbk)N1R$tgA}nDDo&nT?oO4`cl{utl?&O_lrMm z1Cd&59HfYyl_>6F`@O~e`G*#~RD>cdGA9KaW??K*sDJ*JpI#eJC355EwkdFIPH{Ix zBt6VFE>>u|74+tn!_?!)BWAF#_2`Z0Q<8yT+c$ zju)t&4N5YfW2{r^;&W0&H-L}-a=sJD)lGV?AjdxGCJ zHRF%k^-euSj~#;K9J88(hHJ96GG+cC+1WlLZ$FwK(gNaRtj~jGHUco542q==u`c?@ zk>4~<1#lOE=4tc41`e#87hx^GmkFj;)Z6K(gc@MzA8mbsQ4H+R3p`eg(sKW$MPksV zo-yApB}I*6l_e{qaNtdko23dp7qnclVt3&3eAQwWK;ayo|MMF_JLbmDVcu5HJ*w{S z-%kHLZR~)`%<{jYQ%MW$tIfoTIo?Jqt2Eo7H(` z?3&_-r+IvQL)ZfcytH{Yu;)~V*D`R@$)q7wDSOPuf_*TdEdnB;7uQ2x9jByUhzWik5g|0S0Xe zYX+1UQT_sgi-=8$lEAUR-5h(-zWSDl7`QQZ;mV`>T(n zJoAaFmig{c%AmapKYDF_jK!EAmH>i`){D-WYv`tFcQCzT*U~B6>VY&E{XL`l63YT7 zW*IQQWxj_D{LUy%#Pp9t#;w-29jVraTxX=VvpM5bS9kmSDP8@F!DsXPH3DG*_n{(b z1?0?o9Vq?%eClYfU&B^lx}=iOB0kXT9h9_ka@B`^*a&m6%iNZOc1rnPS(NuJLX&{T zyGa<%P)&h2M1uoxub@LycHYb9A#;b@Xj&ZEia?)_SNsT^awbgE_}Cw5_Y*G039JeBdJ4E-x87>z{t8oI&ihQc_XOOuL)p{e zT#O^C1di!kQ%icq13E=Z0!|vKRK-b*F!MN6c@+g8Eb^_m8fVV#j01ymQ7h@9lQEe% zZ+Ha^(hw%+oEo=(ztH9Fc_xO4suhkm4gK`Y=1p+ZjS|oAAPB$=&SF;( z@3%#wKLQ`V@BEH;jBfXJUQGEJh?#eJaTOq`Q~POd%hiC<13RB<=2oQ;_1qTuFYV@? z%>vzR`KO;;Xm{*D3w1sk90#2{z2l4B_)Uvp6)6R=@SZ-s71>|ONh;#P9tEL z%88W-Y{H@rzT4P(_Qy=z-k6r!-bA}STB%CiYEhCwiN!LBCu|PAYF(^pS24^QLJRec ziI0mMr{JG#_Xt*A%lKn|q^0OV+BoXZKxf{QN*fl=q>3w^w44l&x#ije@UbRoOB3{m zy(*F7P?7lpt?N>G$bPmbmqA!p9LO0HTkvA`)NFdOU2y$na?SpO5E*VEG1SB1zwr1B zTiV40i*e0U>&x*RIVZT2`#RRLQ5%#oA;_UPK?H;U!Z&d_HkfoUj8l`jO{cSgH#!8< z^xh56O?dwDNSfWD{JANZe|;BcUH{>XE^g2FGJvXpNYIoVWLAXs-R<{-N|wREP5Po7 z#_qUtmY7&zv=N11La8Y-+>S+~#pvZJQT;b6ZtFX6k1x~7$LLPqThe~yXtB0G$Tujg z$0vl5p{q!rxw5{z+TrFye!#cl^H-@>v+m~F%WYgzx+BqKy#%4XK!I&JUM5SIHt;cZ zElipItjxN_^ag9AJ_-~!$896FavTUmBzW{#M;K3XBn?QF!%^=iCWyC&qINHuO{&H5 zF9q%`YI#K_flv0wm=XH3VjEIZ$!YNIuG)%fd~}nzhLM|d%aaacM3n@(dRTQ?iFy@( zIwI7ZRd&l9z-`bJ$fw`$FV;eL46^(Y3CoWsLb3 z%DtSM4`Kke{0)bb!~Mj&7#LKYbW-4IZvQDrS9Vj*iE|(eBrf^vI2rieb)=~xS2ex^ zkSjwyR+%`V%{%+rx*N{ymx*asNVz@v+6lA)K97tNK1Y1LQQOMRUt{r*?OJznCM}nP zvDgEnaB`)ym(T$pe!|o?=$)~@UENQhZ(qg`I+}=%!Mx2by`PpN@VtdxxasogAtKmS zI{bqBHc#aM3=N0yL6|>S?#aeW^kTkAd6y-NbP_-f^OIocPiO}^_h9U_+Jrm_$?ag( z@VEW(Uh%9JLa5QycZK3kP!s=f8FQb_Wvlt+@o0BHpkFzpu~#2^Bu&^i?7-WT&?tor z*A$oXK&087Sjr-h73F=`t8UifNsNmpuvN36YRyR_} ztW!16yU5U6e&#%wX!?0#)4tg#k2X%noMF(YGpOlJZVbYqGn=tzh$bQ$NUVB6tijxk zV=nHnbQQDKdD>=&dJ+;Kh0_xD;lS}fMX3t&>;S)mAJC)KmFkZA!=PC*+o`^>Q}$*n z@)RnkNi?K*ZNnn`A2z*FZL#|lF(~T$UN`TJU3WNgih!i=l*|$MNju1jWxi^>a$@E- zPgdmYhf5q9KQe3Yn-I>gKd!FQ7&a^@NtjPWbCV2OEfdQgFJW^T%?XtfL@lJ;g8kHr zSAlXxS3G9lbmgG2M^P(tQKav#9a95S)YE@lq^$L&vx{z@h;!NkMGe^mXWsU*S~?6c z{AI72nHafumf+se~6A{owTc^4@$tSsE<%bJpYj$vU5?Olv>^S6?e*V%uszp5Ow0+56LOqm`K1sJHBaHfFA6g1xgyPUd z3VR`lSIZ_)mUjsvFWJY5QYC;!vH6dj2C6%XPRVOly_(+q_cqc#U)$jcsfk@~f4=x6 zqMtf8hqr#noH)k#31|y@gE(mexBre%Ars}BZijHqH~b@TWXZ(ZxZQC0=Esd)>65e= z08GKwkI3Etxdckww%N=0>6)xS3%I5%F8DE#&ye!c2X#_!7u&@v)=-R=@i|e6fQ6IC zxn%}=-6s9N*#jd=Z`&F~=dXpidz9!5dUSvVY7`4QhoeFHzLggJ(_Yiro_{gENL#1) zE*jjnDy1LjvBWB;VPat9RV|MsH})%`ak|YlD5ITCz?48KqR){_K%-<8_z`+>1a408 z{$PkO6c(t$hah!KSP6l&N8b1Fhpm9G`f`u}YCp`4+zbzmM70#~@Uv8j=g|x4u9E5C zKXrLZwiQB-;{jqprPT>TZ=}GacPNk=S}?B&6m+w7F5$e1rC#t1`?8G}qOXr}s6ZE5#}qA^i@&*i<=^$q)PM$`{PU~DwvM8L)<|XDmyVySfwXbKS@JW0;VW^!%RZg zgV$7x%P9>T>sOjn$~&QOY_Dmah%z`rYSz;1a^mBxUHcNP%Xbg8V$|My92k8x z>+Zs=d1pZj5|K{>)!>NRm*F=2h@f2X=iF8KPf>N) z+$CJeBa|o(`sA9f%WGEEFJ2zm{xqO2x4LM0&a~bd?sX#2=CO*Z|A6{3Hvsxw+?z-;NqQ(h7+UKB|Q}%jLfRz}}(jH52 z(?ZArL&*{$^&>P_g0W*OQlyt7)GmcnaoouN}+d4H-+JnVY8#*$D)EfP!SWU=X{xlFS9y0@}D=IBcUR6au#RkkgJ z895JdA{Se6brJocM4KHsr2nu?sKaj>;?=g6PNmrK^u)*q4Zu@($g4GZFKSe0UD>*y zXG{HsLyktUJmqp<<_!rwrX{vi8{m8$+s|wqPo8L@xXueBcEtrh@a4=d(K*`Xc{Kh) z8gHli5%?IW%$T-L1=#~r@41Bn^2s!$H9o@=_!yhFI{@rEeo^)+9AR*^GPy`m- z5LmgXMT`~hb;d5?3-SG0xN_v1!V;x_ANcE5tdC;M*%cVSHum&Y?-|O9X{-42dEaQB zD(4aE8xP9+N_iwIB`B}H{Qlv15ghy*SBUH5!*foe+)rd3)MOI>U*b}S<04q&611=J zJrwa-8|qZmKnyIZeL;9%@rajQLY@C`{13tHes)1^QP^Yrs?05}za~`UP!5;l+S`DiM3F+x|02A!bNGjN%|Z;}b!;QHll$ynwG2Ua&g)1h z7Fp*K^wVN+gcDA!>?D4d8RJ8)Z4B^?g(RN#P@)IUVVQy*PX*Y|CY+*xUN~y* z)aRj?Ue2)KZ=&6#KCJ9^6v&yQa~Jc@1kua)5bQR3RZn&jkATU7RU@?iJ~rZ>0UD7L zvin27_kHUc!a#-i3 zIPK-6rTgEzW_4gOA#MCF{v~Jw_iSy;Mge-Ntu8VN*?LI5UGp;BPj?}vac-$-NhdXf zbxYfK^5GW(Q)l`2>bkf5BqH2GzOpN#WB$gn?!ya?*5{SSzVDx9%Q_X#?5*JT^FoSdPI3EgYf(6XDx zeKa!7xBTes@K=ND@PpC8e)kJ73Jx0*Syfo8fMm>@;~39XH^tG%F?IU23k(-LV$Ome zilvrQ-B`}j%iLLJl^9-^5l@Z>lv{48$YRU&2CyDO>jDMe``@$goHLXw@$!YRR9|YW z!3^D-{lB8m+wEQYO78+#GqIqL{;mL3q@j`GPmH%^I`3?Mqxk%}v$pX2F2e~w zhG@wZG8U!v598FE~ES8JV5F-aF83sZJg>m83HZFc4EKZ-&LGVkYr?Fzw<-q8h@`M zQ#5a+e(|yOlb6c2CNa+J58!$?$Kv3(Av`{Z*qun0oVGVx;AOdY7h@XtJZ~(2@7XjK z?~e5fUgi^A9B~NrpWTp9BP|h3KcwDcGRI@CF4|F3WP8@xc9x-0rbBth)bN1}amwxY zi-3zs+f-^OcaTsCcm3Y0zk6x}2YQgK71F(<^8Pef z42-I$pukrxQfb+cu-Gm;F53sqy<}V>4b)yDA}cdG4P{f}QJh%ao%eUVy=IMKsf_jk zpu&Q@i9JoVIpk@`B=#cFmeu|^|{fM-;?!;JbFq}KXb+7ItJH;L7 z(SNg&BO=l5q)`(}3EPoBTz}_rgA(4pRYw1%beE+m@p0yL_jj10Omceg!v~^eX8zg# z6%lc-#fUeQq-7>MQ<08%pN7l8KtQ2iaYT8X|J!jIGCsO%ViqkZp)E!{Q<-C|HKMWyH~)-f4E2!w5m52MjTB4 z=JsV~2lOPU1@cJve#29n%v71+b^nZ`m9sr7jft)J{yW;uaPKx=KlIR;V+hv0P3J>B zs`qyqmrY)o%~csu8i0Zq(XZy@jR>B?19V3Nrza%MDZw(^EyNiBwcb>1jGd@5sovGL z&P3;hNV~`nAaqNn2zxBh?}Tq-;?7%6_zf_8=>$^}$xf1@od*z4ei72}4D|77d!w_T z!PXb=q{1a%H_AmJ8)7IP|EH2Oe~0?(!}yFQ>llnJTgVb(?Av4=Ta-aF@l}YDEXi21 zOcEL+GRQXeY*|Ce*s?_=TN9PCjiD@M-x`#jkLUaR1J8ASKi756`7Gyszwi5{^0!<| zF|1NikE?=JQ~o;Xx7P!6&d0lF9z=?OKrrDW7JztXOWqyI*h;9iJGSuN#}H(vdt~qY zLG?@&h!;&DOMtE4kLL7|RNBT{y%Mu1Y)x6L5zYs1L+&lQ&lAU7OX0@Vl5eR|)!%-* z7$z^h)MDOG4qCV8oNk&g<_0u3iS!-q|6Q*~+U;4kMSKFv6Etgm?Id{preGaA{Vx`} zj>l{1r!-0bbJz4S+1W41KQ3U2{y)pBCE$z53{)ZJJjVMZhWs861ggQvW))#~`PD}} z)A0x=siz%bj0(-V&{4#W_Y7*?p1J63bnPhuB)1o^4hwbM7Bbi%V@By6eFNmVM4gKY zhSOOegD4Y|IAkvS%BYR!DWGqlS!z<(@bf8M=DqNgB&7Mo>#S+|pJDF9tQ;rZnxCTZ z?CIKHH^w80X~o8?=G>?UP<4WM4`NDx4$r=$jq#{S4KYkaxOdOF>Z)f(Z*Jta;qZS* zV{pq~wl#4t2JgNByz>Od z<~Ek*MDwelRJ*HS5*$K+f1*j;*0bCD{3JLzRzY2A^R~5)EN_}o*SJTY;-#lt#%@0i2_E$>clHXAqMR47L<3h7l z=w?dH=9^umAN*n{CjSP`=N~TNWu%oOiTF%_U!aNKl5OwzUy-jZZB1=~N>k-vcAaAW z`aJDP@pHjMA;5da^{dClrf`}hSjy{pjx3fLaC67xj6yiY7LgB~3b&{E9z&1J%TjH* z)LY*#b$xY6orwQ=D$*-Uet3)=1(eK=eA~88ytssR^;pd0@cL}ZFZVTvkH=)!MmRO% zuK{1FpDN*e7X<%pEqM6U@vhydSNDsNC}#1|x;ULQ5Vs_ZY&xz%4)#KfkvwlsGtTK> zU=6RUVd=jCp~}S-cbYUIh}zrQ69!+DD{kLh=b&bWRvj-DMLUv`t3uROx3?O=Wl_XC zev_2EY7y~cZM&z7BZ-`fBga_j0eovq7rzeAXctt_-^?fze#&Kz=$VA>j3j)bLEI)H zF-Cy1|A>!pJmkBaA1t+!hxzjjQEGXL5nXT=2@w|ykiMSSRcmxC0=Pv_Ey!icdM8HH zSbF>?Iz2@@;4g3%b-FPu#Bl1@D0ypb23;ppm>qM#nY~Bc2a!l%g{#qC-ZFbr4nkZS?)?0c zg5X$@xiA}xKh~`)P-9TGFdBLoj}wd5?!+$nrMysy^&4F+k5$cD#BMVp7L(bjB?F*z zc;%RzJ%|KXP%u4NqxE66z2c0XehmT|%+pny|l$4=HSRJxn*l`aN_QX|bb+*nFK(g+!d&F?`? zU%BxF7MK{5W=sw%72-BmTxoKf;K-(7m0T_%ijJ8fE#bC;i3LogxJH!I7!$@Yi*-nP zi*qBg*N?dz`)3`gH~AC#CrwJQJ@0 zYK+&^u#tYkP(@Pw;N|Vgz5?PTL+8h4!-pM}a-X|)3#%-KURRtl5sn^rl6n<`u+e&x zvGjw^u>c@{`WQpt+I26&m+$%ACcobemnS=aROlyMMHGsUcC@58V{T$sHed zz1c3Ag|YOFOzxKdp^0L>x1eY==W}^fMvqftDJ5Kg)2wH>XQ5|(?vh%?_}q&thr>(j z>*M`tgLeN6Y2z3fKXiZEOF(LuW%%}dPegB)GdnL^qhoF=duXFpE*zQd;pW^3yNX9f z3#N?4wD4B9i$!{0m88#moP99W5M@zTT}`fljSSP%NUzHrL*Fnv4~5ikbpV(OMn?ja z9jqeeZyXs1OMUhTn&;Ysl({k%%nTF?fwV8OK(ofsopR^i;vBBOm+&3ZZ~9B(hm@mV z`zND0QZ8m#M(qN-WYi0V$Jjxaaz9?S6@e_f`&%=cW?jYpOUr}WCCKpe1lgTLl4^N% zXz#Rj1>Pd|5FoJ#|4p|{JN3#fMTGdJSm-ihIb#TO|E1@|=gadVgJ&w0ESmy+${bMU z-`79Vu&+`o(Dg3BIx)%{&0DIVQf~RcOl&~_4n->*vC>LVy``Gsf1QybcTaUOJ92FY zr(h?E-_B`V?M@8_r%b3jEJBA%W(#f4T4)ZyQ~Xn?H! za^$nQqLps|?&Fi{=5AHZdE-ULu87^WYO+xS7C&Df8IWwzS+UKQ<#?WWMC)jy5iPw0 zEZSZ+=9F3#rOq_Bgod=KQ`t(SfvR6VuKG|M?7ErE9&2fW^}c<9qpJR;k_Ok-sn$Rp zv*(3_l&yaNtwB$qsUK#9A^FtER*ycDqkTYxv6Xw_2Rw$@KsmOX6-&xLq@vb7EnNRo z9SHXvVVqcikII|L;5IGan%7DR;u-Gg4_STzX@#idl{%v&S;39v-mz@AY_-6=)`N|R zS2jUTaQMR)&y!5baPW7F4q_?5*2hA!yw!Ma;8<6hKBv>RKPB^)t1-Y z0QXAq`jF!>Ts-dm$wbvQXR@`>R^$Fr?c09B!-gqR`EJN@h7VBaRP45Z; z6!KnF7c18lo`AC`NA9qG(`-#D<2e?r6DDy6DlUdzs_;A4J$no0HpizGdroqoYF}#z z*z(kQribKvruTV!&nhwU>nI1nkE>|GM%7@#DZwlZBgs<3G@a+W$;#9L;BsuOK)Kv~ z-;xyoYQvka9-UuH)cP#WJ~_g=*l_e(*_+-sxv?R)GwAzDLqXh-pf&mlKUX5oE^92T zL6w%1>u41aLdO`{U;)=RhV%A-Sxz+WO4oErZB(y_?%8|}P#J4nU2G~4OHsXboc!SG zu)h70BW{r#4_)d^!NwP8em^CwhX;yk4PJz&5r{FTNcPWKh2uR9CbCfsw!+@{u zp`vcCf~{j&AX2Y4Xh)jbYZmf~pM)G#1mwPV&o&d@aRMviV$=Dsv>x9u-Fwz=v+hqOvrFy|pxT(yk?Q7g9$-FskZYAf00ry>mvbJh z5$zyFYXgxF=q3$>rM?6}In7Q`8gCIY=@yS~)X5D9LgrTC-zLZqs)LfSE{)Q(U#L);9R-@CU0+G)M6tOWswc+igcrrc9h77g!^V z8+=#WBuy?Q3zp=c2;Wh0OV5YBJPIn+H7@(=vdr+!g2AVeVuy?M6!^fGjQpAg{^G>? zBQ?uRV!y$h36y)yDQ3pEtXubjuzooSD)yyv{y{b^0J7YJ&6e7sWJChP=mHz)Rsj8U zm&MfV_p>mAy(gS3zm^Zwgkv9hGYVCCN4WlFr~OC)8I?$d3xzBFT4|)M8bkQsHQ#w! z_{4|J6ByFcU{Jqp=&s8m(8GuC+jf)yT(tjx*!r>kzF%;DxP^}%bLohKXlaHvtwFg) F{|7(I6OjM_ literal 0 HcmV?d00001 diff --git a/docs_nnx/flip/0000-template.md b/docs_nnx/flip/0000-template.md new file mode 100644 index 0000000000..97604a0ea8 --- /dev/null +++ b/docs_nnx/flip/0000-template.md @@ -0,0 +1,25 @@ +- Start Date: (fill me in with today's date, YYYY-MM-DD) +- FLIP PR: [#0000](https://github.com/google/flax/pull/0000) +- FLIP Issue: [#0000](https://github.com/google/flax/issues/0000) + +(Below sections are just a possible structure - please adapt to your FLIP.) + +# Summary +[summary]: #summary + +One paragraph explanation of the FLIP. + +# Motivation +[motivation]: #motivation + +Why are we doing this? What use cases does it support? What is the expected outcome? + +# Implementation +[implementation]: #implementation + +The technical part. + +# Discussion +[discussion]: #discussion + +Summarize the discussion from the original issue and from the pull request. diff --git a/docs_nnx/flip/1009-optimizer-api.md b/docs_nnx/flip/1009-optimizer-api.md new file mode 100644 index 0000000000..ec8d157372 --- /dev/null +++ b/docs_nnx/flip/1009-optimizer-api.md @@ -0,0 +1,504 @@ +- Start Date: 2021-02-08 +- FLIP PR: [#1011](https://github.com/google/flax/pull/1011) +- FLIP Issue: [#1009](https://github.com/google/flax/issues/1009) + +Table of contents: + +- [Summary] +- [Motivation] +- [Using Optax] + - [Gradient Transformations] + - [Optax Training Step] + - [Multi Optimizer] + - [Train State] +- [Previous API] + - [Optimizer and OptimizerDef] + - [Previous Training Step] +- [Update Plan] +- [Appendix] + - [Setup Code] + +# Summary +[Summary]: #summary + +This FLIP proposes to replace our current `flax.optim` API (referred to as +[previous API] in this document) with [Optax], DeepMind's optimizer library. + +# Motivation +[motivation]: #motivation + +Our current API (referred to as [previous API] in this document) uses a pattern +where an `Optimizer` dataclass is created from a pytree of `target` variables +and from an `OptimizerDef` that defines how to update optimizer state, +hyperparameters, and target variables. This pattern is relatively complex for +implementing a simple optimizer, while being quite verbose in the typical Linen +train step (especially when using mutable state collections). + +This package `flax.optim` contains some optimizers, but that list is far from +exhaustive and ideally we would instead use JAX optimizers from a dedicated PyPi +package. + +DeepMind already has a dedicated library — [Optax] — that implements a wide +range of interesting optimizers and provides a framework to compose new +optimizers from reusable gradient transformations. + +[Optax]: https://github.com/deepmind/optax + +# Using Optax +[Using Optax]: #using-optax + +## Gradient Transformations +[Gradient Transformations]: #gradient-transformations + +While [Optax] does provide predefined optimizers (like `optax.adam`, or +`optax.sgd` with momentum), it is really a library of *gradient transformations* +and the idiomatic way of instantiating an optimizer is by providing a +combination of these gradient transformations. To emulate the momentum +optimizer from the example when using the [previous API] we would write: + +```python +import optax + +tx = optax.chain( + optax.trace(decay=0.9, nesterov=False), + optax.scale_by_schedule(lambda step: -get_learning_rate(step)), +) +``` + +Remarks: + +- Above gradient transformation would be equivalent with the example under + [Optimizer and OptimizerDef] where we define a Momentum optimizer without + Nesterov momentum (note that the `beta` parameter corresponds to the `decay` + parameter of the `optax.trace()` transformation, and the learning rate is + applied in a second chained transformation). +- Note that hyper parameters like `decay` or `nesterov` only exist in the inner + scope of the higher order functions returning the `GradientTransformation`. + Such a gradient transformation is currently defined as a `NamedTuple` of the + `init()` and the `update()` function. In principle this pattern could be + extended to also store hyperparameters, maybe a point to discuss on the + [Optax] repo. +- We can use a `get_learning_rate()` that returns the learning rate depending on + the step number when defining the Optax gradient update transformation. Above + code illustrates how this can be a drop-in replacement for a function we also + use in our [previous training step], where this update function already exists + (notice how we need to invert the sign because we add the gradient update to + the parameters). In addition, you can use + [`inject_hyperparams()`](https://github.com/deepmind/optax/pull/48) to + schedule arbitrary hyper parameters with Optax. + +## Optax Training Step +[Optax Training Step]: #optax-training-step + +```python +@functools.partial(jax.jit, static_argnums=(4, 5)) +def train_step(opt_state, variables, inputs, labels, apply_fn, tx_update_fn): + + def loss_fn(params): + logits, new_model_state = apply_fn( + {**variables, 'params': params}, inputs, mutable=['batch_stats']) + loss = xent_loss(logits, labels) + return loss, new_model_state + + variables, params = variables.pop('params') + (loss, new_model_state), grads = jax.value_and_grad(loss_fn, has_aux=True)( + params) + updates, new_opt_state = tx_update_fn(grads, opt_state, params) + new_params = optax.apply_updates(params, updates) + new_variables = {**variables, **new_model_state, 'params': new_params} + return new_opt_state, new_variables, loss + + +opt_state = tx.init(variables['params']) +for batch in ds.as_numpy_iterator(): + opt_state, variables, loss = train_step( + opt_state, variables, batch['image'], batch['label'], model.apply, + tx.update) + print(loss) +``` + +Remarks: + +- Since `tx.update()` only transforms the gradient, we still need to call + `optax.apply_updates()` to apply these transformed gradients to the + parameters. +- Compared with the [previous API], we can now keep the entire `variables` + including the `params` as an input and output to the `train_step()`. +- Splitting `params` from `variables` is still necessary inside the train step + because we only want to compute gradients with respect to `params` and not the + entire `variables`. +- We can still log internal optimizer state, such as the learning rate, as long + as Optax transformations expose that information in their respective state. + For example, `optax.scale_by_schedule()` currently only exposes + `opt_state.count` but could easily be extend to also expose the `step_size`. + The same is true for internal optimizer states that change over time. + +## Multi Optimizer +[Multi Optimizer]: #multi-optimizer + +The [previous API] defined `flax.optim.MultiOptimizer` for processing different +parts of the parameter tree with different optimizers: + +```python +biases_traversal = flax.optim.ModelParamTraversal( + lambda path, _: path.endswith('/bias')) +not_biases_traversal = flax.optim.ModelParamTraversal( + lambda path, _: not path.endswith('/bias')) + +optimizer_def = flax.optim.MultiOptimizer( + (biases_traversal, flax.optim.GradientDescent(learning_rate=0.1)), + (not_biases_traversal, flax.optim.GradientDescent(learning_rate=0.05)), +) +``` + +Note how we first define a traversal that selects parameters based on their +path (which is the concatenation of module scopes and variable name), and then +create a `MultiOptimizer` that binds a different optimizer for each of these +separate traversals. + +Optax has recently implemented `optax.masked()` that can be used for specifying +gradient transformations that only applied to a subset of the gradients: + +```python +def flattened_traversal(fn): + def mask(data): + flat = traverse_util.flatten_dict(data) + return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()}) + return mask + +tx = optax.chain( + optax.masked(optax.sgd(learning_rate=0.1), + mask=flattened_traversal(lambda path, _: path[-1] == 'bias')), + optax.masked(optax.sgd(learning_rate=0.05), + mask=flattened_traversal(lambda path, _: path[-1] != 'bias')), +) +``` + +## Train State +[Train State]: #train-state + +In Flax it is common to hand around a `TrainState` object that can then be +used for checkpointing. This simplifies the above [Optax training step] a bit by +reducing the number of arguments and getting rid of the `static_argnums`. + +We can define a `TrainState` dataclass that wraps the common pattern of updating +the optimizer state and parameters by applying the gradients. + +```python +# Small helper class in flax.training +class TrainState(flax.struct.PyTreeNode): + step: int + apply_fn: Callable = flax.struct.field(pytree_node=False) + params: flax.core.FrozenDict[str, Any] + tx: optax.GradientTransformation = flax.struct.field(pytree_node=False) + opt_state: optax.OptState + + def apply_gradients(self, *, grads, **kwargs): + updates, new_opt_state = self.tx.update( + grads, self.opt_state, self.params) + new_params = optax.apply_updates(self.params, updates) + return self.replace( + step=self.step + 1, + params=new_params, + opt_state=new_opt_state, + **kwargs, + ) + + @classmethod + def create(cls, *, apply_fn, params, tx, **kwargs): + opt_state = tx.init(params) + return cls( + step=0, + apply_fn=apply_fn, + params=params, + tx=tx, + opt_state=opt_state, + **kwargs, + ) +``` + +Users can then derive from this dataclass and add more fields, for example +mutable model state: + +```python +from flax.training import train_state + +class TrainState(train_state.TrainState): + batch_stats: flax.core.FrozenDict[str, Any] +``` + +With this the [Optax Training Step] becomes: + +```python +@jax.jit +def train_step(state, inputs, labels): + + def loss_fn(params): + outputs, new_model_state = state.apply_fn( + {'params': params, 'batch_stats': state.batch_stats}, + inputs, + mutable=['batch_stats']) + loss = xent_loss(outputs, labels) + return loss, new_model_state + + (loss, new_model_state), grads = jax.value_and_grad( + loss_fn, has_aux=True)(state.params) + new_state = state.apply_gradients( + grads=grads, + batch_stats=new_model_state['batch_stats'], + ) + + return new_state, loss + + +state = TrainState.create( + apply_fn=model.apply, + params=variables['params'], + tx=tx, + batch_stats=variables['batch_stats'], +) +for batch in ds.as_numpy_iterator(): + state, loss = train_step(state, batch['image'], batch['label']) +``` + +The train step without mutable state reduces to: + +```python +@jax.jit +def train_step(state, inputs, labels): + + def loss_fn(params): + outputs = state.apply_fn({'params': params}, inputs) + loss = xent_loss(outputs, labels) + return loss + + loss, grads = jax.value_and_grad(loss_fn)(state.params) + new_state = state.update(grads=grads) + + return new_state, loss + + +state = flax.training.TrainState.create( + apply_fn=model.apply, + params=variables['params'], + tx=tx, +) +for batch in ds.as_numpy_iterator(): + state, loss = train_step(state, batch['image'], batch['label']) +``` + +Remarks: + +- It is a common pattern in Flax training loops to have a `TrainState` dataclass + that is updated with new state after every step. +- The simple solution proposed in `flax.training.train_state` an be extended + with additional data, but advanced usecases (e.g. multiple different models + and/or optimizers) are not supported. Users should instead fork the dataclass + and re-implement it to their needs. +- As opposed to the `Optimizer` abstraction in the [previous API], the + `TrainState` now directly contains the `.params`, without having to to through + `.optimizer` + +# Previous API +[previous API]: #previous-api + +## Optimizer and OptimizerDef +[Optimizer and OptimizerDef]: #optimizer-and-optimizerdef + +The optimizer itself would be implemented by creating a new class derived +from `OpimizerDef`: + +```python +# flax/optim/momentum.py + +@flax.struct.dataclass +class _MomentumHyperParams: + learning_rate: jnp.ndarray + beta: jnp.ndarray + + +@flax.struct.dataclass +class _MomentumParamState: + momentum: np.ndarray + + +class Momentum(flax.optim.OptimizerDef): + + def __init__(self, learning_rate=None, beta=0.9): + super().__init__( + _MomentumHyperParams(learning_rate, beta) + ) + + def init_param_state(self, param): + return _MomentumParamState(jnp.zeros_like(param)) + + def apply_param_gradient(self, step, hyper_params, param, state, grad): + del step + assert hyper_params.learning_rate is not None + new_momentum = state.momentum * hyper_params.beta + grad + new_params = param - hyper_params.learning_rate * new_momentum + return new_params, _MomentumParamState(new_momentum) +``` + +Remarks: + +- Note the relationship between `OptimizerDef` and `Optimizer` : When the + function `Optimizer.apply_gradient()` is called from the user code, it calls + into `OptimizerDef.apply_gradient()` (among other things) which in turn will + call `OptimizerDef.apply_param_gradient()` (implemented by subclasses of + `OptimizerDef`). +- The functions `init_param_state()` and `apply_param_gradient()` are called + for every leaf in the params/grads pytree. This makes it possible to write the + calculations directly without `jax.tree_util.tree_map()`. +- The interface was defined in pre-Linen without the distinction of `params` vs. + other collections in `variables` in mind. The original API was elegant because + one only needed to pass around the optimizer, which included the parameters, + optimizer state, optimizer hyperparameters, and a reference to the + `OptimizerDef` to perform the param/state update. + +## Previous Training Step +[Previous Training Step]: #previous-training-step + +An optimizer would first be constructed from its definition and the pytree of +target params: + +```python +optimizer_def = flax.optim.Momentum(learning_rate=0.1, beta=0.9) +optimizer = optimizer_def.create(variables['params']) +``` + +Then, the target variables would optimized in the train step (assuming a single +non-params collection "batch_stats"): + +```python +def make_train_step(apply_fn): + @jax.jit + def train_step(optimizer, batch_stats, inputs, labels): + + def loss_fn(params): + variables = {'params': params, 'batch_stats': batch_stats} + logits, new_model_state = apply_fn( + variables, inputs, mutable=['batch_stats']) + loss = xent_loss(logits, labels) + return loss, new_model_state['batch_stats'] + + (loss, new_batch_stats), grad = jax.value_and_grad(loss_fn, has_aux=True)( + optimizer.target) + lr = get_learning_rate(step) + new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr) + return new_optimizer, new_batch_stats, loss + + return train_step + + +batch_stats = variables['batch_stats'] +train_step = make_train_step(model.apply) +for step, batch in enumerate(ds) + optimizer, batch_stats, loss = train_step( + optimizer, batch_stats, batch['image'], batch['label']) +``` + +Remarks: + +- Notice how `optimizer.apply_gradient()` can take additional arguments to + update hyperparameters, such as learning rate from an independent function + `get_learning_rate()` in this case. + + +# Update Plan +[Update Plan]: #update-plan + +1. Finalize discussions on this FLIP +2. Add [equivalence tests] to Optax that guarantee that existing `flax.optim` + optimizers return identical values with corresponding `optax` optimizers. +3. Update examples to use Optax and verify that they reach the same final + performance with the same computational cost. +4. Port missing optimizers to Optax (e.g. Adafactor) - and verify above points. +5. Update all documentation (including README, Flax guided tour, HOWTOs, ...) to + talk exclusively about Optax optimizers. +6. Create a transition guide for updating users from `flax.optim` to using + Optax. This transition guide should also point to Optax's [equivalence tests] + and the pull requests updating the examples. +7. Mark optimizers in `flax.optim` as deprecated. + +[equivalence tests]: https://github.com/deepmind/optax/blob/master/optax/_src/equivalence_test.py + +Note that all current Flax examples use an optimizer that is already available +in Optax: + +| Example | Flax | Optax | Comments | +| -------- | -------------- | ----------- | ----------------------------------- | +| imagenet | optim.Momentum | optax.sgd | DynamicScale can be used unchanged. | +| mnist | optim.Momentum | optax.sgd | | +| nlp_seq | optim.Adam | optax.adamw | | +| pixelcnn | optim.Adam | optax.adam | | +| ppo | optim.Adam | optax.adam | | +| seq2seq | optim.Adam | optax.adam | | +| vae | optim.Adam | optax.adam | | +| wmt | optim.Adam | optax.adamw | | + +(Flax's Adam implementation has an optional parameter for weight decay, but in +Optax Adam with and without weight decay are two different aliases.) + +# Appendix +[Appendix]: #appendix + +## Setup Code +[Setup Code]: #setup-code + +The following setup code can be used for running the code snippets in this +FLIP: + +```python +import functools +from typing import Callable, Sequence + +import jax +import jax.numpy as jnp +import flax +import flax.linen as nn +import tensorflow as tf +import tensorflow_datasets as tfds + + +def pp(features): + return { + 'image': tf.cast(features['image'], tf.float32) / 255 - 0.5, + 'label': features['label'], + } + + +class Model(nn.Module): + + @nn.compact + def __call__(self, inputs): + x = inputs.reshape([inputs.shape[0], -1]) + x = nn.normalization.BatchNorm(True)(x) + x = nn.Dense(10)(x) + x = nn.log_softmax(x) + return x + + +def onehot(labels, num_classes, on_value=1.0, off_value=0.0): + x = (labels[..., None] == jnp.arange(num_classes)[None]) + x = jax.lax.select( + x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value)) + return x.astype(jnp.float32) + + +def xent_loss(logits, labels): + return -jnp.sum( + onehot(labels, num_classes=10) * logits) / labels.size + + +def get_learning_rate(step): + return 0.1 + + +model = Model() +rng = jax.random.key(0) +ds = tfds.load('mnist')['train'].take(160).map(pp).batch(16) +batch = next(iter(ds)) +variables = model.init(rng, jnp.array(batch['image'][:1])) +jax.tree_util.tree_map(jnp.shape, variables) +``` diff --git a/docs_nnx/flip/1777-default-dtype.md b/docs_nnx/flip/1777-default-dtype.md new file mode 100644 index 0000000000..6344b6bb0e --- /dev/null +++ b/docs_nnx/flip/1777-default-dtype.md @@ -0,0 +1,133 @@ +# FLIP: Default dtypes + + +- Start Date: 2022-01-11 +- FLIP PR: [#1776](https://github.com/google/flax/pull/1776) +- FLIP Issue: [#1777](https://github.com/google/flax/issues/1777) +- Status: Implemented + + +## Summary + +This FLIP proposes to replace the default dtype which is currently fixed to float32, and instead use the JAX type promotion results to derive a default dtype from the input and parameters of a layer. + + +## Motivation + +Currently, Linen Modules always produce `module.dtype` (defaults to float32) outputs regardless of input and parameter dtypes. Half-precision types like float16 and bfloat16 are supported by explicitly passing the half-precision type to each Module. The way this is currently implemented is that each Module has a dtype argument with float32 as the default value. The layer guarantees that this dtype will be the return type of the result returned by `__call__`. + +The current behavior is problematic and results in silent bugs, especially for dtypes that do not fit inside float32 (complex, float64). Also, the Linen dtype behavior is significantly different from how NumPy and by extension JAX handle dtypes. + + +### Dtypes in JAX + +JAX uses a NumPy-inspired [dtype promotion](https://github.com/google/jax/blob/main/jax/_src/dtypes.py) mechanism as explained [here](https://jax.readthedocs.io/en/latest/type_promotion.html?highlight=lattice#type-promotion-semantics). The type promotion rules are summarized by the following type lattice: + +![JAX type promotion lattice](https://jax.readthedocs.io/en/latest/_images/type_lattice.svg) + + +## Dtypes in Linen + +Besides input arguments, state and in particular parameters could affect dtype promotion. For example: we might feed a float64 input to a Dense layer with float32 parameters. Currently, the result would be truncated to float32. If the input is a complex number the result is even worse because the imaginary part will be silently dropped when casting to float32. + +By using the dtype promotion rules already available in JAX we can avoid this issue. A public API is available called `jax.numpy.result_dtype(*args)`, which returns the dtype that JAX would promote the given arguments to, in accordance with the type promotion lattice. For Linen layers the arguments would be the layer inputs together with the parameters. For example, for a linear layer this would be inputs, kernel, and bias. + +Note that there is also a `param_dtype` attribute in standard Linen Modules that also defaults to flaot32. This behavior is left untouched and encodes the common case of having float32 parameters. +There are a few reasons why float32 is almost always the correct dtype for parameters: +1. Storing weights in half-precision often leads to underflow during optimization. +2. Double precision is rarely used because it severely slows down modern accelerators (GPU, TPU). Therefore, such a cost should be explicitly opted-in for. +3. Complex Modules are relatively uncommon. Even within complex networks, the complex inputs can be projected with a real matrix. + + +# Implementation + +A simplified example implementation: + + +```python +def promote_arrays(*xs, dtype): + if dtype is None: + dtype = jnp.result_type(*jax.tree_util.tree_leaves(xs)) + return jax.tree_util.tree_map(lambda x: jnp.asarray(x, dtype), xs) + +Dtype = Any +class Dense(nn.Module): + features: int + kernel_init: Callable + bias_init: Callable + dtype: Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 + + @nn.compact + def __call__(self, x): + kernel = self.param("kernel", + self.kernel_init, + (x.shape[-1], self.features), self.param_dtype) + bias = self.param("bias", self.bias_init, (self.features,), self.param_dtype) + x, kernel, bias = promote_arrays(x, kernel, bias, dtype=self.dtype) + return x @ kernel + bias +``` + + +## Half-precision dtypes + +Some layers don’t work with half-precision dtypes internally. For example: The normalization layers currently compute mean and variance in float32 even when a half-precision dtype is specified to avoid numerical issues. We can replicate this behavior by calling result_dtype with a dummy argument that has the minimum precision for the sub computation to work correctly. + + +## Backward compatibility + +This proposal causes some layers to behave differently in cases where the dtype is not specified to a Linen Module. By default, parameters are in float32. Therefore, passing in half or float32 precision inputs will cause a float32 dtype and no functional differences with current behavior. + +When passing complex or float64 precision, the result will no longer truncate the imaginary component or the precision. The silent truncation is problematic and has caused [user complaints](https://github.com/google/flax/issues/805#issuecomment-981468837). Therefore, this change can be considered a bugfix. + +Thus, although this proposal strictly speaking changes behavior it is unlikely to cause problems for users. There are 2 exceptions to this which should be rare and easy to fix: +1. A user relies on the enforced float32 to downcast a double precision value. +2. A user relies on the float32 to explicitly upcast a half precision value even though the weights are in half precision. + + +## Corner cases + +In this section we describe corner cases where the implementation of the proposal is not obvious. The two main concerns are how complex numbers are handled in existing layers and how to determine the dtype of state variables. + +**Autoregressive decoding cache** + +Currently, only attention implements autoregressive caching and the stored key and value mirror the dtype of the key and value passed to the layer. Forcing the cache dtype to be the same as the output dtype could result in reduced precision during cached decoding vs uncached. This seems undesirable. Decision: keep the current behavior. + +**Batch statistics** + +BatchNorm layers are often used with a half precision output dtype. However, calculating statistics is by default always done in float32 to avoid numerical precision issues and over/underflow for float16. With float64 this would actually cause a downcast so we should now use `np.promote_types(float32, dtype)` such that the precision is at least float32. The running batch statistics will be stored with the same dtype for consistency. + +**Complex number support** + +Currently, our complex number support is brittle because the default behavior is to truncate the output to the real part. This issue will be fixed by the automatic type promotion proposed in this FLIP. However, some layers require some additional thought to extend to complex numbers correctly: + +1. Normalization layers use the complex conjugate to calculate norms instead of normal squaring. +2. Attention: It’s not exactly clear how the dot product and softmax are defined in this case. Raise an error on complex inputs. +3. Recurrent layers: might require special gating / activation functions to function correctly, but these can be specified by the user. + + +# Discussion + +Summarizing the main points from the discussion: + + +## Consider implicit complex truncation an error + +Q: +I'm wondering if we should always raise an error if one of the xs tree leaves is complex but dtype is not. Users should maybe remove imaginary part by themselves if that's really what they want to do. +(Maybe it's a contrived example, but I can imagine cases where layers have their dtype set by parent modules based on assumptions without complex numbers in mind) + +A: +This is worth considering in a follow-up CL but this might as well be solved in JAX directly where the safeguard would apply more generally. In NumPy this was also considered but abandoned because it is not backwards compatible. + + +## Dtype attribute names + +Q: +Are the dtype and param_dtype arguments confusing? In particular, should dtype perhaps be called output_dtype to make the difference between the two dtypes more explicit? + +A: +This would be a large and orthogonal change wrt to this proposal so leaving it out for now. +Also, this breaks with the standard dtype argument in NumPY/JAX. +Although dtype indeed constrains the output dtype it is also a hint for the dtype we would like the computation to happen in. + diff --git a/docs_nnx/flip/2396-rnn.md b/docs_nnx/flip/2396-rnn.md new file mode 100644 index 0000000000..9785e6006b --- /dev/null +++ b/docs_nnx/flip/2396-rnn.md @@ -0,0 +1,238 @@ +# RNN Flip + +- Start Date: 2022-08-18 +- FLIP PR: [#2604](https://github.com/google/flax/pull/2604) +- FLIP Issue: [#2396](https://github.com/google/flax/issues/2396) +- Authors: Jasmijn Bastings (@bastings) and Cristian Garcia (@cgarciae) + +## Summary +This FLIP adds support for higher-level recurrent layers (RNN, GRU, LSTM) that can help users process input sequences using the recurrent cells already available in Flax. + +## Motivation +Implementing well known recurrent architectures is tricky and prone to user errors, even a simple LSTM layers involves the manual creation and handling of the carry/memory and correctly setting up `nn.scan`: + +```python +@nn.compact +def __call__(self, x): + LSTM = nn.scan( + nn.LSTMCell, variable_broadcast="params", split_rngs={"params": False} + ) + carry = LSTM.initialize_carry( + jax.random.key(0), batch_dims=x.shape[:1], size=self.hidden_size + ) + carry, x = LSTM()(carry, x) + return x +``` +Slightly more complicated cases involving padding like in the [seq2seq](https://github.com/google/flax/blob/main/examples/seq2seq/models.py) example require even more work but couple potentially be simplified to a couple of lines with the right abstractions. We propose providing users with clean, correct, and efficient abstractions to use recurrent cells. + +## Requirements + +* **Masking**: We need to support a batch of sequences that contain padding at the end of each sequence. + * We do not intend to support non-contiguous padding, i.e. padding that is not at the end of a sequence, for performance reasons, except in the case of packing (see below). +* **Bidirectionality**: The ability to process a sequence in both the forward and reverse directions, respecting padding (i.e., the reverse direction should start with the actual inputs, not with padding values). +* **Performance**: The proposed classes should be benchmarked to provide the best performance in terms of step time and/or memory use. +* **Recurrent Dropout**: Support for recurrent dropout in cells (e.g. dropout on the state of the cell). + +## Implementation +### High-level structure + +We propose to have these 3 levels of abstraction: + +* **Cells (unchanged)**: all RNNCellBase subclasses such as LSTMCell and GRUCell, these implement the stepwise logic. These already exist in Flax today. +* **Layers (new)**: a class (RNN) that takes a cell and scans over a sequence respecting possible padding values and optionally also allows packed sequences. +* **Bidirectional (new)**: a single class that takes a forward and a backward RNN instance and correctly processes the input sequence in both directions and merges the results. + +### Example of proposed API +We start with a code example of what you could do with the proposed API, and then we discuss the API in detail below. + +```python +cell = nn.LSTMCell() +# Encodes a batch of input sequences. +carry, outputs = nn.RNN(cell, cell_size)(inputs, seq_lengths) +``` + +A Bidirectional layer with a LSTM RNNs for the forward and backward directions respectively would look like this: + +```python +forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32) +backward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32) +# Bidirectional combinator. +bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn) +# Encodes a batch of input sequences in both directions. +carry, outputs = bi_rnn(inputs, seq_lengths) +``` + +Next we will discuss `RNN`, `Bidirectional`, and proposed changes to `RNNCellBase`. + +### RNNBase +The `RNNBase` class serves as a base class for the `RNN` class, it specifies +the API that all RNN layers should implement to be compatible with the `Bidirectional`. +`RNNBase` contains the `__call__` and `flip_sequences` methods: + +```python +class RNNBase(Protocol): + def __call__( + self, + inputs: jax.Array, + *, + initial_carry: Optional[Carry] = None, + init_key: Optional[random.KeyArray] = None, + seq_lengths: Optional[Array] = None, + return_carry: Optional[bool] = None, + time_major: Optional[bool] = None, + reverse: Optional[bool] = None, + keep_order: Optional[bool] = None, + ) -> Union[Output, Tuple[Carry, Output]]: + ... +``` +Where: + +* `inputs`: the input sequence. +* `initial_carry`: the initial carry, if not provided it will be initialized + using the cell's :meth:`RNNCellBase.initialize_carry` method. +* `init_key`: a PRNG key used to initialize the carry, if not provided + ``jax.random.key(0)`` will be used. Most cells will ignore this + argument. +* `seq_lengths`: an optional integer array of shape ``(*batch)`` indicating + the length of each sequence, elements whose index in the time dimension + is greater than the corresponding length will be considered padding and + will be ignored. +* `return_carry`: if ``return_carry=False`` (default) only the output sequence is returned, + else it will return a tuple of the final carry and the output sequence. +* `time_major`: if ``time_major=False`` (default) it will expect inputs with shape + ``(*batch, time, *features)``, else it will expect inputs with shape ``(time, *batch, *features)``. +* `reverse`: if ``reverse=False`` (default) the sequence is + processed from left to right and returned in the original order, else it will be processed + from right to left, and returned in reverse order. If ``seq_lengths`` is passed, + padding will always remain at the end of the sequence. +* `keep_order`: if ``keep_order=True``, when ``reverse=True`` + the output will be reversed back to the original order after processing, this is + useful to align sequences in bidirectional RNNs. If ``keep_order=False`` (default), + the output will remain in the order specified by ``reverse``. +* `Returns`: if ``return_carry=False`` (default) only the output sequence is returned, +else it will return a tuple of the final carry and the output sequence. + +### RNN +The `RNN` module inherits from `RNNBase`, it main function is to apply an `RNNCellBase` instance over a batch of input sequences, it can be used with any type of cell (e.g., `GRUCell`, `LSTMCell`, etc). It accepts the following parameters: + +```python +class RNN(RNNBase): + cell: RNNCellBase, + cell_size: int | Tuple[int, ...] + time_axis: int = -2, + variable_axes = FrozenDict(), + variable_broadcast: CollectionFilter = 'params' + variable_carry: CollectionFilter = False + split_rngs = FrozenDict({'params': False}) + # implement RNNBase + ... +``` + +Attributes like `variable_axes`, `variable_broadcast`, `variable_carry`, and `split_rngs` are directly passed to `nn.scan`, their default values are set such that common cells like `LSTMCell` and `GRUCell` work out of the box. + +### Masking +`seq_lengths` is defined as an integer array of shape `(*batch,)` indicating the length of each sequence. + +
Discussion + +There are various masking formats found in other frameworks, here are some of the most popular ones: + +* **Binary masking**: specifies per-sample and timestep whether that data point should be included or not in the computation, it can be non-contigous (e.g., [1, 1, 0, 1]). This is used by Keras. +* **Sequence length masking**: specifies per-sample the number of non-padding examples contained in the sequence, any padding contained in the sequence should be stacked at the end. This is used by FlaxFormer. +* **Segmentation Mask**: specifies row and timestep to which sample the data point belongs to, this format allows more than one sample per row which potentially reduces the total amount of padding needed (e.g. [1, 1, 1, 2, 2, 0, 0]). Pytorch uses this representation (see [pack_padded_sequence](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html)). + +While Sequence packing (see [LM1B example](https://github.com/google/flax/blob/main/examples/lm1b/input_pipeline.py#L90-L92)) is is more powerful, its implementation is more complex and it is not clear whether it is worth the effort. The simplest format is sequence length masking, which is the one we propose to use. + +
+ +### Bidirectional +Bidirectional processing can be achieved via a Module that accepts a `forward_rnn` Module and a `backward_rnn` Module, both of which should be `RNN` instances, in order to process the input sequence in both directions. Here we present some pseudo code of the implementation: + +```python +def __call__(self, inputs, seq_lengths): + # Encode in the forward direction. + carry_forward, outputs_forward = self.forward_rnn( + inputs, seq_lengths=seq_lengths, + return_carry=True, reverse=False, + ) + # Encode in the reverse order. + carry_backward, outputs_backward = self.backward_rnn( + inputs, seq_lengths=seq_lengths, + return_carry=True, reverse=True, # process in reverse order + keep_order=True, # but return the sequence in the original order + ) + # Merge both sequences. + outputs = jax.tree.map(self.merge_fn, outputs_forward, outputs_backward) + + return (carry_forward, carry_backward), outputs +``` + +Here `merge_fn` a function that takes both outputs and fuses them (`concat` by default). As showcased in the beginning of this document, usage would look like this: + +```python +forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32) +backward_rnn = nn.RNN(nn.GRUCell(), cell_size=32) +# Bidirectional combinator. +bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn) +# Encodes a batch of input sequences in both directions. +carry, outputs = bi_rnn(inputs, seq_lengths) +``` + +### Recurrent Dropout +There are two main uses of dropout in RNNs: +1. Input dropout: regular dropout applied to the inputs, different for every step. +4. Recurrent dropout: applies dropout to a recurrent input/output, same for every step. + +Flax's `nn.scan` can easily express both types of dropout via `split_rns`, input dropout would split rngs while recurrent dropout would not. [#2540](https://github.com/google/flax/pull/2540) was introduces such that the `rng_name` in `nn.Dropout` can now be defined by the user, this way Cells could define both types of dropout e.g: + +```python +self.dropout = nn.Dropout(...) # input dropout +self.recurrent_dropout = nn.Dropout(..., rng_collection='recurrent_dropout') +``` +Based on this, `nn.scan` / `nn.RNN` can now specify `split_rngs` accordingly e.g: +``` +nn.scan(scan_fn, ..., split_rngs={'dropout': True, 'recurrent_dropout': False}) +``` + +# Future ideas + +
show + +### Sequence Packing +Allow packing multiple sequences to make efficient use of space/memory. This might result in a trade-off where step time is higher (because at each step we need to check whether we are starting a new sequence and reset the carry/initial state), but where less padding is used increasing efficiency overall. + +### RNNCell redesign + +#### Make initialize_state an instance method +First altenative is to make `initialize_carry` a instance method. With this change hyperparameters can be passed directly to the cell, it signature would look like this: + +```python +def initialize_carry(self, sample_input) -> Carry: + ... +``` + +Usage would look like this: + +```python +LSTM = nn.scan( + nn.LSTMCell, variable_broadcast='params', + split_rngs={'dropout': True}) +lstm = LSTM(features=32) +carry = lstm.initialize_carry(x[:, 0]) +carry, y = lstm(carry, x) +``` + +#### Remove initialize_carry + +An alternative is to remove `initialize_carry` entirely and have the carry state be handled as a carry collection. This would simplify usage quite a bit: + +```python +LSTM = nn.scan( + nn.LSTMCell, variable_broadcast='params', + split_rngs={'dropout': True}) +y = LSTM(features=32)(carry, x) +``` + +However, this would require `nn.scan` to support initialization of carry collections which is currently not possible. Also, users would have to specify that a collection is mutable e.g. `mutable=['carry']`, even if they are not interested in the output carry state. + +
diff --git a/docs_nnx/flip/2434-general-metadata.md b/docs_nnx/flip/2434-general-metadata.md new file mode 100644 index 0000000000..8f21e378b9 --- /dev/null +++ b/docs_nnx/flip/2434-general-metadata.md @@ -0,0 +1,230 @@ +# FLIP: Axis Metadata + + +- Start Date: 2022-08-08 +- FLIP Issue: [#2434](https://github.com/google/flax/issues/2434) +- FLIP PR: [#2435](https://github.com/google/flax/pull/2435) +- Status: Proposal + + +## Summary + +This FLIP proposes to extend Flax's variable collections with a generic axis metadata API. +The core of the API is an abstract base class that is recognized by lifting transformations that can add an axis (vmap, scan). +Users can extend the base class to keep track of per-axis metadata in a way that works with lifted transformations. + + +## Motivation + +Generally, there is no way in Flax to track metadata for variables across lifted transformations. +Axis metadata is used to keep track of semantic information about axes into other (Flax independent) APIs. +For example, optimizers like AdaFactor can be configured on a per-axis level and partitioning APIs +in JAX like xmap or pjit require per variable annotations to map effectiently to parallel hardware. + +Currently, there is an experimental [API](https://github.com/google/flax/blob/main/flax/linen/partitioning.py) +supporting partitioning annotations with wrappers around lifted transforms that change axes (``nn.scan_with_axes``, ``nn.vmap_with_axes``) +and a special APIs to create variables (``param_with_axes`` and ``variable_with_axes``). +The experimental partitioning API stores the metadata in a separate collection named "[collection]_axes". + + +The experimental API has a number of shortcomings that we like to solve: +1. The current API works for tracking PartitionSpecs but not for other types of metadata like optimizer annotations. +2. The implementation using an "xxx_axes" collection requires error-prone and non-composable string manipulation. +3. Special, partioning-aware variable creators and lifted transforms are required +4. The partioning API is hard to use with pre-existing Modules that aren't partioning aware. + + +## Proposal + +To generalize metadata tracking and keep the specific metadata out of core Flax we propose the following abstract base class: + +```python +TAxisMetadata = TypeVar("TAxisMetadata", bound="AxisMetadata") + +class AxisMetadata(metaclass=abc.ABCMeta): + """Abstract base class for boxed Metadata. + + ``AxisMetadata`` enables arbitrary, per axis metadata for variables. + By using ``unbox`` the metadata is stripped away to obtain the original + variables. By using unboxing, most code handling variables does not need + to handle ``AxisMetadata`` specifically, but can directly operate on the JAX + arrays that they wrap. + + Additionally, ``AxisMetadata`` supports updating metadata whenever an axis + is added or removed by a functional transformation + (e.g.: ``nn.scan`` or ``nn.vmap``) using the ``add_axis`` and ``remove_axis`` + methods. + + By extending ``AxisMetadata``, custom metadata can be stored. See + ``Partitioned`` for a specific implementation. + """ + + @abc.abstractmethod + def unbox(self) -> Any: + """Returns the content of the AxisMetadata box. + + Note that unlike ``meta.unbox`` the unbox call should recursively unbox + metadata. It should simply return value that it wraps directly even + if that value itself is an instance of AxisMetadata. + + In practise, AxisMetadata subclasses should be registred as PyTree nodes to + support passing instances to JAX and Flax APIs. The leaves returned for this + note should correspond to the value returned by unbox. + + Returns: + The unboxed value. + """ + pass + + @abc.abstractmethod + def add_axis(self: TAxisMetadata, index: int, + params: Dict[Any, Any]) -> TAxisMetadata: + """Adds a new axis to the axis metadata. + + Note that add_axis and remove_axis should act as each other's inverse + (meaning: ``x.add_axis(i, p).remove_axis(i, p) == x``) + + Args: + index: The position at which the new axis will be inserted + params: An arbitrary dictionary of parameters passed by the transformation + that introduces the new axis (e.g.: ``nn.scan`` or ``nn.vmap``). The + user passes this dictionary as the `metadata_param` argument to the + transformation. + Returns: + A new instance of the same type as self and with the same ``unbox`` + content with updated axis metadata. + """ + pass + + @abc.abstractmethod + def remove_axis(self: TAxisMetadata, index: int, + params: Dict[Any, Any]) -> TAxisMetadata: + """Removes an axis from the axis metadata. + + Note that add_axis and remove_axis should act as each other's inverse + (meaning: ``x.remove_axis(i, p).add_axis(i, p) == x``) + + Args: + index: The position of the axis that is to be removed + params: An arbitrary dictionary of parameters passed by the transformation + that introduced the axis (e.g.: ``nn.scan`` or ``nn.vmap``). The + user passes this dictionary as the `metadata_param` argument to the + transformation. + Returns: + A new instance of the same type as self and with the same ``unbox`` + content with updated axis metadata. + """ + pass +``` + +We call this type of class wrapping a value and keeping track of some additional data a **box**. +By defining an abstract base class for this box, the API does not need to be aware of the specifics of the metadata that is tracked. +This should make the API future proof and modular. + +The ``add_axis`` and ``remove_axis`` method return an instance of their own type instead of mutating in-place. +Typically, an implementation would be a ``flax.struct.PyTreeNode`` because the box should still be a valid JAX value and must therefore be handled by the PyTree API. +Calling ``jax.tree.map`` on a boxed value will simply map over the value in the box. +The lifted transforms that need to handle metadata will call ``jax.tree.map(..., is_leaf=lambda x: isinstance(x, AxisMetadata))`` to find the AxisMetadata instances within a PyTree. + +Advantages of the boxing approach: +1. Boxing can be used outside of Flax and metadata is automatically "inherited". For example, the optimizer state will + have the same partitioning spec as the parameters, because the state is initialized using a ``jax.tree.map`` over the boxed parameters. +2. Boxes are composable. +3. Boxing avoids string manipulation and generally avoids having to handle additional auxiliary collections like "param_axes" in the current + partitioning API. +4. No need to lift metadata collections separately. + + +Disadvantages: +1. Adding the boxes changes the PyTree hierarchy and introduces dataclasses within the otherwise plain, nested dict of variables. +3. Custom Pytree nodes have a small runtime overhead. It's hard to observe this in practise because JAX calls are async. + + +### Init syntax + + +Boxes can be created directly by the init function of a variable. Therefore, we propose to create metadata using higher-order initializers. +The main advantage of this is that we can decouple metadata handling completely from the Module definition. Also, most Modules already overwrite +attributes to override the default initialzers so users can add metadata to existing Modules without requiring any code changes. + +To illustrate this, let's consider a metadata class that keeps track of PartitionSpecs used by ``pjit``: + +```python +class Partitioned(flax.struct.PyTreeNode, AxisMetadata): + value: Any + names: Tuple[Optional[str], ...] = flax.struct.field(pytree_node=False) + + def add_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: + axis_name = self._get_partition_name(params) + names = list(self.names) + names.insert(index, axis_name) + return self.replace(names=tuple(names)) + + def remove_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: + axis_name = self._get_partition_name(params) + names = list(self.names) + assert names.pop(index) == axis_name + return self.replace(names=tuple(names)) + +def with_partitioning(init_fn, names): + def wrapper(*args, **kwargs): + return Partitioned(init_fn(*args, **kwargs), names) + return wrapper +``` + +Here we also defined a small utility called ``with_partitioning`` that we can use to wrap existing initialzers to add metadata: + + +```python +# init kernel with lecun normal and split the output features over the data axis +partitioned_dense = nn.Dense(features, kernel_init=with_partitioning(nn.initializers.lecun_normal, (None, "data"))) +``` + +Initializing a model that creates partitioned weights would result in the following variable structure: + +```python +variables = partitioned_dense.init(rng, jnp.ones((4,))) +jax.tree.map(np.shape, variables) # => {"params": {"kernel": Partitioned(value=(4, 8), names=(None, "data")), bias: (8,)}} +``` + +The variable tree with metadata can be used to integrate with other libraries and APIs. +For example, we can turn the ``Partitioned`` metadata into ``jax.pjit`` sharding annotations: + +```python +def to_sharding_spec(x): + if isinstance(x, Partitioned): + return PartitionSpec(*x.names) + else: + # fully replicated + return PartitionSpec() + +# Result: {"params": {"kernel": PartitionSpec(None, "data"), bias: PartitionSpec()}} +variables_pspec = jax.tree.map(to_sharding_spec, variables, is_leaf=lambda x: isinstance(x, Partitioned)) +``` + +### Unbox syntax + + +Metadata typically doesn't need to be handled by Modules directly. Therefore, we propose to make Modules agnostic to Metadata boxes by default. +The ``unbox`` method can be used to unpack a variable such that only the original JAX arrays remain. Users can manually call unbox but to make +sure Module classes don't have to call it everywhere we add an unbox keyword arg to variable returning APIs (e.g.: ``.param``, ``.variable``, ``.get_variable``). +The keyword arg ``unbox`` will default to ``True`` such that a Modules are metadata agnostic by default. This also means existing Modules will be backward compatible +with the new API. + +```python +kernel = self.param("kernel", self.kernel_init, shape) # No AxisMetadata instances +kernel_box = self.get_variable("param", "kernel", unbox=False) # AxisMetadata boxes are preserved +``` + + +### Lift syntax + +When calling a lifted transformation that adds an axis you will now be able to pass a dictionary with arguments. +These params will be passed to ``AxisMetadata`` add_axis/remove_axis callbacks: + +```python +nn.scan(..., variable_axes={"params": 0}, metadata_params={nn.Partitioned.AXIS_NAME: "layers"}) +``` + +A dict is used such that users can add their own arguments to custom AxisMetadata classes. + diff --git a/docs_nnx/flip/2974-kw-only-dataclasses.md b/docs_nnx/flip/2974-kw-only-dataclasses.md new file mode 100644 index 0000000000..7f03881d65 --- /dev/null +++ b/docs_nnx/flip/2974-kw-only-dataclasses.md @@ -0,0 +1,99 @@ +# FLIP: kw_only dataclasses +Authors: Brennan Saeta, Ivy Zheng + + - Start Date: Mar 23, 2023 + - FLIP Issue: [TBD] + - FLIP PR: #2974 + - Status: Implementing + + +## Summary + +Python 3.10 adds support for `kw_only` dataclasses. Subclasses of `flax.linen.Module` are automatically converted to `dataclasses` on users' behalf, but today, Flax doesn't allow setting the `kw_only` parameter to this dataclass transform, even if users are running Python 3.10. This proposal allows users to use this new feature with `nn.Module`'s. + + +## Motivation + +In larger Flax-based codebases (e.g. [`PaxML`](https://github.com/google/paxml) / [`Praxis`](https://github.com/google/praxis)), it’s not uncommon to define an (abstract) subclass of nn.Module that contains shared functionality that is itself further subclassed for specific implementations (e.g. [`BaseLayer`](https://github.com/google/praxis/blob/main/praxis/base_layer.py), or [`StackedTransformerRepeat`](https://github.com/google/praxis/blob/81479b260fcc13de8549cdbfb0fdf5c3f188ac90/praxis/layers/transformers.py#L1836) which is further subclassed by [`PipelineCompatibleStackedTransformerRepeat`](https://github.com/google/praxis/blob/81479b260fcc13de8549cdbfb0fdf5c3f188ac90/praxis/layers/transformers.py#L2198)). + +Often, these parent types define hyperparameters (constructor arguments), often with default values. Without `kw_only` on the `dataclass` transform, default values must be specified for all child layers hyperparameters. This is suboptimal, because users could forget to set them when instantiating the modules. For example, `Child` must set a default value for `num_heads` (because a non-defaulted argument can’t come after a defaulted argument if they are positional), but no reasonable default is available: + +```python +class BaseLayer(nn.Module): + mesh: Optional[jax.experimental.mesh.Mesh] = None + + def with_sharding(self, some_variable, some_sharding): + if self.mesh: + # Do something useful here. + +class Child(BaseLayer): + num_heads: int # Don't want to have to set a default argument! + + def __call__(self, x): + ... +``` + +Note: Flax already has this problem, which is why `nn.Module` has its own fancy `kw_only_dataclasses.dataclass` transform: it moves the `name` and `parent` dataclass fields to the end, so they can have defaults. + + +## Implementation + +To allow modules to optionally opt into this `kw_only` dataclass behavior, we leverage arguments to `__init_subclass__`. This would look as follows: + +```python +class BaseLayer(nn.Module, kw_only=True): + ... + +class Child(BaseLayer): + ... +``` + +The implementation of `nn.Module`’s `__init_subclass__` will be tweaked as follows: + +```python +class Module(ModuleBase): + def __init_subclass__(self, kw_only: Optional[bool] = None): + # ... + if kw_only: + if is_python_310_or_above(): + dataclass_transform_args = {'kw_only': True} + else: + raise TypeError("Can't use `kw_only` before Py3.10.") + else: + dataclass_transform_args = {} + + kw_only_dataclasses.dataclass( + cls, unsafe_hash='__hash__' not in cls.__dict__, + repr=False, + **dataclass_transform_args) +``` + +### Forward compatibility + +For future simplification, if `kw_only` is requested and the Python version is 3.10 or above, bypass the `kw_only_dataclasses` implementation and just use the regular `dataclasses` transform. + +That means we may one day remove `flax/linen/kw_only_dataclasses.py` when Flax rolls over 3.10. + + +## Discussion + +### Aligned with Python `dataclass` + +We prefer to keep the behavior of `nn.Module`’s `kw_only` aligned with the Python dataclasses. Note that this means `kw_only` will not be inheritable, and this could happen: + +```python +class BaseLayer(nn.Module, kw_only=True): + base_muliplier: Optional[int] = -1 + +class ChildLayer(BaseLayer): + child_multiplier: int + +BaseLayer(2) # This will throw error +ChildLayer(2) # But this will not +``` + +### `flax.struct.dataclass` + +There’s a potentially related feature to allow `kw_only` to be specified for `flax.struct.dataclass`. This should be considered an orthogonal decision. + + diff --git a/docs_nnx/flip/3099-rnnbase-refactor.md b/docs_nnx/flip/3099-rnnbase-refactor.md new file mode 100644 index 0000000000..e15e2c9b52 --- /dev/null +++ b/docs_nnx/flip/3099-rnnbase-refactor.md @@ -0,0 +1,79 @@ +# Refactor RNNCellBase in FLIP + +Authors: Cristian Garcia, Marcus Chiam, Jasmijn Bastings + + - Start Date: May 1, 2023 + - FLIP Issue: [TBD] + - FLIP PR: #3053 + - Status: Implemented + +## Summary +This proposal aims to improve the usability of the `RNNCellBase` class by refactoring the `initialize_carry` method and other relevant components. + +## Motivation + +Currently, `initialize_carry` is used to both initialize the carry and pass crucial metadata like the number of features. The API can be unintuitive as it requires users to manually calculate things that could typically be inferred by the modules themselves, such as the shape of batch dimensions and the shape of feature dimensions. + +### Example: ConvLSTM +The current API can be unintuitive in cases like `ConvLSTM` where a the `size` parameter contains both the input image shape and output feature dimensions: + +```python +x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels) + +# image shape: vvvvvvv +carry = nn.ConvLSTMCell.initialize_carry(key1, (16,), (64, 64, 16)) +# batch size: ^^ ^^ :output features + +lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3)) +(carry, y), initial_params = lstm.init_with_output(key2, carry, x) +``` + +This FLIP will propose some changes to `initialize_carry` such that the previous example can be simplified to: + +```python +x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels) + +lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3)) +carry = lstm.initialize_carry(key1, input_shape=x.shape) + +(carry, y), initial_params = lstm.init_with_output(key2, carry, x) +``` + +## Implementation +The proposal suggests the following changes: + +### initialize_carry +`initialize_carry` should be refactored as an instance method with the following signature: + +```python +def initialize_carry(self, key, sample_input): +``` + +`sample_input` should be an array of the same shape that will be processed by the cell, excluding the time axis. + +### Refactor RNNCellBase subclasses +`RNNCellBase` should be refactored to include the metadata required to initialize the cell and execute its forward pass. For `LSTMCell` and `GRUCell`, this means adding a `features` attribute that should be provided by the user upon construction. This change aligns with the structure of most other `Module`s, making them more familiar to users. + +```python +x = jnp.ones((2, 100, 10)) # (batch, time, features) + +cell = nn.LSTMCell(features=32) +carry = cell.initialize_carry(PRNGKey(0), x[:, 0]) # sample input + +(carry, y), variables = cell.init_with_output(PRNGKey(1), carry, x) +``` + +### num_feature_dims +To simplify the handling of `RNNCellBase` instances in abstractions like `RNN`, each cell should implement the `num_feature_dims` property. For most cells, such as `LSTMCell` and `GRUCell`, this is always 1. For cells like `ConvLSTM`, this depends on their `kernel_size`. + +## Discussion +### Alternative Approaches +* To eliminate the need for `num_feature_dims`, `RNN` could support only a single batch dimension, i.e., inputs of the form `(batch, time, *features)`. Currently, it supports both multiple batch dimensions and multiple feature dimensions. +* Another approach could be a complete redesign of how Flax deals with recurrent states. For example, a `memory` collection could be handled as part of the variables. However, this introduces challenges such as handling stateless cells during training, passing state from one layer to another, and performing initialization inside `scan`. + +### Refactor Cost +Initial TGP results showed 761 broken and 110 failed tests. However, after fixing one test, TGP results in 231 broken and 13 failed tests so there seems to be a lot +of overlap between the broken tests. + +To minimize refactor costs, the current implementation will be kept for Google internal users under a deprecated name. This will allow users to migrate to the new API at their own pace. For Open Source users we should bump Flax version to +`0.7.0` so existing users can continue to depend on `0.6.x` versions. diff --git a/docs_nnx/flip/4105-jax-style-nnx-transforms.md b/docs_nnx/flip/4105-jax-style-nnx-transforms.md new file mode 100644 index 0000000000..5bb552c2aa --- /dev/null +++ b/docs_nnx/flip/4105-jax-style-nnx-transforms.md @@ -0,0 +1,177 @@ +# JAX-style NNX Transforms + +- Authors: Cristian Garcia, Anselm Levskaya +- Date: Jun/2024 +- FLIP PR: #4107 +- Status: Implementing + +## Motivation + +NNX allows users to utilize Modules at the top level due to their eager initialization and self-contained state. This naturally leads users to want to use them with transforms and soon start playing with NNX transforms. Since NNX Modules resemble PyTrees in that they contain Arrays, new users often attempt to apply JAX conventions, for example: + +```py +@nnx.vmap(in_axes=(1, 0)) +def f(m1: Module, m2: Module): + ... +``` + +However, this can be misleading. Currently, NNX transforms follow Linen's convention of treating input Modules as a single unit (all Modules are split together to preserve shared references) and provide APIs for transforming that State separately. The previous example effectively translates to: + +```py +# this is what is really happening +@nnx.vmap(in_axes=(IGNORE, IGNORE), state_axes={BatchStat: None, ...: 0}) +def f(m1: Module, m2: Module): + ... +``` + +Note that `IGNORE` is not a real symbol, but represents the fact that any value placed here won't affect the outcome, as Modules are replaced by empty PyTree placeholders (similar to `None`). The `state_axes` parameter controls how the State is vectorized through a mapping of high-level `Filter`s to their desired axes. In this example, `...` (ellipsis) is a filter that accepts everything, so by default all States are vectorized on the 0th axis. + +To express their original intention, users must resort to more complex custom filters that guess the index of each Module in the monolith. While this is straightforward in simple cases, users generally need to calculate the index (Modules appear in the order specified by `jax.tree.leaves` over the `args`): + +```py +select_m1 = lambda path, value: path[0] == 0 +select_m2 = lambda path, value: path[0] == 1 + +# To select modules individually, you must create a filter (which can be tricky) +@nnx.vmap(state_axes={select_m1: 1, select_m2: 0}) +def f(m1: Module, m2: Module): + ... +``` + +## What if JAX conventions Just Worked™? + +This proposal aims to align NNX transforms with user's expectations based on their JAX experience, making the syntax work as intuitively as possible. The original example would function **as if** `m1` and `m2` were PyTrees vectorized in axes `1` and `0` respectively: + +```py +@nnx.vmap(in_axes=(1, 0)) +def f(m1: Module, m2: Module): + ... +``` + +The primary advantage of this approach is that for `vmap` and `scan`, we could eliminate the `state_axes` and `split_rngs` arguments, relying solely on the `in_axes` API. This syntax alone would likely suffice for 80-90% of use cases, as users tend to manage state in predictable ways. + +### The Lift symbols + +To enable more fine-grained state control within each Module, we introduce the `Lift` API. By using special types containing State Filters in place of a tree prefix, state lifting can now be done **structurally**. This allows different Filters to be applied to different Modules in the arguments without the need for complex path-based filters. Ideally, each transform would support its own Lift type, adding the desired behavior through existing JAX APIs. + +For example, in `vmap`, we could allow `StateAxes` instances (vmap's Lift type) to be accepted by `in/out_axes` to control how substates are handled by mapping state `Filter`s to an axis specifier: + +```py +state_axes = StateAxes({Param: 1, BatchStat: None}) + +@nnx.vmap(in_axes=(state_axes, 0)) +def f(m1: Module, m2: Module): + ... +``` + +In this case, `m1`'s `Param`s are vectorized in axis `1` while its `BatchStat`s are broadcasted, and `m2`'s entire state is vectorized in axis `0`. + +For `nnx.grad`, we could allow `DiffState` to be used in the `argnums` parameter to specify both the position of the argument to be differentiated and a Filter specifying the differentiable State of the Module: + +```py +grads = nnx.grad(loss_fn, argnums=(DiffState(0, LoRAParam),))(model, x, y) +``` + +## Rng Handling + +To simplify RNG state handling, we propose removing the separate `split_rngs` parameter in `vmap` and `scan`. Instead, we suggest introducing a new `nnx.split_rngs` API that would manage RNG handling before and after the transformation. This approach provides more explicit control to the user and aligns better with JAX transform behavior. + +## Consistent Aliasing + +To ensure the correctness of transformations with objects that obey reference semantics, we must enforce consistent lifting/lowering specifications for all aliases of a reference. Transforms must adhere to two rules: + +1. All aliases of a reference must receive the **exact same** lifting/lowering specification. +2. Captured references are not allowed on the output of transformed functions. + +For example: + +```py +@nnx.vmap(in_axes=(m1_axes, m2_axes, m1_axes), out_axes=m2_axes) +def f(m1, m2, m1_alias): + return m2 + +m2 = f(m1, m2, m1) +``` + +Here, `m1` has two input aliases as it is passed as the first and third input to `f`, but this is acceptable because `m1_axes` is assigned to both in `in_axes`. `m2` is passed as the second input and has an output alias, which is also acceptable because `m2_axes` is assigned in both `in_axes` and `out_axes`. + +Let's examine some examples of programs that should be **rejected** based on these criteria: + +### Inconsistent input aliases + +Consider a function with two arguments `m1` and `m2` being vectorized in axis `0` and `1` respectively. Passing the same Module as both arguments would be inconsistent: + +```py +@nnx.vmap(in_axes=(0, 1)) +def f(m1: Module, m2: Module): + ... + +f(m, m) # This should be rejected +``` + +### Inconsistent input / output aliases + +Now consider an identity function `g` under `vmap` with `in_axes=0` and `out_axes=1`. In JAX, this would result in transposing the arrays in the inputs: + +```py +@nnx.vmap(in_axes=0, out_axes=1) +def g(m: Module): + return m +``` + +While this appears correct, in NNX this behavior is not well-defined because shared mutable references behave as auxiliary outputs. Under the hood, `g` is converted into a function that has the inputs as an extra first output, and `out_axes` is set to the same values as `in_axes` for that output: + +```py +@nnx.vmap(in_axes=0, out_axes=(0, 1)) +def g_real(m: Module): + return m, m +``` + +This return structure reveals an inconsistency: we're attempting to lower `m` with both `out_axes=0` and `out_axes=1`. + +### Inconsistent aliases in nested structures + +Similar issues can arise in less obvious cases, such as when `m` is contained within another structure: + +```py +@nnx.vmap(in_axes=0, out_axes=1) +def f(m: Module): + return SomeModule(m) +``` + +This means we must traverse the entire graph of both inputs and outputs to check for consistent assignments. The same problem occurs when passing shared reference inputs/outputs with different specifications: + +```py +shared = Shared() +m1, m2 = Foo(shared), Foo(shared) + +@nnx.vmap(in_axes=(0, 1)) +def f(m1, m2): # shared is passed through both + ... +``` + +### Captured Modules cannot be outputs + +Finally, let's consider the second consistent aliasing rule, which states that captured Modules cannot be outputs. The main issue here is that NNX needs to split all input references together to track changes, but captured Modules bypass this process. Treating them as new references would result in **implicit cloning**: + +```py +m = SomeModule() + +@nnx.vmap(out_axes=0, axis_size=5) +def f(): + return m + +assert m is not f() # implicit cloning +``` + +To preserve reference identity, we must disallow captured Modules as outputs. In practice, we can detect captured Modules using the trace level context machinery used to restrict stateful updates on Modules from a different level. + +## Recap + +In this document, we have: + +* Discussed issues with the current implementation that make it unintuitive for JAX users. +* Proposed refactoring NNX transforms to allow users to use regular JAX semantics when interacting with objects, removing extra arguments introduced by NNX transforms. +* Introduced the use of Lift types in JAX APIs to compensate for the lack of a "prefix" notion in NNX objects, enabling independent lifting of Module substates. +* Proposed a new `nnx.split_rngs` API to replace the `split_rngs` arguments in `vmap` and `scan`, making RNG handling an explicit operation and giving users more control. +* Analyzed edge cases resulting from aliasing shared mutable references and proposed enforcing **consistent aliasing** on all transforms with semantics over the inputs. \ No newline at end of file diff --git a/docs_nnx/flip/README.md b/docs_nnx/flip/README.md new file mode 100644 index 0000000000..a489b58719 --- /dev/null +++ b/docs_nnx/flip/README.md @@ -0,0 +1,32 @@ +# FLIP: Flax Improvement Process + +Most changes can be discussed with simple issues/discussions and pull requests. + +Some changes though are a bit larger in scope or require more discussion, and +these should be implemented as FLIPs. This allows for writing longer documents +that can be discussed in a pull request themselves. + +The structure of FLIPs is kept as lightweight as possible to start and might +be extended later on. + +## When you should use a FLIP + +- When your change requires a design doc. We prefer collecting the designs as + FLIPs for better discoverability and further reference. + +- When your change requires extensive discussion. It's fine to have relatively + short discussions on issues or pull requests, but when the discussion gets + longer this becomes unpractical for later digestion. FLIPs allow to update the + main document with a summary of the discussion and these updates can be + discussed themselves in the pull request adding the FLIP. + +## How to start a FLIP + +First, create an issue with the [FLIP label]. All pull requests that relate to +the FLIP (i.e. adding the FLIP itself as well as any implementing pull requests) +should be linked to this issue. + +Then create a pull request that consists of a copy of the `0000-template.md` +renamed to `%04d-{short-title}.md` - with the number being the issue number. + +[FLIP label]: https://github.com/google/flax/issues?q=label%3AFLIP diff --git a/docs_nnx/glossary.rst b/docs_nnx/glossary.rst new file mode 100644 index 0000000000..39aef00050 --- /dev/null +++ b/docs_nnx/glossary.rst @@ -0,0 +1,112 @@ +********* +Glossary +********* + +For additional terms, refer to the `Jax glossary `__. + +.. glossary:: + + Bound Module + When a :class:`Module ` + is created through regular Python object construction (e.g. `module = SomeModule(args...)`, it is in an *unbound* state. This means that only + dataclass attributes are set, and no variables are bound to the module. When the pure + functions :meth:`Module.init() ` + or :meth:`Module.apply() ` + are called, Flax clones the Module and binds the variables to it, and the module's method code is + executed in a locally bound state, allowing things like calling submodules directly without + providing variables. For more details, refer to the + `module lifecycle `__. + + Compact / Non-compact Module + Modules with a single method are able to declare submodules and variables inline by + using the :func:`@nn.compact ` decorator. + These are referred to as “compact-style modules”, + whereas modules defining a :meth:`setup() ` method + (usually but not always with multiple callable methods) + are referred to as “setup-style modules”. To learn more, refer to the + `setup vs compact guide `__. + + `Folding in `__ + Generating a new PRNG key given an input PRNG key and integer. Typically used when you want to + generate a new key but still be able to use the original rng key afterwards. You can also do this with + `jax.random.split `__ + but this will effectively create two RNG keys, which is slower. See how Flax generates new PRNG keys + automatically within ``Modules`` in our + `RNG guide `__. + + `FrozenDict `__ + An immutable dictionary which can be “`unfrozen `__” + to a regular, mutable dictionary. Internally, Flax uses FrozenDicts to ensure variable dicts + aren't accidentally mutated. Note: We are considering returning to regular dicts from our APIs, + and only using FrozenDicts internally. + (see `#1223 `__). + + Functional core + The flax core library implements the simple container Scope API for threading + variables and PRNGs through a model, as well as the lifting machinery needed to + transform functions passing Scope objects. The python class-based module API + is built on top of this core library. + + Lazy initialization + Variables in Flax are initialized late, only when needed. That is, during normal + execution of a module, if a requested variable name isn’t found in the provided + variable collection data, we call the initializer function to create it. This + allows us to treat initialization and application under the same code-paths, + simplifying the use of JAX transforms with layers. + + Lifted transformation + Refer to the `Flax docs `__. + + Module + A dataclass allowing the definition and initialization of parameters in a + referentially-transparent form. This is responsible for storing and updating variables + and parameters within itself. Modules can be readily transformed into functions, + allowing them to be trivially used with JAX transformations like `vmap` and `scan`. + + Params / parameters + "params" is the canonical variable collection in the variable dictionary (dict). + The “params” collection generally contains the trainable weights. + + RNG sequences + Inside Flax :class:`Modules `, you can obtain a new + `PRNG `__ + key through :meth:`Module.make_rng() `. + These keys can be used to generate random numbers through + `JAX's functional random number generators `__. + Having different RNG sequences (e.g. for "params" and "dropout") allows fine-grained + control in a multi-host setup (e.g. initializing parameters identically on different + hosts, but have different dropout masks) and treating these sequences differently when + `lifting transformations `__. + See the `RNG guide `__ + for more details. + + Scope + A container class for holding the variables and PRNG keys for each layer. + + Shape inference + Modules do not need to specify the shape of the input array in their definitions. + Flax upon initialization inspects the input array, and infers the correct shapes + for parameters in the model. + + TrainState + Refer to :class:`flax.training.train_state.TrainState`. + + Variable + The `weights / parameters / data / arrays `__ + residing in the leaves of :term:`variable collections`. + Variables are defined inside modules using :meth:`Module.variable() `. + A variable of collection "params" is simply called a param and can be set using + :meth:`Module.param() `. + + Variable collections + Entries in the variable dict, containing weights / parameters / data / arrays that + are used by the model. “params” is the canonical collection in the variable dict. + They are typically differentiable, updated by an outer SGD-like loop / optimizer, + rather than modified directly by forward-pass code. + + `Variable dictionary `__ + A dictionary containing :term:`variable collections`. + Each variable collection is a mapping from a string name + (e.g., ":term:`params`" or "batch_stats") to a (possibly nested) + dictionary with :term:`Variables` as leaves, matching the submodule tree structure. + Read more about pytrees and leaves in the `Jax docs `__. \ No newline at end of file diff --git a/docs/nnx/blog.md b/docs_nnx/guides/blog.md similarity index 100% rename from docs/nnx/blog.md rename to docs_nnx/guides/blog.md diff --git a/docs/nnx/bridge_guide.ipynb b/docs_nnx/guides/bridge_guide.ipynb similarity index 100% rename from docs/nnx/bridge_guide.ipynb rename to docs_nnx/guides/bridge_guide.ipynb diff --git a/docs/nnx/bridge_guide.md b/docs_nnx/guides/bridge_guide.md similarity index 98% rename from docs/nnx/bridge_guide.md rename to docs_nnx/guides/bridge_guide.md index ddd0f5b018..8b808d7f8f 100644 --- a/docs/nnx/bridge_guide.md +++ b/docs_nnx/guides/bridge_guide.md @@ -1,6 +1,6 @@ # Use Flax NNX along with Flax Linen -This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API. +This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API. This will be helpful if you: @@ -9,9 +9,9 @@ This will be helpful if you: We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the caveats of interoperating the two APIs, on a few aspects that they are fundamentally different. -**Note**: +**Note**: -This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/nnx/haiku_linen_vs_nnx.html) guide. +This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/nnx/haiku_linen_vs_nnx.html) guide. And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html). @@ -31,7 +31,7 @@ from typing import * ## Submodule is all you need -A Flax model is always a tree of modules - either old Linen modules (`flax.linen.Module`, usually written as `nn.Module`) or NNX modules (`nnx.Module`). +A Flax model is always a tree of modules - either old Linen modules (`flax.linen.Module`, usually written as `nn.Module`) or NNX modules (`nnx.Module`). An `nnx.bridge` wrapper glues the two types together, in both ways: @@ -295,7 +295,7 @@ model = bridge.ToLinen(NNXAddConstant, skip_rng=True) y, var = model.init_with_output(jax.random.key(0), x) ``` -You may notice that you need to an additional `.value` to access this Flax `w` param. This is because all NNX variables will be wrapped with an `nnx.Variable` class, which will allow it to be annotated with various information, such as its partitioning. This was translated into an equivalent `nnx.bridge.NNXMeta` wrapper. +You may notice that you need to an additional `.value` to access this Flax `w` param. This is because all NNX variables will be wrapped with an `nnx.Variable` class, which will allow it to be annotated with various information, such as its partitioning. This was translated into an equivalent `nnx.bridge.NNXMeta` wrapper. If you use [partition metadata in Linen](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html), you can learn more about how that works in NNX in [Partition Metadata Section](#partition-metadata) below. @@ -309,7 +309,7 @@ print(type(variables['params']['w'].value)) # => jax.Array -Similar to `ToNNX`, you can use `ToLinen` to create a submodule of another Linen module. +Similar to `ToNNX`, you can use `ToLinen` to create a submodule of another Linen module. ```python @@ -361,7 +361,7 @@ If you convert an NNX module to Linen, the underlying NNX module's RNG states wi Now, it really depends on whether your underlying NNX module generates new random data from its RNG state, or from the passed-in argument. Fortunately, `nnx.Dropout` supports both - using passed-in keys if there is any, and use its own RNG state if not. -And this leaves you with two style options of handling the RNG keys: +And this leaves you with two style options of handling the RNG keys: * The NNX style (recommended): Let the underlying NNX state manage the RNG keys, no need to pass in extra keys in `apply()`. This means a few more lines to mutate the `variables` for every apply call, but things will look easier once your whole model no longer needs `ToLinen`. @@ -403,13 +403,13 @@ assert jnp.allclose(y3, y5) # When you use same top-level RNG, outputs are When you want to group some variables as one category, in Linen you use different collections; in NNX, since all variables shall be top-level Python attributes, you use different variable types. -Therefore, when mixing Linen and NNX modules, Flax must know the 1-to-1 mapping between Linen collections and NNX variable types, so that `ToNNX` and `ToLinen` can do the conversion automatically. +Therefore, when mixing Linen and NNX modules, Flax must know the 1-to-1 mapping between Linen collections and NNX variable types, so that `ToNNX` and `ToLinen` can do the conversion automatically. Flax keeps a registry for this, and it already covers all Flax's built-in Linen collections. You can register extra mapping of NNX variable type and Linen collection names using `nnx.register_variable_name_type_pair`. ### Linen to NNX -For any collection of your Linen module, `ToNNX` will convert all its endpoint arrays (aka. leaves) to a subtype of `nnx.Variable`, either from registry or automatically created on-the-fly. +For any collection of your Linen module, `ToNNX` will convert all its endpoint arrays (aka. leaves) to a subtype of `nnx.Variable`, either from registry or automatically created on-the-fly. (However, we still keep the whole collection as one class attribute, because Linen modules may have duplicated names over different collections.) @@ -515,17 +515,17 @@ print(var['params']) ## Partition metadata -Flax uses a metadata wrapper box over the raw JAX array to annotate how a variable should be sharded. +Flax uses a metadata wrapper box over the raw JAX array to annotate how a variable should be sharded. -In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too. +In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too. The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen and `nnx.with_partitioning` for NNX). ### Linen to NNX -Even if you are not using any partition metadata in your Linen module, the variable JAX arrays will be converted to `nnx.Variable`s that wraps the true JAX array within. +Even if you are not using any partition metadata in your Linen module, the variable JAX arrays will be converted to `nnx.Variable`s that wraps the true JAX array within. -If you use `nn.with_partitioning` to annotate your Linen module's variables, the annotation will be converted to a `.sharding` field in the corresponding `nnx.Variable`. +If you use `nn.with_partitioning` to annotate your Linen module's variables, the annotation will be converted to a `.sharding` field in the corresponding `nnx.Variable`. You can then use `nnx.with_sharding_constraint` to explicitly put the arrays into the annotated partitions within a `jax.jit`-compiled function, to initialize the whole model with every array at the right sharding. @@ -535,7 +535,7 @@ class LinenDotWithPartitioning(nn.Module): out_dim: int @nn.compact def __call__(self, x): - w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')), + w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')), (x.shape[-1], self.out_dim)) return x @ w @@ -604,7 +604,7 @@ print(type(unboxed['params']['w'])) # The raw jax.Array ## Lifted transforms -In general, if you want to apply Linen/NNX-style lifted transforms upon an `nnx.bridge`-converted module, just go ahead and do it in the usual Linen/NNX syntax. +In general, if you want to apply Linen/NNX-style lifted transforms upon an `nnx.bridge`-converted module, just go ahead and do it in the usual Linen/NNX syntax. For Linen-style transforms, note that `bridge.ToLinen` is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases) diff --git a/docs/nnx/demo.ipynb b/docs_nnx/guides/demo.ipynb similarity index 100% rename from docs/nnx/demo.ipynb rename to docs_nnx/guides/demo.ipynb diff --git a/docs/nnx/demo.md b/docs_nnx/guides/demo.md similarity index 100% rename from docs/nnx/demo.md rename to docs_nnx/guides/demo.md diff --git a/docs/nnx/filters_guide.ipynb b/docs_nnx/guides/filters_guide.ipynb similarity index 96% rename from docs/nnx/filters_guide.ipynb rename to docs_nnx/guides/filters_guide.ipynb index 21591226ac..5f63191bbf 100644 --- a/docs/nnx/filters_guide.ipynb +++ b/docs_nnx/guides/filters_guide.ipynb @@ -82,8 +82,8 @@ "```\n", "where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise.\n", "\n", - "Types are obviously not functions of this form, so the reason why they are treated as Filters \n", - "is because, as we will see next, types and some other literals are converted to predicates. For example, \n", + "Types are obviously not functions of this form, so the reason why they are treated as Filters\n", + "is because, as we will see next, types and some other literals are converted to predicates. For example,\n", "`Param` is roughly converted to a predicate like this:" ] }, @@ -117,7 +117,7 @@ "id": "a8a2641e", "metadata": {}, "source": [ - "Such function matches any value that is an instance of `Param` or any value that has a \n", + "Such function matches any value that is an instance of `Param` or any value that has a\n", "`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which\n", "defines a callable of this form for a given type:" ] @@ -151,8 +151,8 @@ "source": [ "## The Filter DSL\n", "\n", - "To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized \n", - "as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, \n", + "To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized\n", + "as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis,\n", "tuples/lists, etc, and converts them to the appropriate predicate internally.\n", "\n", "Here is a list of all the callable Filters included in Flax NNX and their DSL literals\n", @@ -286,7 +286,7 @@ " break\n", " else:\n", " raise ValueError(f'No filter matched {path = } {value = }')\n", - " \n", + "\n", " states: tuple[nnx.GraphState, ...] = tuple(\n", " nnx.State.from_flat_path(flat_state) for flat_state in flat_states\n", " )\n", @@ -306,11 +306,11 @@ "id": "7b3aeac8", "metadata": {}, "source": [ - "One very important thing to note is that **filtering is order-dependent**. The first filter that \n", - "matches a value will keep it, therefore you should place more specific filters before more general \n", - "filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar` \n", - "object that contains both types of parameters, if we try to split the `Param`s before the \n", - "`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group \n", + "One very important thing to note is that **filtering is order-dependent**. The first filter that\n", + "matches a value will keep it, therefore you should place more specific filters before more general\n", + "filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar`\n", + "object that contains both types of parameters, if we try to split the `Param`s before the\n", + "`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group\n", "will be empty because all `SpecialParam`s are also `Param`s:" ] }, diff --git a/docs/nnx/filters_guide.md b/docs_nnx/guides/filters_guide.md similarity index 95% rename from docs/nnx/filters_guide.md rename to docs_nnx/guides/filters_guide.md index 84bbe3fa7f..c403451649 100644 --- a/docs/nnx/filters_guide.md +++ b/docs_nnx/guides/filters_guide.md @@ -50,8 +50,8 @@ In general Filter are predicate functions of the form: ``` where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise. -Types are obviously not functions of this form, so the reason why they are treated as Filters -is because, as we will see next, types and some other literals are converted to predicates. For example, +Types are obviously not functions of this form, so the reason why they are treated as Filters +is because, as we will see next, types and some other literals are converted to predicates. For example, `Param` is roughly converted to a predicate like this: ```{code-cell} ipython3 @@ -64,7 +64,7 @@ print(f'{is_param((), nnx.Param(0)) = }') print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') ``` -Such function matches any value that is an instance of `Param` or any value that has a +Such function matches any value that is an instance of `Param` or any value that has a `type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which defines a callable of this form for a given type: @@ -77,8 +77,8 @@ print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') ## The Filter DSL -To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized -as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, +To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized +as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally. Here is a list of all the callable Filters included in Flax NNX and their DSL literals @@ -151,7 +151,7 @@ def split(node, *filters): break else: raise ValueError(f'No filter matched {path = } {value = }') - + states: tuple[nnx.GraphState, ...] = tuple( nnx.State.from_flat_path(flat_state) for flat_state in flat_states ) @@ -166,11 +166,11 @@ print(f'{params = }') print(f'{batch_stats = }') ``` -One very important thing to note is that **filtering is order-dependent**. The first filter that -matches a value will keep it, therefore you should place more specific filters before more general -filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar` -object that contains both types of parameters, if we try to split the `Param`s before the -`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group +One very important thing to note is that **filtering is order-dependent**. The first filter that +matches a value will keep it, therefore you should place more specific filters before more general +filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar` +object that contains both types of parameters, if we try to split the `Param`s before the +`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group will be empty because all `SpecialParam`s are also `Param`s: ```{code-cell} ipython3 diff --git a/docs_nnx/guides/graph_mutations.ipynb b/docs_nnx/guides/graph_mutations.ipynb new file mode 100644 index 0000000000..bfc9daefc8 --- /dev/null +++ b/docs_nnx/guides/graph_mutations.ipynb @@ -0,0 +1,23 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Propagating Graph Mutations\n", + "\n", + "**WORK IN PROGRESS 🚧**" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs_nnx/guides/graph_mutations.md b/docs_nnx/guides/graph_mutations.md new file mode 100644 index 0000000000..d9b2403437 --- /dev/null +++ b/docs_nnx/guides/graph_mutations.md @@ -0,0 +1,13 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +# Propagating Graph Mutations + +**WORK IN PROGRESS 🚧** diff --git a/docs/nnx/haiku_linen_vs_nnx.rst b/docs_nnx/guides/haiku_linen_vs_nnx.rst similarity index 100% rename from docs/nnx/haiku_linen_vs_nnx.rst rename to docs_nnx/guides/haiku_linen_vs_nnx.rst diff --git a/docs/nnx/images/stateful-transforms.png b/docs_nnx/guides/images/stateful-transforms.png similarity index 100% rename from docs/nnx/images/stateful-transforms.png rename to docs_nnx/guides/images/stateful-transforms.png diff --git a/docs_nnx/guides/index.rst b/docs_nnx/guides/index.rst new file mode 100644 index 0000000000..08bd478c04 --- /dev/null +++ b/docs_nnx/guides/index.rst @@ -0,0 +1,12 @@ +Guides +------------------------ + +.. toctree:: + :maxdepth: 2 + + filters_guide + haiku_linen_vs_nnx + bridge_guide + surgery + graph_mutations + jax_and_nnx_transforms \ No newline at end of file diff --git a/docs/nnx/transforms.rst b/docs_nnx/guides/jax_and_nnx_transforms.rst similarity index 100% rename from docs/nnx/transforms.rst rename to docs_nnx/guides/jax_and_nnx_transforms.rst diff --git a/docs/nnx/quick_start.ipynb b/docs_nnx/guides/quick_start.ipynb similarity index 100% rename from docs/nnx/quick_start.ipynb rename to docs_nnx/guides/quick_start.ipynb diff --git a/docs/nnx/surgery.ipynb b/docs_nnx/guides/surgery.ipynb similarity index 100% rename from docs/nnx/surgery.ipynb rename to docs_nnx/guides/surgery.ipynb diff --git a/docs/nnx/surgery.md b/docs_nnx/guides/surgery.md similarity index 100% rename from docs/nnx/surgery.md rename to docs_nnx/guides/surgery.md diff --git a/docs/nnx/tiny_nnx.ipynb b/docs_nnx/guides/tiny_nnx.ipynb similarity index 100% rename from docs/nnx/tiny_nnx.ipynb rename to docs_nnx/guides/tiny_nnx.ipynb diff --git a/docs/nnx/why.ipynb b/docs_nnx/guides/why.ipynb similarity index 100% rename from docs/nnx/why.ipynb rename to docs_nnx/guides/why.ipynb diff --git a/docs/nnx/why.md b/docs_nnx/guides/why.md similarity index 100% rename from docs/nnx/why.md rename to docs_nnx/guides/why.md diff --git a/docs/nnx/index.rst b/docs_nnx/index.rst similarity index 95% rename from docs/nnx/index.rst rename to docs_nnx/index.rst index 237eb24a7f..d5e0c9d34d 100644 --- a/docs/nnx/index.rst +++ b/docs_nnx/index.rst @@ -179,12 +179,11 @@ Learn more .. toctree:: :hidden: - :maxdepth: 1 + :maxdepth: 2 nnx_basics mnist_tutorial - transforms - haiku_linen_vs_nnx - bridge_guide - filters_guide - surgery + guides/index + The Flax philosophy + How to contribute + api_reference/index diff --git a/docs_nnx/linen_intro.ipynb b/docs_nnx/linen_intro.ipynb new file mode 100644 index 0000000000..5f82580fef --- /dev/null +++ b/docs_nnx/linen_intro.ipynb @@ -0,0 +1,1097 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/linen_intro.ipynb)\n", + "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/notebooks/linen_intro.ipynb)\n", + "\n", + "# Preface\n", + "\n", + "
\n", + "
CAVEAT PROGRAMMER
\n", + "\n", + "The below is an alpha API preview and things might break. The surface syntax of the features of the API are not fixed in stone, and we welcome feedback on any points." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Useful links\n", + "\n", + "⟶ [Slides](https://docs.google.com/presentation/d/1ngKWUwsSqAwPRvATG8sAxMzu9ujv4N__cKsUofdNno0/edit?usp=sharing) for the core ideas of the new Functional Core and Linen\n", + "\n", + "⟶ \"Design tests\" guided our design process. Many are available for [functional core](https://github.com/google/flax/tree/main/examples/core_design_test) and some for the [proposed Module abstraction](https://github.com/google/flax/tree/main/examples/linen_design_test/)\n", + "\n", + "⟶ Ported examples: [ImageNet](https://github.com/google/flax/tree/main/examples/imagenet) and [WMT](https://github.com/google/flax/tree/main/examples/wmt) (to the proposed Module abstraction). TODO: Port to functional core.\n", + "\n", + "⟶ Our new [discussion forums](https://github.com/google/flax/discussions/)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Install and Import" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "skip-execution" + ] + }, + "outputs": [], + "source": [ + "# Install the newest JAXlib version.\n", + "!pip install --upgrade -q pip jax jaxlib\n", + "# Install Flax at head:\n", + "!pip install --upgrade -q git+https://github.com/google/flax.git" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import functools\n", + "from typing import Any, Callable, Sequence, Optional\n", + "import jax\n", + "from jax import lax, random, numpy as jnp\n", + "import flax\n", + "from flax import linen as nn" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Invoking Modules" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's instantiate a `Dense` layer.\n", + " - Modules are actually objects in this API, so we provide _constructor arguments_ when initializing the Module. In this case, we only have to provide the output `features` dimension." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "model = nn.Dense(features=3)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We need to initialize the Module variables, these include the parameters of the Module as well as any other state variables.\n", + "\n", + "We call the `init` method on the instantiated Module. If the Module `__call__` method has args `(self, *args, **kwargs)` then we call `init` with `(rngs, *args, **kwargs)` so in this case, just `(rng, input)`:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "outputId": "3adfaeaf-977e-4e82-8adf-d254fae6eb91" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, + { + "data": { + "text/plain": [ + "FrozenDict({\n", + " params: {\n", + " kernel: DeviceArray([[ 0.6503669 , 0.8678979 , 0.46042678],\n", + " [ 0.05673932, 0.9909285 , -0.63536596],\n", + " [ 0.76134115, -0.3250529 , -0.6522163 ],\n", + " [-0.8243032 , 0.4150194 , 0.19405058]], dtype=float32),\n", + " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", + " },\n", + "})" + ] + }, + "execution_count": 4, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Make RNG Keys and a fake input.\n", + "key1, key2 = random.split(random.key(0), 2)\n", + "x = random.uniform(key1, (4,4))\n", + "\n", + "# provide key and fake input to get initialized variables\n", + "init_variables = model.init(key2, x)\n", + "\n", + "init_variables" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We call the `apply` method on the instantiated Module. If the Module `__call__` method has args `(self, *args, **kwargs)` then we call `apply` with `(variables, *args, rngs=, mutable=, **kwargs)` where\n", + " - `` are the optional _call time_ RNGs for things like dropout. For simple Modules this is just a single key, but if your module has multiple __kinds__ of data, it's a dictionary of rng-keys per-kind, e.g. `{'params': key0, 'dropout': key1}` for a Module with dropout layers.\n", + " - `` is an optional list of names of __kinds__ that are expected to be mutated during the call. e.g. `['batch_stats']` for a layer updating batchnorm statistics.\n", + "\n", + "So in this case, just `(variables, input)`:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "outputId": "e8c389a6-29f3-4f93-97ea-703e85a8b811" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([[ 0.5035518 , 1.8548559 , -0.4270196 ],\n", + " [ 0.0279097 , 0.5589246 , -0.43061775],\n", + " [ 0.35471284, 1.5741 , -0.3286552 ],\n", + " [ 0.5264864 , 1.2928858 , 0.10089308]], dtype=float32)" + ] + }, + "execution_count": 5, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "y = model.apply(init_variables, x)\n", + "y" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Additional points:\n", + " - If you want to `init` or `apply` a Module using a method other than call, you need to provide the `method=` kwarg to `init` and `apply` to use it instead of the default `__call__`, e.g. `method='encode'`, `method='decode'` to apply the encode/decode methods of an autoencoder." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Defining Basic Modules" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Composing submodules" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We support declaring modules in `setup()` that can still benefit from shape inference by using __Lazy Initialization__ that sets up variables the first time the Module is called." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "outputId": "1a6c6a17-0b95-42c2-b5bf-b9ad80fd7758", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initialized parameter shapes:\n", + " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", + "output:\n", + " [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n", + " -1.7147182e-02]\n", + " [ 1.2967804e-01 -1.4551792e-01 9.4432175e-02 1.2521386e-02\n", + " -4.5417294e-02]\n", + " [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n", + " 0.0000000e+00]\n", + " [ 9.3024090e-04 2.7864411e-05 2.4478839e-04 8.1344356e-04\n", + " -1.0110775e-03]]\n" + ] + } + ], + "source": [ + "class ExplicitMLP(nn.Module):\n", + " features: Sequence[int]\n", + "\n", + " def setup(self):\n", + " # we automatically know what to do with lists, dicts of submodules\n", + " self.layers = [nn.Dense(feat) for feat in self.features]\n", + " # for single submodules, we would just write:\n", + " # self.layer1 = nn.Dense(feat1)\n", + "\n", + " def __call__(self, inputs):\n", + " x = inputs\n", + " for i, lyr in enumerate(self.layers):\n", + " x = lyr(x)\n", + " if i != len(self.layers) - 1:\n", + " x = nn.relu(x)\n", + " return x\n", + "\n", + "key1, key2 = random.split(random.key(0), 2)\n", + "x = random.uniform(key1, (4,4))\n", + "\n", + "model = ExplicitMLP(features=[3,4,5])\n", + "init_variables = model.init(key2, x)\n", + "y = model.apply(init_variables, x)\n", + "\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", + "print('output:\\n', y)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we show the equivalent compact form of the MLP that declares the submodules inline using the `@compact` decorator." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "outputId": "b3709789-e66e-4e20-f6b2-04022f8a62bb", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initialized parameter shapes:\n", + " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", + "output:\n", + " [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n", + " -1.7147182e-02]\n", + " [ 1.2967804e-01 -1.4551792e-01 9.4432175e-02 1.2521386e-02\n", + " -4.5417294e-02]\n", + " [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n", + " 0.0000000e+00]\n", + " [ 9.3024090e-04 2.7864411e-05 2.4478839e-04 8.1344356e-04\n", + " -1.0110775e-03]]\n" + ] + } + ], + "source": [ + "class SimpleMLP(nn.Module):\n", + " features: Sequence[int]\n", + "\n", + " @nn.compact\n", + " def __call__(self, inputs):\n", + " x = inputs\n", + " for i, feat in enumerate(self.features):\n", + " x = nn.Dense(feat, name=f'layers_{i}')(x)\n", + " if i != len(self.features) - 1:\n", + " x = nn.relu(x)\n", + " # providing a name is optional though!\n", + " # the default autonames would be \"Dense_0\", \"Dense_1\", ...\n", + " # x = nn.Dense(feat)(x)\n", + " return x\n", + "\n", + "key1, key2 = random.split(random.key(0), 2)\n", + "x = random.uniform(key1, (4,4))\n", + "\n", + "model = SimpleMLP(features=[3,4,5])\n", + "init_variables = model.init(key2, x)\n", + "y = model.apply(init_variables, x)\n", + "\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", + "print('output:\\n', y)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Declaring and using variables" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Flax uses lazy initialization, which allows declared variables to be initialized only at the first site of their use, using whatever shape information is available a the local call site for shape inference. Once a variable has been initialized, a reference to the data is kept for use in subsequent calls.\n", + "\n", + "For declaring parameters that aren't mutated inside the model, but rather by gradient descent, we use the syntax:\n", + "\n", + " `self.param(parameter_name, parameter_init_fn, *init_args, **init_kwargs)`\n", + "\n", + "with arguments:\n", + " - `parameter_name` just the name, a string\n", + " - `parameter_init_fn` a function taking an RNG key and a variable number of other arguments, i.e. `fn(rng, *args)`. typically those in `nn.initializers` take an `rng` and a `shape` argument.\n", + " - the remaining arguments to feed to the init function when initializing.\n", + "\n", + "Again, we'll demonstrate declaring things inline as we typically do using the `@compact` decorator." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "outputId": "bc5cb1f2-c5e9-4159-d131-73247009e32f", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initialized parameters:\n", + " FrozenDict({\n", + " params: {\n", + " kernel: DeviceArray([[ 0.6503669 , 0.8678979 , 0.46042678],\n", + " [ 0.05673932, 0.9909285 , -0.63536596],\n", + " [ 0.76134115, -0.3250529 , -0.6522163 ],\n", + " [-0.8243032 , 0.4150194 , 0.19405058]], dtype=float32),\n", + " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", + " },\n", + "})\n", + "output:\n", + " [[ 0.5035518 1.8548559 -0.4270196 ]\n", + " [ 0.0279097 0.5589246 -0.43061775]\n", + " [ 0.35471284 1.5741 -0.3286552 ]\n", + " [ 0.5264864 1.2928858 0.10089308]]\n" + ] + } + ], + "source": [ + "class SimpleDense(nn.Module):\n", + " features: int\n", + " kernel_init: Callable = nn.initializers.lecun_normal()\n", + " bias_init: Callable = nn.initializers.zeros_init()\n", + "\n", + " @nn.compact\n", + " def __call__(self, inputs):\n", + " kernel = self.param('kernel',\n", + " self.kernel_init, # RNG passed implicitly.\n", + " (inputs.shape[-1], self.features)) # shape info.\n", + " y = lax.dot_general(inputs, kernel,\n", + " (((inputs.ndim - 1,), (0,)), ((), ())),)\n", + " bias = self.param('bias', self.bias_init, (self.features,))\n", + " y = y + bias\n", + " return y\n", + "\n", + "key1, key2 = random.split(random.key(0), 2)\n", + "x = random.uniform(key1, (4,4))\n", + "\n", + "model = SimpleDense(features=3)\n", + "init_variables = model.init(key2, x)\n", + "y = model.apply(init_variables, x)\n", + "\n", + "print('initialized parameters:\\n', init_variables)\n", + "print('output:\\n', y)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also declare variables in setup, though in doing so you can't take advantage of shape inference and have to provide explicit shape information at initialization. The syntax is a little repetitive in this case right now, but we do force agreement of the assigned names." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "outputId": "1e822bd8-7a08-4e80-e0e6-a86637c46772", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initialized parameters:\n", + " FrozenDict({\n", + " params: {\n", + " kernel: DeviceArray([[ 0.6503669 , 0.8678979 , 0.46042678],\n", + " [ 0.05673932, 0.9909285 , -0.63536596],\n", + " [ 0.76134115, -0.3250529 , -0.6522163 ],\n", + " [-0.8243032 , 0.4150194 , 0.19405058]], dtype=float32),\n", + " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", + " },\n", + "})\n", + "output:\n", + " [[ 0.5035518 1.8548559 -0.4270196 ]\n", + " [ 0.0279097 0.5589246 -0.43061775]\n", + " [ 0.35471284 1.5741 -0.3286552 ]\n", + " [ 0.5264864 1.2928858 0.10089308]]\n" + ] + } + ], + "source": [ + "class ExplicitDense(nn.Module):\n", + " features_in: int # <-- explicit input shape\n", + " features: int\n", + " kernel_init: Callable = nn.initializers.lecun_normal()\n", + " bias_init: Callable = nn.initializers.zeros_init()\n", + "\n", + " def setup(self):\n", + " self.kernel = self.param('kernel',\n", + " self.kernel_init,\n", + " (self.features_in, self.features))\n", + " self.bias = self.param('bias', self.bias_init, (self.features,))\n", + "\n", + " def __call__(self, inputs):\n", + " y = lax.dot_general(inputs, self.kernel,\n", + " (((inputs.ndim - 1,), (0,)), ((), ())),)\n", + " y = y + self.bias\n", + " return y\n", + "\n", + "key1, key2 = random.split(random.key(0), 2)\n", + "x = random.uniform(key1, (4,4))\n", + "\n", + "model = ExplicitDense(features_in=4, features=3)\n", + "init_variables = model.init(key2, x)\n", + "y = model.apply(init_variables, x)\n", + "\n", + "print('initialized parameters:\\n', init_variables)\n", + "print('output:\\n', y)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## General Variables" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For declaring generally mutable _variables_ that may be mutated inside the model we use the call:\n", + "\n", + " `self.variable(variable_kind, variable_name, variable_init_fn, *init_args, **init_kwargs)`\n", + "\n", + "with arguments:\n", + " - `variable_kind` the \"kind\" of state this variable is, i.e. the name of the nested-dict collection that this will be stored in inside the top Modules variables. e.g. `batch_stats` for the moving statistics for a batch norm layer or `cache` for autoregressive cache data. Note that parameters also have a kind, but they're set to the default `param` kind.\n", + " - `variable_name` just the name, a string\n", + " - `variable_init_fn` a function taking a variable number of other arguments, i.e. `fn(*args)`. Note that we __don't__ assume the need for an RNG, if you _do_ want an RNG, provide it via a `self.make_rng(variable_kind)` call in the provided arguments.\n", + " - the remaining arguments to feed to the init function when initializing.\n", + "\n", + "⚠️ Unlike parameters, we expect these to be mutated, so `self.variable` returns not a constant, but a _reference_ to the variable. To __get__ the raw value, you'd write `myvariable.value` and to __set__ it `myvariable.value = new_value`." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "outputId": "2a8f5453-81b1-44dc-a431-d14b372c5710", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initialized variables:\n", + " FrozenDict({\n", + " counter: {\n", + " count: DeviceArray(0, dtype=int32),\n", + " },\n", + "})\n", + "mutated variables:\n", + " FrozenDict({\n", + " counter: {\n", + " count: DeviceArray(1, dtype=int32),\n", + " },\n", + "})\n", + "output:\n", + " 1\n" + ] + } + ], + "source": [ + "class Counter(nn.Module):\n", + " @nn.compact\n", + " def __call__(self):\n", + " # easy pattern to detect if we're initializing\n", + " is_initialized = self.has_variable('counter', 'count')\n", + " counter = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32))\n", + " if is_initialized:\n", + " counter.value += 1\n", + " return counter.value\n", + "\n", + "\n", + "key1 = random.key(0)\n", + "\n", + "model = Counter()\n", + "init_variables = model.init(key1)\n", + "print('initialized variables:\\n', init_variables)\n", + "\n", + "y, mutated_variables = model.apply(init_variables, mutable=['counter'])\n", + "\n", + "print('mutated variables:\\n', mutated_variables)\n", + "print('output:\\n', y)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Another Mutability and RNGs Example" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's make an artificial, goofy example that mixes differentiable parameters, stochastic layers, and mutable variables:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "outputId": "8f299a5c-74c8-476c-93fa-e5543901ec45", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "updated variables:\n", + " FrozenDict({\n", + " params: {\n", + " Dense_0: {\n", + " kernel: DeviceArray([[ 0.6498898 , -0.5000124 , 0.78573596],\n", + " [-0.25609785, -0.7132329 , 0.2500864 ],\n", + " [-0.64630085, 0.39321756, -1.0203307 ],\n", + " [ 0.38721725, 0.86828285, 0.10860055]], dtype=float32),\n", + " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", + " },\n", + " BatchNorm_0: {\n", + " scale: DeviceArray([1., 1., 1.], dtype=float32),\n", + " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", + " },\n", + " },\n", + " batch_stats: {\n", + " BatchNorm_0: {\n", + " mean: DeviceArray([ 0.00059601, -0.00103457, 0.00166948], dtype=float32),\n", + " var: DeviceArray([0.9907686, 0.9923046, 0.992195 ], dtype=float32),\n", + " },\n", + " },\n", + "})\n", + "initialized variable shapes:\n", + " FrozenDict({\n", + " batch_stats: {\n", + " BatchNorm_0: {\n", + " mean: (3,),\n", + " var: (3,),\n", + " },\n", + " },\n", + " params: {\n", + " BatchNorm_0: {\n", + " bias: (3,),\n", + " scale: (3,),\n", + " },\n", + " Dense_0: {\n", + " bias: (3,),\n", + " kernel: (4, 3),\n", + " },\n", + " },\n", + "})\n", + "output:\n", + " [[[-0.21496922 0.21550177 -0.35633382]\n", + " [-0.21496922 -2.0458 1.3015485 ]\n", + " [-0.21496922 -0.925116 -0.35633382]\n", + " [-0.6595459 0.21550177 0.3749205 ]]\n", + "\n", + " [[-0.21496922 1.642865 -0.35633382]\n", + " [-0.21496922 1.3094063 -0.88034123]\n", + " [ 2.5726683 0.21550177 0.34353197]\n", + " [-0.21496922 0.21550177 1.6778195 ]]\n", + "\n", + " [[-1.6060593 0.21550177 -1.9460517 ]\n", + " [ 1.4126908 -1.4898677 1.2790381 ]\n", + " [-0.21496922 0.21550177 -0.35633382]\n", + " [-0.21496922 0.21550177 -0.7251308 ]]]\n", + "eval output:\n", + " [[[ 3.2246590e-01 2.6108384e-02 4.4821960e-01]\n", + " [ 8.5726947e-02 -5.4385906e-01 3.8821870e-01]\n", + " [-2.3933809e-01 -2.7381191e-01 -1.7526165e-01]\n", + " [-6.2515378e-02 -5.2414006e-01 1.7029770e-01]]\n", + "\n", + " [[ 1.5014435e-01 3.4498507e-01 -1.3554120e-01]\n", + " [-3.6971044e-04 2.6463276e-01 -1.2491019e-01]\n", + " [ 3.8763803e-01 2.9023719e-01 1.6291586e-01]\n", + " [ 4.1320035e-01 4.1468274e-02 4.7670874e-01]]\n", + "\n", + " [[-1.9433719e-01 5.2831882e-01 -3.7554008e-01]\n", + " [ 2.2608691e-01 -4.0989807e-01 3.8292480e-01]\n", + " [-2.4945706e-01 1.6170470e-01 -2.5247774e-01]\n", + " [-7.2220474e-02 1.2077977e-01 -8.8408351e-02]]]\n" + ] + } + ], + "source": [ + "class Block(nn.Module):\n", + " features: int\n", + " training: bool\n", + " @nn.compact\n", + " def __call__(self, inputs):\n", + " x = nn.Dense(self.features)(inputs)\n", + " x = nn.Dropout(rate=0.5)(x, deterministic=not self.training)\n", + " x = nn.BatchNorm(use_running_average=not self.training)(x)\n", + " return x\n", + "\n", + "key1, key2, key3, key4 = random.split(random.key(0), 4)\n", + "x = random.uniform(key1, (3,4,4))\n", + "\n", + "model = Block(features=3, training=True)\n", + "\n", + "init_variables = model.init({'params': key2, 'dropout': key3}, x)\n", + "_, init_params = flax.core.pop(init_variables, 'params')\n", + "\n", + "# When calling `apply` with mutable kinds, returns a pair of output,\n", + "# mutated_variables.\n", + "y, mutated_variables = model.apply(\n", + " init_variables, x, rngs={'dropout': key4}, mutable=['batch_stats'])\n", + "\n", + "# Now we reassemble the full variables from the updates (in a real training\n", + "# loop, with the updated params from an optimizer).\n", + "updated_variables = flax.core.freeze(dict(params=init_params,\n", + " **mutated_variables))\n", + "\n", + "print('updated variables:\\n', updated_variables)\n", + "print('initialized variable shapes:\\n',\n", + " jax.tree_util.tree_map(jnp.shape, init_variables))\n", + "print('output:\\n', y)\n", + "\n", + "# Let's run these model variables during \"evaluation\":\n", + "eval_model = Block(features=3, training=False)\n", + "y = eval_model.apply(updated_variables, x) # Nothing mutable; single return value.\n", + "print('eval output:\\n', y)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# JAX transformations inside modules" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## JIT" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It's not immediately clear what use this has, but you can compile specific submodules if there's a reason to.\n", + "\n", + "_Known Gotcha_: at the moment, the decorator changes the RNG stream slightly, so comparing jitted an unjitted initializations will look different." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "outputId": "3f324d0f-259f-40f0-8273-103f7fc281c5", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initialized parameter shapes:\n", + " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", + "output:\n", + " [[ 0.2524199 0.11621253 0.5246693 0.19144788 0.2096542 ]\n", + " [ 0.08557513 -0.04126885 0.2502836 0.03910369 0.16575359]\n", + " [ 0.2804383 0.27751124 0.44969672 0.26016283 0.05875347]\n", + " [ 0.2440843 0.17069656 0.45499086 0.20377949 0.13428023]]\n" + ] + } + ], + "source": [ + "class MLP(nn.Module):\n", + " features: Sequence[int]\n", + "\n", + " @nn.compact\n", + " def __call__(self, inputs):\n", + " x = inputs\n", + " for i, feat in enumerate(self.features):\n", + " # JIT the Module (it's __call__ fn by default.)\n", + " x = nn.jit(nn.Dense)(feat, name=f'layers_{i}')(x)\n", + " if i != len(self.features) - 1:\n", + " x = nn.relu(x)\n", + " return x\n", + "\n", + "key1, key2 = random.split(random.key(3), 2)\n", + "x = random.uniform(key1, (4,4))\n", + "\n", + "model = MLP(features=[3,4,5])\n", + "init_variables = model.init(key2, x)\n", + "y = model.apply(init_variables, x)\n", + "\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", + "print('output:\\n', y)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Remat" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For memory-expensive computations, we can `remat` our method to recompute a Module's output during a backwards pass.\n", + "\n", + "_Known Gotcha_: at the moment, the decorator changes the RNG stream slightly, so comparing remat'd and undecorated initializations will look different." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "outputId": "7fe8e13b-7dd6-4e55-ee50-ce334e8ed178", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initialized parameter shapes:\n", + " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", + "output:\n", + " [[-0.14814317 0.06889858 -0.19695625 0.12019286 0.02068037]\n", + " [-0.04439102 -0.06698258 -0.11579747 -0.19906905 -0.04342325]\n", + " [-0.08875751 -0.13392815 -0.23153095 -0.39802808 -0.0868225 ]\n", + " [-0.01606487 -0.02424064 -0.04190649 -0.07204203 -0.01571464]]\n" + ] + } + ], + "source": [ + "class RematMLP(nn.Module):\n", + " features: Sequence[int]\n", + " # For all transforms, we can annotate a method, or wrap an existing\n", + " # Module class. Here we annotate the method.\n", + " @nn.remat\n", + " @nn.compact\n", + " def __call__(self, inputs):\n", + " x = inputs\n", + " for i, feat in enumerate(self.features):\n", + " x = nn.Dense(feat, name=f'layers_{i}')(x)\n", + " if i != len(self.features) - 1:\n", + " x = nn.relu(x)\n", + " return x\n", + "\n", + "key1, key2 = random.split(random.key(3), 2)\n", + "x = random.uniform(key1, (4,4))\n", + "\n", + "model = RematMLP(features=[3,4,5])\n", + "init_variables = model.init(key2, x)\n", + "y = model.apply(init_variables, x)\n", + "\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", + "print('output:\\n', y)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Vmap" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can now `vmap` Modules inside. The transform has a lot of arguments, they have the usual jax vmap args:\n", + " - `in_axes` - an integer or `None` for each input argument\n", + " - `out_axes` - an integer or `None` for each output argument\n", + " - `axis_size` - the axis size if you need to give it explicitly\n", + "\n", + "In addition, we provide for each __kind__ of variable it's axis rules:\n", + "\n", + " - `variable_in_axes` - a dict from kinds to a single integer or `None` specifying the input axes to map\n", + " - `variable_out_axes` - a dict from kinds to a single integer or `None` specifying the output axes to map\n", + " - `split_rngs` - a dict from RNG-kinds to a bool, specifying whether to split the rng along the axis.\n", + "\n", + "\n", + "Below we show an example defining a batched, multiheaded attention module from a single-headed unbatched attention implementation." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "outputId": "223d880e-c7b2-4210-ebb5-dbfcdd9aed09", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initialized parameter shapes:\n", + " {'params': {'attention': {'key': {'kernel': (2, 64, 32)}, 'out': {'bias': (2, 64), 'kernel': (2, 32, 64)}, 'query': {'kernel': (2, 64, 32)}, 'value': {'kernel': (2, 64, 32)}}}}\n", + "output:\n", + " (3, 13, 2)\n" + ] + } + ], + "source": [ + "class RawDotProductAttention(nn.Module):\n", + " attn_dropout_rate: float = 0.1\n", + " train: bool = False\n", + "\n", + " @nn.compact\n", + " def __call__(self, query, key, value, bias=None, dtype=jnp.float32):\n", + " assert key.ndim == query.ndim\n", + " assert key.ndim == value.ndim\n", + "\n", + " n = query.ndim\n", + " attn_weights = lax.dot_general(\n", + " query, key,\n", + " (((n-1,), (n - 1,)), ((), ())))\n", + " if bias is not None:\n", + " attn_weights += bias\n", + " norm_dims = tuple(range(attn_weights.ndim // 2, attn_weights.ndim))\n", + " attn_weights = jax.nn.softmax(attn_weights, axis=norm_dims)\n", + " attn_weights = nn.Dropout(self.attn_dropout_rate)(attn_weights,\n", + " deterministic=not self.train)\n", + " attn_weights = attn_weights.astype(dtype)\n", + "\n", + " contract_dims = (\n", + " tuple(range(n - 1, attn_weights.ndim)),\n", + " tuple(range(0, n - 1)))\n", + " y = lax.dot_general(\n", + " attn_weights, value,\n", + " (contract_dims, ((), ())))\n", + " return y\n", + "\n", + "class DotProductAttention(nn.Module):\n", + " qkv_features: Optional[int] = None\n", + " out_features: Optional[int] = None\n", + " train: bool = False\n", + "\n", + " @nn.compact\n", + " def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):\n", + " qkv_features = self.qkv_features or inputs_q.shape[-1]\n", + " out_features = self.out_features or inputs_q.shape[-1]\n", + "\n", + " QKVDense = functools.partial(\n", + " nn.Dense, features=qkv_features, use_bias=False, dtype=dtype)\n", + " query = QKVDense(name='query')(inputs_q)\n", + " key = QKVDense(name='key')(inputs_kv)\n", + " value = QKVDense(name='value')(inputs_kv)\n", + "\n", + " y = RawDotProductAttention(train=self.train)(\n", + " query, key, value, bias=bias, dtype=dtype)\n", + "\n", + " y = nn.Dense(features=out_features, dtype=dtype, name='out')(y)\n", + " return y\n", + "\n", + "class MultiHeadDotProductAttention(nn.Module):\n", + " qkv_features: Optional[int] = None\n", + " out_features: Optional[int] = None\n", + " batch_axes: Sequence[int] = (0,)\n", + " num_heads: int = 1\n", + " broadcast_dropout: bool = False\n", + " train: bool = False\n", + " @nn.compact\n", + " def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):\n", + " qkv_features = self.qkv_features or inputs_q.shape[-1]\n", + " out_features = self.out_features or inputs_q.shape[-1]\n", + "\n", + " # Make multiheaded attention from single-headed dimension.\n", + " Attn = nn.vmap(DotProductAttention,\n", + " in_axes=(None, None, None),\n", + " out_axes=2,\n", + " axis_size=self.num_heads,\n", + " variable_axes={'params': 0},\n", + " split_rngs={'params': True,\n", + " 'dropout': not self.broadcast_dropout})\n", + "\n", + " # Vmap across batch dimensions.\n", + " for axis in reversed(sorted(self.batch_axes)):\n", + " Attn = nn.vmap(Attn,\n", + " in_axes=(axis, axis, axis),\n", + " out_axes=axis,\n", + " variable_axes={'params': None},\n", + " split_rngs={'params': False, 'dropout': False})\n", + "\n", + " # Run the vmap'd class on inputs.\n", + " y = Attn(qkv_features=qkv_features // self.num_heads,\n", + " out_features=out_features,\n", + " train=self.train,\n", + " name='attention')(inputs_q, inputs_kv, bias)\n", + "\n", + " return y.mean(axis=-2)\n", + "\n", + "\n", + "key1, key2, key3, key4 = random.split(random.key(0), 4)\n", + "x = random.uniform(key1, (3, 13, 64))\n", + "\n", + "model = functools.partial(\n", + " MultiHeadDotProductAttention,\n", + " broadcast_dropout=False,\n", + " num_heads=2,\n", + " batch_axes=(0,))\n", + "\n", + "init_variables = model(train=False).init({'params': key2}, x, x)\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", + "\n", + "y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})\n", + "print('output:\\n', y.shape)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scan" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Scan allows us to apply `lax.scan` to Modules, including their parameters and mutable variables. To use it we have to specify how we want each \"kind\" of variable to be transformed. For scanned variables we specify similar to vmap via in `variable_in_axes`, `variable_out_axes`:\n", + " - `nn.broadcast` broadcast the variable kind across the scan steps as a constant\n", + " - `` scan along `axis` for e.g. unique parameters at each step\n", + "\n", + "OR we specify that the variable kind is to be treated like a \"carry\" by passing to the `variable_carry` argument.\n", + "\n", + "Further, for `scan`'d variable kinds, we further specify whether or not to split the rng at each step." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "outputId": "7d9ebed3-64de-4ca8-9dce-4b09ba9e31a1", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initialized parameter shapes:\n", + " {'params': {'lstm_cell': {'hf': {'bias': (2,), 'kernel': (2, 2)}, 'hg': {'bias': (2,), 'kernel': (2, 2)}, 'hi': {'bias': (2,), 'kernel': (2, 2)}, 'ho': {'bias': (2,), 'kernel': (2, 2)}, 'if': {'kernel': (2, 2)}, 'ig': {'kernel': (2, 2)}, 'ii': {'kernel': (2, 2)}, 'io': {'kernel': (2, 2)}}}}\n", + "output:\n", + " ((DeviceArray([[-0.562219 , 0.92847174]], dtype=float32), DeviceArray([[-0.31570646, 0.2885693 ]], dtype=float32)), DeviceArray([[[-0.08265854, 0.01302483],\n", + " [-0.10249066, 0.21991298],\n", + " [-0.26609066, 0.22519003],\n", + " [-0.27982554, 0.28393182],\n", + " [-0.31570646, 0.2885693 ]]], dtype=float32))\n" + ] + } + ], + "source": [ + "class SimpleScan(nn.Module):\n", + " features: int\n", + "\n", + " @nn.compact\n", + " def __call__(self, xs):\n", + " LSTM = nn.scan(nn.LSTMCell,\n", + " in_axes=1, out_axes=1,\n", + " variable_broadcast='params',\n", + " split_rngs={'params': False})\n", + " lstm = LSTM(self.features, name=\"lstm_cell\")\n", + "\n", + " dummy_rng = random.key(0)\n", + " input_shape = xs[:, 0].shape\n", + " init_carry = lstm.initialize_carry(dummy_rng, input_shape)\n", + "\n", + " return lstm(init_carry, xs)\n", + "\n", + "key1, key2 = random.split(random.key(0), 2)\n", + "xs = random.uniform(key1, (1, 5, 2))\n", + "\n", + "model = SimpleScan(2)\n", + "init_variables = model.init(key2, xs)\n", + "\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", + "\n", + "y = model.apply(init_variables, xs)\n", + "print('output:\\n', y)" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "language_info": { + "name": "python", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs_nnx/linen_intro.md b/docs_nnx/linen_intro.md new file mode 100644 index 0000000000..beea4c014b --- /dev/null +++ b/docs_nnx/linen_intro.md @@ -0,0 +1,597 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/linen_intro.ipynb) +[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/notebooks/linen_intro.ipynb) + +# Preface + +
+
CAVEAT PROGRAMMER
+ +The below is an alpha API preview and things might break. The surface syntax of the features of the API are not fixed in stone, and we welcome feedback on any points. + ++++ + +## Useful links + +⟶ [Slides](https://docs.google.com/presentation/d/1ngKWUwsSqAwPRvATG8sAxMzu9ujv4N__cKsUofdNno0/edit?usp=sharing) for the core ideas of the new Functional Core and Linen + +⟶ "Design tests" guided our design process. Many are available for [functional core](https://github.com/google/flax/tree/main/examples/core_design_test) and some for the [proposed Module abstraction](https://github.com/google/flax/tree/main/examples/linen_design_test/) + +⟶ Ported examples: [ImageNet](https://github.com/google/flax/tree/main/examples/imagenet) and [WMT](https://github.com/google/flax/tree/main/examples/wmt) (to the proposed Module abstraction). TODO: Port to functional core. + +⟶ Our new [discussion forums](https://github.com/google/flax/discussions/) + ++++ + +# Install and Import + +```{code-cell} +:tags: [skip-execution] + +# Install the newest JAXlib version. +!pip install --upgrade -q pip jax jaxlib +# Install Flax at head: +!pip install --upgrade -q git+https://github.com/google/flax.git +``` + +```{code-cell} +import functools +from typing import Any, Callable, Sequence, Optional +import jax +from jax import lax, random, numpy as jnp +import flax +from flax import linen as nn +``` + +# Invoking Modules + ++++ + +Let's instantiate a `Dense` layer. + - Modules are actually objects in this API, so we provide _constructor arguments_ when initializing the Module. In this case, we only have to provide the output `features` dimension. + +```{code-cell} +model = nn.Dense(features=3) +``` + +We need to initialize the Module variables, these include the parameters of the Module as well as any other state variables. + +We call the `init` method on the instantiated Module. If the Module `__call__` method has args `(self, *args, **kwargs)` then we call `init` with `(rngs, *args, **kwargs)` so in this case, just `(rng, input)`: + +```{code-cell} +:outputId: 3adfaeaf-977e-4e82-8adf-d254fae6eb91 + +# Make RNG Keys and a fake input. +key1, key2 = random.split(random.key(0), 2) +x = random.uniform(key1, (4,4)) + +# provide key and fake input to get initialized variables +init_variables = model.init(key2, x) + +init_variables +``` + +We call the `apply` method on the instantiated Module. If the Module `__call__` method has args `(self, *args, **kwargs)` then we call `apply` with `(variables, *args, rngs=, mutable=, **kwargs)` where + - `` are the optional _call time_ RNGs for things like dropout. For simple Modules this is just a single key, but if your module has multiple __kinds__ of data, it's a dictionary of rng-keys per-kind, e.g. `{'params': key0, 'dropout': key1}` for a Module with dropout layers. + - `` is an optional list of names of __kinds__ that are expected to be mutated during the call. e.g. `['batch_stats']` for a layer updating batchnorm statistics. + +So in this case, just `(variables, input)`: + +```{code-cell} +:outputId: e8c389a6-29f3-4f93-97ea-703e85a8b811 + +y = model.apply(init_variables, x) +y +``` + +Additional points: + - If you want to `init` or `apply` a Module using a method other than call, you need to provide the `method=` kwarg to `init` and `apply` to use it instead of the default `__call__`, e.g. `method='encode'`, `method='decode'` to apply the encode/decode methods of an autoencoder. + ++++ + +# Defining Basic Modules + ++++ + +## Composing submodules + ++++ + +We support declaring modules in `setup()` that can still benefit from shape inference by using __Lazy Initialization__ that sets up variables the first time the Module is called. + +```{code-cell} +:outputId: 1a6c6a17-0b95-42c2-b5bf-b9ad80fd7758 +:tags: [] + +class ExplicitMLP(nn.Module): + features: Sequence[int] + + def setup(self): + # we automatically know what to do with lists, dicts of submodules + self.layers = [nn.Dense(feat) for feat in self.features] + # for single submodules, we would just write: + # self.layer1 = nn.Dense(feat1) + + def __call__(self, inputs): + x = inputs + for i, lyr in enumerate(self.layers): + x = lyr(x) + if i != len(self.layers) - 1: + x = nn.relu(x) + return x + +key1, key2 = random.split(random.key(0), 2) +x = random.uniform(key1, (4,4)) + +model = ExplicitMLP(features=[3,4,5]) +init_variables = model.init(key2, x) +y = model.apply(init_variables, x) + +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) +print('output:\n', y) +``` + +Here we show the equivalent compact form of the MLP that declares the submodules inline using the `@compact` decorator. + +```{code-cell} +:outputId: b3709789-e66e-4e20-f6b2-04022f8a62bb +:tags: [] + +class SimpleMLP(nn.Module): + features: Sequence[int] + + @nn.compact + def __call__(self, inputs): + x = inputs + for i, feat in enumerate(self.features): + x = nn.Dense(feat, name=f'layers_{i}')(x) + if i != len(self.features) - 1: + x = nn.relu(x) + # providing a name is optional though! + # the default autonames would be "Dense_0", "Dense_1", ... + # x = nn.Dense(feat)(x) + return x + +key1, key2 = random.split(random.key(0), 2) +x = random.uniform(key1, (4,4)) + +model = SimpleMLP(features=[3,4,5]) +init_variables = model.init(key2, x) +y = model.apply(init_variables, x) + +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) +print('output:\n', y) +``` + +## Declaring and using variables + ++++ + +Flax uses lazy initialization, which allows declared variables to be initialized only at the first site of their use, using whatever shape information is available a the local call site for shape inference. Once a variable has been initialized, a reference to the data is kept for use in subsequent calls. + +For declaring parameters that aren't mutated inside the model, but rather by gradient descent, we use the syntax: + + `self.param(parameter_name, parameter_init_fn, *init_args, **init_kwargs)` + +with arguments: + - `parameter_name` just the name, a string + - `parameter_init_fn` a function taking an RNG key and a variable number of other arguments, i.e. `fn(rng, *args)`. typically those in `nn.initializers` take an `rng` and a `shape` argument. + - the remaining arguments to feed to the init function when initializing. + +Again, we'll demonstrate declaring things inline as we typically do using the `@compact` decorator. + +```{code-cell} +:outputId: bc5cb1f2-c5e9-4159-d131-73247009e32f +:tags: [] + +class SimpleDense(nn.Module): + features: int + kernel_init: Callable = nn.initializers.lecun_normal() + bias_init: Callable = nn.initializers.zeros_init() + + @nn.compact + def __call__(self, inputs): + kernel = self.param('kernel', + self.kernel_init, # RNG passed implicitly. + (inputs.shape[-1], self.features)) # shape info. + y = lax.dot_general(inputs, kernel, + (((inputs.ndim - 1,), (0,)), ((), ())),) + bias = self.param('bias', self.bias_init, (self.features,)) + y = y + bias + return y + +key1, key2 = random.split(random.key(0), 2) +x = random.uniform(key1, (4,4)) + +model = SimpleDense(features=3) +init_variables = model.init(key2, x) +y = model.apply(init_variables, x) + +print('initialized parameters:\n', init_variables) +print('output:\n', y) +``` + +We can also declare variables in setup, though in doing so you can't take advantage of shape inference and have to provide explicit shape information at initialization. The syntax is a little repetitive in this case right now, but we do force agreement of the assigned names. + +```{code-cell} +:outputId: 1e822bd8-7a08-4e80-e0e6-a86637c46772 +:tags: [] + +class ExplicitDense(nn.Module): + features_in: int # <-- explicit input shape + features: int + kernel_init: Callable = nn.initializers.lecun_normal() + bias_init: Callable = nn.initializers.zeros_init() + + def setup(self): + self.kernel = self.param('kernel', + self.kernel_init, + (self.features_in, self.features)) + self.bias = self.param('bias', self.bias_init, (self.features,)) + + def __call__(self, inputs): + y = lax.dot_general(inputs, self.kernel, + (((inputs.ndim - 1,), (0,)), ((), ())),) + y = y + self.bias + return y + +key1, key2 = random.split(random.key(0), 2) +x = random.uniform(key1, (4,4)) + +model = ExplicitDense(features_in=4, features=3) +init_variables = model.init(key2, x) +y = model.apply(init_variables, x) + +print('initialized parameters:\n', init_variables) +print('output:\n', y) +``` + +## General Variables + ++++ + +For declaring generally mutable _variables_ that may be mutated inside the model we use the call: + + `self.variable(variable_kind, variable_name, variable_init_fn, *init_args, **init_kwargs)` + +with arguments: + - `variable_kind` the "kind" of state this variable is, i.e. the name of the nested-dict collection that this will be stored in inside the top Modules variables. e.g. `batch_stats` for the moving statistics for a batch norm layer or `cache` for autoregressive cache data. Note that parameters also have a kind, but they're set to the default `param` kind. + - `variable_name` just the name, a string + - `variable_init_fn` a function taking a variable number of other arguments, i.e. `fn(*args)`. Note that we __don't__ assume the need for an RNG, if you _do_ want an RNG, provide it via a `self.make_rng(variable_kind)` call in the provided arguments. + - the remaining arguments to feed to the init function when initializing. + +⚠️ Unlike parameters, we expect these to be mutated, so `self.variable` returns not a constant, but a _reference_ to the variable. To __get__ the raw value, you'd write `myvariable.value` and to __set__ it `myvariable.value = new_value`. + +```{code-cell} +:outputId: 2a8f5453-81b1-44dc-a431-d14b372c5710 +:tags: [] + +class Counter(nn.Module): + @nn.compact + def __call__(self): + # easy pattern to detect if we're initializing + is_initialized = self.has_variable('counter', 'count') + counter = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32)) + if is_initialized: + counter.value += 1 + return counter.value + + +key1 = random.key(0) + +model = Counter() +init_variables = model.init(key1) +print('initialized variables:\n', init_variables) + +y, mutated_variables = model.apply(init_variables, mutable=['counter']) + +print('mutated variables:\n', mutated_variables) +print('output:\n', y) +``` + +## Another Mutability and RNGs Example + ++++ + +Let's make an artificial, goofy example that mixes differentiable parameters, stochastic layers, and mutable variables: + +```{code-cell} +:outputId: 8f299a5c-74c8-476c-93fa-e5543901ec45 +:tags: [] + +class Block(nn.Module): + features: int + training: bool + @nn.compact + def __call__(self, inputs): + x = nn.Dense(self.features)(inputs) + x = nn.Dropout(rate=0.5)(x, deterministic=not self.training) + x = nn.BatchNorm(use_running_average=not self.training)(x) + return x + +key1, key2, key3, key4 = random.split(random.key(0), 4) +x = random.uniform(key1, (3,4,4)) + +model = Block(features=3, training=True) + +init_variables = model.init({'params': key2, 'dropout': key3}, x) +_, init_params = flax.core.pop(init_variables, 'params') + +# When calling `apply` with mutable kinds, returns a pair of output, +# mutated_variables. +y, mutated_variables = model.apply( + init_variables, x, rngs={'dropout': key4}, mutable=['batch_stats']) + +# Now we reassemble the full variables from the updates (in a real training +# loop, with the updated params from an optimizer). +updated_variables = flax.core.freeze(dict(params=init_params, + **mutated_variables)) + +print('updated variables:\n', updated_variables) +print('initialized variable shapes:\n', + jax.tree_util.tree_map(jnp.shape, init_variables)) +print('output:\n', y) + +# Let's run these model variables during "evaluation": +eval_model = Block(features=3, training=False) +y = eval_model.apply(updated_variables, x) # Nothing mutable; single return value. +print('eval output:\n', y) +``` + +# JAX transformations inside modules + ++++ + +## JIT + ++++ + +It's not immediately clear what use this has, but you can compile specific submodules if there's a reason to. + +_Known Gotcha_: at the moment, the decorator changes the RNG stream slightly, so comparing jitted an unjitted initializations will look different. + +```{code-cell} +:outputId: 3f324d0f-259f-40f0-8273-103f7fc281c5 +:tags: [] + +class MLP(nn.Module): + features: Sequence[int] + + @nn.compact + def __call__(self, inputs): + x = inputs + for i, feat in enumerate(self.features): + # JIT the Module (it's __call__ fn by default.) + x = nn.jit(nn.Dense)(feat, name=f'layers_{i}')(x) + if i != len(self.features) - 1: + x = nn.relu(x) + return x + +key1, key2 = random.split(random.key(3), 2) +x = random.uniform(key1, (4,4)) + +model = MLP(features=[3,4,5]) +init_variables = model.init(key2, x) +y = model.apply(init_variables, x) + +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) +print('output:\n', y) +``` + +## Remat + ++++ + +For memory-expensive computations, we can `remat` our method to recompute a Module's output during a backwards pass. + +_Known Gotcha_: at the moment, the decorator changes the RNG stream slightly, so comparing remat'd and undecorated initializations will look different. + +```{code-cell} +:outputId: 7fe8e13b-7dd6-4e55-ee50-ce334e8ed178 +:tags: [] + +class RematMLP(nn.Module): + features: Sequence[int] + # For all transforms, we can annotate a method, or wrap an existing + # Module class. Here we annotate the method. + @nn.remat + @nn.compact + def __call__(self, inputs): + x = inputs + for i, feat in enumerate(self.features): + x = nn.Dense(feat, name=f'layers_{i}')(x) + if i != len(self.features) - 1: + x = nn.relu(x) + return x + +key1, key2 = random.split(random.key(3), 2) +x = random.uniform(key1, (4,4)) + +model = RematMLP(features=[3,4,5]) +init_variables = model.init(key2, x) +y = model.apply(init_variables, x) + +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) +print('output:\n', y) +``` + +## Vmap + ++++ + +You can now `vmap` Modules inside. The transform has a lot of arguments, they have the usual jax vmap args: + - `in_axes` - an integer or `None` for each input argument + - `out_axes` - an integer or `None` for each output argument + - `axis_size` - the axis size if you need to give it explicitly + +In addition, we provide for each __kind__ of variable it's axis rules: + + - `variable_in_axes` - a dict from kinds to a single integer or `None` specifying the input axes to map + - `variable_out_axes` - a dict from kinds to a single integer or `None` specifying the output axes to map + - `split_rngs` - a dict from RNG-kinds to a bool, specifying whether to split the rng along the axis. + + +Below we show an example defining a batched, multiheaded attention module from a single-headed unbatched attention implementation. + +```{code-cell} +:outputId: 223d880e-c7b2-4210-ebb5-dbfcdd9aed09 +:tags: [] + +class RawDotProductAttention(nn.Module): + attn_dropout_rate: float = 0.1 + train: bool = False + + @nn.compact + def __call__(self, query, key, value, bias=None, dtype=jnp.float32): + assert key.ndim == query.ndim + assert key.ndim == value.ndim + + n = query.ndim + attn_weights = lax.dot_general( + query, key, + (((n-1,), (n - 1,)), ((), ()))) + if bias is not None: + attn_weights += bias + norm_dims = tuple(range(attn_weights.ndim // 2, attn_weights.ndim)) + attn_weights = jax.nn.softmax(attn_weights, axis=norm_dims) + attn_weights = nn.Dropout(self.attn_dropout_rate)(attn_weights, + deterministic=not self.train) + attn_weights = attn_weights.astype(dtype) + + contract_dims = ( + tuple(range(n - 1, attn_weights.ndim)), + tuple(range(0, n - 1))) + y = lax.dot_general( + attn_weights, value, + (contract_dims, ((), ()))) + return y + +class DotProductAttention(nn.Module): + qkv_features: Optional[int] = None + out_features: Optional[int] = None + train: bool = False + + @nn.compact + def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): + qkv_features = self.qkv_features or inputs_q.shape[-1] + out_features = self.out_features or inputs_q.shape[-1] + + QKVDense = functools.partial( + nn.Dense, features=qkv_features, use_bias=False, dtype=dtype) + query = QKVDense(name='query')(inputs_q) + key = QKVDense(name='key')(inputs_kv) + value = QKVDense(name='value')(inputs_kv) + + y = RawDotProductAttention(train=self.train)( + query, key, value, bias=bias, dtype=dtype) + + y = nn.Dense(features=out_features, dtype=dtype, name='out')(y) + return y + +class MultiHeadDotProductAttention(nn.Module): + qkv_features: Optional[int] = None + out_features: Optional[int] = None + batch_axes: Sequence[int] = (0,) + num_heads: int = 1 + broadcast_dropout: bool = False + train: bool = False + @nn.compact + def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): + qkv_features = self.qkv_features or inputs_q.shape[-1] + out_features = self.out_features or inputs_q.shape[-1] + + # Make multiheaded attention from single-headed dimension. + Attn = nn.vmap(DotProductAttention, + in_axes=(None, None, None), + out_axes=2, + axis_size=self.num_heads, + variable_axes={'params': 0}, + split_rngs={'params': True, + 'dropout': not self.broadcast_dropout}) + + # Vmap across batch dimensions. + for axis in reversed(sorted(self.batch_axes)): + Attn = nn.vmap(Attn, + in_axes=(axis, axis, axis), + out_axes=axis, + variable_axes={'params': None}, + split_rngs={'params': False, 'dropout': False}) + + # Run the vmap'd class on inputs. + y = Attn(qkv_features=qkv_features // self.num_heads, + out_features=out_features, + train=self.train, + name='attention')(inputs_q, inputs_kv, bias) + + return y.mean(axis=-2) + + +key1, key2, key3, key4 = random.split(random.key(0), 4) +x = random.uniform(key1, (3, 13, 64)) + +model = functools.partial( + MultiHeadDotProductAttention, + broadcast_dropout=False, + num_heads=2, + batch_axes=(0,)) + +init_variables = model(train=False).init({'params': key2}, x, x) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) + +y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4}) +print('output:\n', y.shape) +``` + +## Scan + ++++ + +Scan allows us to apply `lax.scan` to Modules, including their parameters and mutable variables. To use it we have to specify how we want each "kind" of variable to be transformed. For scanned variables we specify similar to vmap via in `variable_in_axes`, `variable_out_axes`: + - `nn.broadcast` broadcast the variable kind across the scan steps as a constant + - `` scan along `axis` for e.g. unique parameters at each step + +OR we specify that the variable kind is to be treated like a "carry" by passing to the `variable_carry` argument. + +Further, for `scan`'d variable kinds, we further specify whether or not to split the rng at each step. + +```{code-cell} +:outputId: 7d9ebed3-64de-4ca8-9dce-4b09ba9e31a1 +:tags: [] + +class SimpleScan(nn.Module): + features: int + + @nn.compact + def __call__(self, xs): + LSTM = nn.scan(nn.LSTMCell, + in_axes=1, out_axes=1, + variable_broadcast='params', + split_rngs={'params': False}) + lstm = LSTM(self.features, name="lstm_cell") + + dummy_rng = random.key(0) + input_shape = xs[:, 0].shape + init_carry = lstm.initialize_carry(dummy_rng, input_shape) + + return lstm(init_carry, xs) + +key1, key2 = random.split(random.key(0), 2) +xs = random.uniform(key1, (1, 5, 2)) + +model = SimpleScan(2) +init_variables = model.init(key2, xs) + +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) + +y = model.apply(init_variables, xs) +print('output:\n', y) +``` diff --git a/docs/nnx/mnist_tutorial.ipynb b/docs_nnx/mnist_tutorial.ipynb similarity index 99% rename from docs/nnx/mnist_tutorial.ipynb rename to docs_nnx/mnist_tutorial.ipynb index 83c8f91203..a01acbcf12 100644 --- a/docs/nnx/mnist_tutorial.ipynb +++ b/docs_nnx/mnist_tutorial.ipynb @@ -10,9 +10,9 @@ "\n", "# MNIST Tutorial\n", "\n", - "Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional \n", + "Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional\n", "neural network (CNN) on the MNIST dataset using the Flax NNX API. Flax NNX is a Python neural network library\n", - "built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within \n", + "built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within\n", "[Flax](https://github.com/google/flax)." ] }, @@ -23,7 +23,7 @@ "source": [ "## 1. Install Flax\n", "\n", - "If `flax` is not installed in your environment, you can install it from PyPI, uncomment and run the \n", + "If `flax` is not installed in your environment, you can install it from PyPI, uncomment and run the\n", "following cell:" ] }, @@ -48,8 +48,8 @@ "source": [ "## 2. Load the MNIST Dataset\n", "\n", - "First, the MNIST dataset is loaded and prepared for training and testing using \n", - "Tensorflow Datasets. Image values are normalized, the data is shuffled and divided \n", + "First, the MNIST dataset is loaded and prepared for training and testing using\n", + "Tensorflow Datasets. Image values are normalized, the data is shuffled and divided\n", "into batches, and samples are prefetched to enhance performance." ] }, @@ -235,7 +235,7 @@ "\n", "optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))\n", "metrics = nnx.MultiMetric(\n", - " accuracy=nnx.metrics.Accuracy(), \n", + " accuracy=nnx.metrics.Accuracy(),\n", " loss=nnx.metrics.Average('loss'),\n", ")\n", "\n", @@ -285,8 +285,8 @@ "id": "17", "metadata": {}, "source": [ - "The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with \n", - "[XLA](https://www.tensorflow.org/xla), optimizing performance on \n", + "The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with\n", + "[XLA](https://www.tensorflow.org/xla), optimizing performance on\n", "hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),\n", "except it can transforms functions that contain Flax NNX objects as inputs and outputs.\n", "\n", @@ -300,7 +300,7 @@ "source": [ "## 6. Train and Evaluate\n", "\n", - "Now we train a model using batches of data for 10 epochs, evaluate its performance \n", + "Now we train a model using batches of data for 10 epochs, evaluate its performance\n", "on the test set after each epoch, and log the training and testing metrics (loss and\n", "accuracy) throughout the process. Typically this leads to a model with around 99% accuracy." ] diff --git a/docs/nnx/mnist_tutorial.md b/docs_nnx/mnist_tutorial.md similarity index 97% rename from docs/nnx/mnist_tutorial.md rename to docs_nnx/mnist_tutorial.md index 6cea6668de..740395331a 100644 --- a/docs/nnx/mnist_tutorial.md +++ b/docs_nnx/mnist_tutorial.md @@ -14,16 +14,16 @@ jupytext: # MNIST Tutorial -Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional +Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional neural network (CNN) on the MNIST dataset using the Flax NNX API. Flax NNX is a Python neural network library -built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within +built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within [Flax](https://github.com/google/flax). +++ ## 1. Install Flax -If `flax` is not installed in your environment, you can install it from PyPI, uncomment and run the +If `flax` is not installed in your environment, you can install it from PyPI, uncomment and run the following cell: ```{code-cell} ipython3 @@ -34,8 +34,8 @@ following cell: ## 2. Load the MNIST Dataset -First, the MNIST dataset is loaded and prepared for training and testing using -Tensorflow Datasets. Image values are normalized, the data is shuffled and divided +First, the MNIST dataset is loaded and prepared for training and testing using +Tensorflow Datasets. Image values are normalized, the data is shuffled and divided into batches, and samples are prefetched to enhance performance. ```{code-cell} ipython3 @@ -127,7 +127,7 @@ momentum = 0.9 optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum)) metrics = nnx.MultiMetric( - accuracy=nnx.metrics.Accuracy(), + accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average('loss'), ) @@ -160,8 +160,8 @@ def eval_step(model: CNN, metrics: nnx.MultiMetric, batch): metrics.update(loss=loss, logits=logits, labels=batch['label']) # inplace updates ``` -The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with -[XLA](https://www.tensorflow.org/xla), optimizing performance on +The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with +[XLA](https://www.tensorflow.org/xla), optimizing performance on hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), except it can transforms functions that contain Flax NNX objects as inputs and outputs. @@ -171,7 +171,7 @@ except it can transforms functions that contain Flax NNX objects as inputs and o ## 6. Train and Evaluate -Now we train a model using batches of data for 10 epochs, evaluate its performance +Now we train a model using batches of data for 10 epochs, evaluate its performance on the test set after each epoch, and log the training and testing metrics (loss and accuracy) throughout the process. Typically this leads to a model with around 99% accuracy. diff --git a/docs/nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb similarity index 99% rename from docs/nnx/nnx_basics.ipynb rename to docs_nnx/nnx_basics.ipynb index fb326e1dc6..94f1cb044d 100644 --- a/docs/nnx/nnx_basics.ipynb +++ b/docs_nnx/nnx_basics.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Flax NNX Basics\n", + "# Flax Basics\n", "\n", "Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug,\n", "and analyze neural networks in JAX. It achieves this by adding first class support\n", @@ -43,15 +43,15 @@ "metadata": {}, "source": [ "## The Module System\n", - "To begin lets see how to create a `Linear` Module using Flax. The main difference between \n", - "Flax NNX and Module systems like Haiku or Flax Linen is that everything is **explicit**. This \n", - "means among other things that 1) the Module itself holds the state (e.g. parameters) directly, \n", - "2) the RNG state is threaded by the user, and 3) all shape information must be provided on \n", + "To begin lets see how to create a `Linear` Module using Flax. The main difference between\n", + "Flax NNX and Module systems like Haiku or Flax Linen is that everything is **explicit**. This\n", + "means among other things that 1) the Module itself holds the state (e.g. parameters) directly,\n", + "2) the RNG state is threaded by the user, and 3) all shape information must be provided on\n", "initialization (no shape inference).\n", "\n", - "As shown next, dynamic state is usually stored in `nnx.Param`s, and static state \n", - "(all types not handled by Flax) such as integers or strings are stored directly. \n", - "Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic \n", + "As shown next, dynamic state is usually stored in `nnx.Param`s, and static state\n", + "(all types not handled by Flax) such as integers or strings are stored directly.\n", + "Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic\n", "state, although storing them inside `nnx.Variable`s such as `Param` is preferred.\n", "Also, `nnx.Rngs` can be used to get new unique keys starting from a root key." ] @@ -83,9 +83,9 @@ "to any JAX function as they implement the `__jax_array__` protocol (as long as their\n", "inner value is a JAX array).\n", "\n", - "To actually initialize a Module you simply call the constructor, all the parameters \n", - "of a Module are usually created eagerly. Since Modules hold their own state methods \n", - "can be called directly without the need for a separate `apply` method, this is very \n", + "To actually initialize a Module you simply call the constructor, all the parameters\n", + "of a Module are usually created eagerly. Since Modules hold their own state methods\n", + "can be called directly without the need for a separate `apply` method, this is very\n", "convenient for debugging as entire structure of the model can be inspected directly." ] }, @@ -135,8 +135,8 @@ "source": [ "### Stateful Computation\n", "\n", - "Implementing layers such as `BatchNorm` requires performing state updates during the \n", - "forward pass. To implement this in Flax you just create a `Variable` and update its \n", + "Implementing layers such as `BatchNorm` requires performing state updates during the\n", + "forward pass. To implement this in Flax you just create a `Variable` and update its\n", "`.value` during the forward pass." ] }, @@ -184,7 +184,7 @@ "source": [ "### Nested Modules\n", "\n", - "As expected, Modules can be used to compose other Modules in a nested structure, these can \n", + "As expected, Modules can be used to compose other Modules in a nested structure, these can\n", "be assigned directly as attributes, or inside an attribute of any (nested) pytree type e.g.\n", " `list`, `dict`, `tuple`, etc. In the example below we define a simple `MLP` Module that\n", "consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer." @@ -219,7 +219,7 @@ " def __call__(self, x: jax.Array):\n", " x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))\n", " return self.linear2(x)\n", - " \n", + "\n", "model = MLP(2, 16, 5, rngs=nnx.Rngs(0))\n", "\n", "y = model(x=jnp.ones((3, 2)))\n", @@ -240,9 +240,9 @@ "metadata": {}, "source": [ "#### Model Surgery\n", - "Flax NNX Modules are mutable by default, this means their structure can be changed at any time, \n", + "Flax NNX Modules are mutable by default, this means their structure can be changed at any time,\n", "this makes model surgery quite easy as any submodule attribute can be replaced with anything\n", - "else e.g. new Modules, existing shared Modules, Modules of different types, etc. More over, \n", + "else e.g. new Modules, existing shared Modules, Modules of different types, etc. More over,\n", "`Variable`s can also be modified or replaced / shared.\n", "\n", "The following example shows how to replace the `Linear` layers in the `MLP` model\n", @@ -299,8 +299,8 @@ "\n", "Flax Transforms extend JAX transforms to support Modules and other objects.\n", "They are supersets of their equivalent JAX counterpart with the addition of\n", - "being aware of the object's state and providing additional APIs to transform \n", - "it. One of the main features of Flax Transforms is the preservation of reference semantics, \n", + "being aware of the object's state and providing additional APIs to transform\n", + "it. One of the main features of Flax Transforms is the preservation of reference semantics,\n", "meaning that any mutation of the object graph that occurs inside the transform is\n", "propagated outisde as long as its legal within the transform rules. In practice this\n", "means that Flax programs can be express using imperative code, highly simplifying\n", @@ -441,13 +441,13 @@ "## The Functional API\n", "\n", "The Functional API establishes a clear boundary between reference/object semantics and\n", - "value/pytree semantics. It also allows same amount of fine-grained control over the \n", + "value/pytree semantics. It also allows same amount of fine-grained control over the\n", "state that Linen/Haiku users are used to. The Functional API consists of 3 basic methods:\n", "`split`, `merge`, and `update`.\n", "\n", "The `StatefulLinear` Module shown below will serve as an example for the use of the\n", "Functional API. It contains some `nnx.Param` Variables and a custom `Count` Variable\n", - "type which is used to keep track of integer scalar state that increases on every \n", + "type which is used to keep track of integer scalar state that increases on every\n", "forward pass." ] }, @@ -481,7 +481,7 @@ " def __call__(self, x: jax.Array):\n", " self.count += 1\n", " return x @ self.w + self.b\n", - " \n", + "\n", "model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))\n", "y = model(jnp.ones((1, 3)))\n", "\n", @@ -495,8 +495,8 @@ "### State and GraphDef\n", "\n", "A Module can be decomposed into `GraphDef` and `State` using the\n", - "`split` function. State is a Mapping from strings to Variables or nested \n", - "States. GraphDef contains all the static information needed to reconstruct \n", + "`split` function. State is a Mapping from strings to Variables or nested\n", + "States. GraphDef contains all the static information needed to reconstruct\n", "a Module graph, it is analogous to JAX's `PyTreeDef`." ] }, @@ -545,7 +545,7 @@ "`merge` is the reverse of `split`, it takes the GraphDef + State and reconstructs\n", "the Module. As shown in the example below, by using `split` and `merge` in sequence\n", "any Module can be lifted to be used in any JAX transform. `update` can\n", - "update an object inplace with the content of a given State. This pattern is used to \n", + "update an object inplace with the content of a given State. This pattern is used to\n", "propagate the state from a transform back to the source object outside." ] }, @@ -590,7 +590,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The key insight of this pattern is that using mutable references is \n", + "The key insight of this pattern is that using mutable references is\n", "fine within a transform context (including the base eager interpreter)\n", "but its necessary to use the Functional API when crossing boundaries.\n", "\n", diff --git a/docs/nnx/nnx_basics.md b/docs_nnx/nnx_basics.md similarity index 96% rename from docs/nnx/nnx_basics.md rename to docs_nnx/nnx_basics.md index 42b67dd547..bdf2d05bfc 100644 --- a/docs/nnx/nnx_basics.md +++ b/docs_nnx/nnx_basics.md @@ -8,7 +8,7 @@ jupytext: jupytext_version: 1.13.8 --- -# Flax NNX Basics +# Flax Basics Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in JAX. It achieves this by adding first class support @@ -30,15 +30,15 @@ import jax.numpy as jnp ``` ## The Module System -To begin lets see how to create a `Linear` Module using Flax. The main difference between -Flax NNX and Module systems like Haiku or Flax Linen is that everything is **explicit**. This -means among other things that 1) the Module itself holds the state (e.g. parameters) directly, -2) the RNG state is threaded by the user, and 3) all shape information must be provided on +To begin lets see how to create a `Linear` Module using Flax. The main difference between +Flax NNX and Module systems like Haiku or Flax Linen is that everything is **explicit**. This +means among other things that 1) the Module itself holds the state (e.g. parameters) directly, +2) the RNG state is threaded by the user, and 3) all shape information must be provided on initialization (no shape inference). -As shown next, dynamic state is usually stored in `nnx.Param`s, and static state -(all types not handled by Flax) such as integers or strings are stored directly. -Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic +As shown next, dynamic state is usually stored in `nnx.Param`s, and static state +(all types not handled by Flax) such as integers or strings are stored directly. +Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic state, although storing them inside `nnx.Variable`s such as `Param` is preferred. Also, `nnx.Rngs` can be used to get new unique keys starting from a root key. @@ -60,9 +60,9 @@ arithmetic expressions (as shown above). Additionally, Variables can passed to any JAX function as they implement the `__jax_array__` protocol (as long as their inner value is a JAX array). -To actually initialize a Module you simply call the constructor, all the parameters -of a Module are usually created eagerly. Since Modules hold their own state methods -can be called directly without the need for a separate `apply` method, this is very +To actually initialize a Module you simply call the constructor, all the parameters +of a Module are usually created eagerly. Since Modules hold their own state methods +can be called directly without the need for a separate `apply` method, this is very convenient for debugging as entire structure of the model can be inspected directly. ```{code-cell} ipython3 @@ -79,8 +79,8 @@ The above visualization by `nnx.display` is generated using the awesome [Treesco ### Stateful Computation -Implementing layers such as `BatchNorm` requires performing state updates during the -forward pass. To implement this in Flax you just create a `Variable` and update its +Implementing layers such as `BatchNorm` requires performing state updates during the +forward pass. To implement this in Flax you just create a `Variable` and update its `.value` during the forward pass. ```{code-cell} ipython3 @@ -106,7 +106,7 @@ Flax provides sound mechanisms to handle them. ### Nested Modules -As expected, Modules can be used to compose other Modules in a nested structure, these can +As expected, Modules can be used to compose other Modules in a nested structure, these can be assigned directly as attributes, or inside an attribute of any (nested) pytree type e.g. `list`, `dict`, `tuple`, etc. In the example below we define a simple `MLP` Module that consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer. @@ -122,7 +122,7 @@ class MLP(nnx.Module): def __call__(self, x: jax.Array): x = nnx.gelu(self.dropout(self.bn(self.linear1(x)))) return self.linear2(x) - + model = MLP(2, 16, 5, rngs=nnx.Rngs(0)) y = model(x=jnp.ones((3, 2))) @@ -136,9 +136,9 @@ new masks during the forward pass without the need for the user to pass a new ke +++ #### Model Surgery -Flax NNX Modules are mutable by default, this means their structure can be changed at any time, +Flax NNX Modules are mutable by default, this means their structure can be changed at any time, this makes model surgery quite easy as any submodule attribute can be replaced with anything -else e.g. new Modules, existing shared Modules, Modules of different types, etc. More over, +else e.g. new Modules, existing shared Modules, Modules of different types, etc. More over, `Variable`s can also be modified or replaced / shared. The following example shows how to replace the `Linear` layers in the `MLP` model @@ -172,8 +172,8 @@ nnx.display(model) Flax Transforms extend JAX transforms to support Modules and other objects. They are supersets of their equivalent JAX counterpart with the addition of -being aware of the object's state and providing additional APIs to transform -it. One of the main features of Flax Transforms is the preservation of reference semantics, +being aware of the object's state and providing additional APIs to transform +it. One of the main features of Flax Transforms is the preservation of reference semantics, meaning that any mutation of the object graph that occurs inside the transform is propagated outisde as long as its legal within the transform rules. In practice this means that Flax programs can be express using imperative code, highly simplifying @@ -262,13 +262,13 @@ JAX transforms lets take a look at the Functional API. ## The Functional API The Functional API establishes a clear boundary between reference/object semantics and -value/pytree semantics. It also allows same amount of fine-grained control over the +value/pytree semantics. It also allows same amount of fine-grained control over the state that Linen/Haiku users are used to. The Functional API consists of 3 basic methods: `split`, `merge`, and `update`. The `StatefulLinear` Module shown below will serve as an example for the use of the Functional API. It contains some `nnx.Param` Variables and a custom `Count` Variable -type which is used to keep track of integer scalar state that increases on every +type which is used to keep track of integer scalar state that increases on every forward pass. ```{code-cell} ipython3 @@ -283,7 +283,7 @@ class StatefulLinear(nnx.Module): def __call__(self, x: jax.Array): self.count += 1 return x @ self.w + self.b - + model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0)) y = model(jnp.ones((1, 3))) @@ -293,8 +293,8 @@ nnx.display(model) ### State and GraphDef A Module can be decomposed into `GraphDef` and `State` using the -`split` function. State is a Mapping from strings to Variables or nested -States. GraphDef contains all the static information needed to reconstruct +`split` function. State is a Mapping from strings to Variables or nested +States. GraphDef contains all the static information needed to reconstruct a Module graph, it is analogous to JAX's `PyTreeDef`. ```{code-cell} ipython3 @@ -308,7 +308,7 @@ nnx.display(graphdef, state) `merge` is the reverse of `split`, it takes the GraphDef + State and reconstructs the Module. As shown in the example below, by using `split` and `merge` in sequence any Module can be lifted to be used in any JAX transform. `update` can -update an object inplace with the content of a given State. This pattern is used to +update an object inplace with the content of a given State. This pattern is used to propagate the state from a transform back to the source object outside. ```{code-cell} ipython3 @@ -334,7 +334,7 @@ nnx.update(model, state) print(f'{model.count.value = }') ``` -The key insight of this pattern is that using mutable references is +The key insight of this pattern is that using mutable references is fine within a transform context (including the base eager interpreter) but its necessary to use the Functional API when crossing boundaries. diff --git a/docs_nnx/philosophy.md b/docs_nnx/philosophy.md new file mode 100644 index 0000000000..e9a5dcd5d7 --- /dev/null +++ b/docs_nnx/philosophy.md @@ -0,0 +1,121 @@ +# The Flax philosophy + +In no particular order: + +* Library code should be easy to read and understand. +* Prefer duplicating code over a bad abstraction. +* Generally, prefer duplicating code over adding options to functions. +* Comment-driven design: If it's hard to document your code, consider + changing the design. +* Unit test-driven design: If it's hard to test your code, consider + changing the design. +* People start projects by copying an existing implementation — make + base implementations excellent. +* If we expose an abstraction to our developers, we own the mental + overhead. +* Developer-facing functional programming abstractions confuse some users, + expose them where the benefit is high. +* "Read the manual" is not an appropriate response to developer confusion. + The framework should guide developers + towards good solutions, such as through assertions and error messages. +* An unhelpful error message is a bug. +* "Debugging is twice as hard as writing the code in the first + place. Therefore, if you write the code as cleverly as possible, you + are, by definition, not smart enough to debug it." — Brian Kernighan + +## Design principles + +Flax is a neural network library built on [JAX](https://jax.readthedocs.io) that has been adopted by a +growing set of users, most notably in the JAX submissions for the MLPerf +0.7 benchmark. Our experience over the last year (and many conversations +with users and JAX core devs) has guided a redesign of the API called +[Linen](https://github.com/google/flax/blob/main/flax/linen/README.md) ([`flax.linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html)) in response to the following basic design questions. + +### How does a neural network library benefit from being built on JAX and leverage JAX’s unique strengths? + +The world already has TensorFlow and PyTorch, and there’s little need to +build a clone of either. We believe that the composable +function-transformation approach that JAX takes opens up new frontiers +for making neural net code more maintainable, more scalable and more +performant than existing libraries. While we strive to offer an API +familiar to those experienced with Keras/Sonnet/PyTorch, Linen is +fundamentally a functional system for defining neural nets in JAX. Just +a few examples of what we believe a JAX-targeted library can enable: + +- Write models as “single-example” code and introduce batching + automatically with [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html). +- Automatically handle ragged batches in NLP and other masking issues. +- Create efficient compile-time and runtime models by utilizing + rematerialized `scan` for massive convolutional networks. +- Remove memory headaches by enabling easy rematerialization, + reversibility, and model-parallel data sharding. + +### How does one interoperate with JAX transformations? + +Arguably, the entire point of a neural net library is to offer an +implicit variable management API to save the user from having to +manually thread thousands of variables through a complex tree of +functions. However, JAX operates on pure functions. To handle both +current and future JAX transforms (configured and composed in any way), +Linen Modules are directly “functionalized”, that is, automatically cast +in-place as explicit functions of the form: + +$$f \left( v_{in}, x \right) \rightarrow v_{out}, y$$ + +Where $v_{in}$ is the variable collections and [PRNG](https://jax.readthedocs.io/en/latest/jep/263-prng.html) state used by +the model, $v_{out}$ the mutated output variable collections, +$x$ the input data and $y$ the output data. Applying JAX +transformations then simply reduces to specifying any argument-specific +transform options to the various variable collections and PRNG state. +This unleashes the flexibility and strength of [JAX transformations](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) – for +example, one can achieve either device-parallel training or per-device +ensembling by using [`jax.pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) in different ways, without any explicit +library support. Moreover, **within [Modules](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module)**, we expose lightweight +wrappers around the complex JAX transforms such as [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) and [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) +that annotate how each variable collection is to be transformed by JAX. +Importantly, we handle the nontrivial cases of creating new variables +and transformed variables under mapping and loop transforms correctly +for initialization and application. + +### How are parameters represented, and how do we handle general “differentiable algorithms” that update stateful variables? + +We follow the JAX functional conventions of storing data in “pytrees”: +JAX arrays contained in nested tuples, lists, dictionaries. Because +researchers inevitably manually interact with this data, we use nested +dictionaries with meaningful default keys and offer several utilities +(traversals, etc.) for handling them directly. Linen uses an accelerated +version of a Python frozen dictionary that caches its JAX-flattened form +to speed up [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html)ted function call overheads. + +Flax generalizes the operation of a neural net by allowing models to +accept collections of several different “kinds”: parameters, batch-norm +stats, autoregressive caches, debug information, fine-grained +hyperparameters, etc. Each collection is stored in a nested dictionary +of the same structure as the model. Importantly, we do *not* conflate +these various kinds under the single vague rubric of “state”, but keep +different logical types of variables separate that can be treated +differently under JAX transformations and under mutations (e.g. training +vs prediction). Similarly, we allow for multiple separate named PRNG +chains inside [Modules](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) for separate treatment of randomness for different +applications such as initialization, dropout, sampling, etc. + +At every stage the data associated with a neural net is not kept in a +custom object hierarchy, but left in an explicit, Python and JAX native +form that is easy to introspect and modify. Users have utilized this to +map TF and PyTorch checkpoints to Flax, to implement submodel-specific +loss terms, and to perform fast model surgery, etc. For saving this +data, most Flax examples store these nested dictionaries via the +efficient “msgpack” binary format – but as variables are simply Python +dicts, you can use any (non-JAX-aware) serialization library directly. + +### How does one interoperate with purely functional JAX code? + +To be broadly useful to the JAX ecosystem, users shouldn’t need to +heavily refactor their code in order to add “trainability” for a given +numerical task. _“The library should not get in the way.”_ Utilizing +purely functional code from within Linen is trivial: [Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) +implementations are just JAX code with named variables. Using Linen +Modules inside otherwise purely functional code can be as simple as +using a single top-level Module transformation to allow initialization +and pure application of any JAX program that might contain various +trainable sections. diff --git a/docs_nnx/quick_start.ipynb b/docs_nnx/quick_start.ipynb new file mode 100644 index 0000000000..32530b9bed --- /dev/null +++ b/docs_nnx/quick_start.ipynb @@ -0,0 +1,701 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6eea21b3", + "metadata": {}, + "source": [ + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/quick_start.ipynb)\n", + "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/quick_start.ipynb)\n", + "\n", + "# Quick start\n", + "\n", + "Welcome to Flax!\n", + "\n", + "Flax is an open source Python neural network library built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural\n", + "network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train\n", + "the network for image classification on the MNIST dataset." + ] + }, + { + "cell_type": "markdown", + "id": "nwJWKIhdwxDo", + "metadata": {}, + "source": [ + "## 1. Install Flax" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb81587e", + "metadata": { + "tags": [ + "skip-execution" + ] + }, + "outputs": [], + "source": [ + "!pip install -q flax>=0.7.5" + ] + }, + { + "cell_type": "markdown", + "id": "b529fbef", + "metadata": {}, + "source": [ + "## 2. Loading data\n", + "\n", + "Flax can use any\n", + "data-loading pipeline and this example demonstrates how to utilize TFDS. Define a function that loads and prepares the MNIST dataset and converts the\n", + "samples to floating-point numbers." + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "bRlrHqZVXZvk", + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow_datasets as tfds # TFDS for MNIST\n", + "import tensorflow as tf # TensorFlow operations\n", + "\n", + "def get_datasets(num_epochs, batch_size):\n", + " \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n", + " train_ds = tfds.load('mnist', split='train')\n", + " test_ds = tfds.load('mnist', split='test')\n", + "\n", + " train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'],\n", + " tf.float32) / 255.,\n", + " 'label': sample['label']}) # normalize train set\n", + " test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'],\n", + " tf.float32) / 255.,\n", + " 'label': sample['label']}) # normalize test set\n", + "\n", + " train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", + " train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", + " test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", + " test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", + "\n", + " return train_ds, test_ds" + ] + }, + { + "cell_type": "markdown", + "id": "7057395a", + "metadata": {}, + "source": [ + "## 3. Define network\n", + "\n", + "Create a convolutional neural network with the Linen API by subclassing\n", + "[Flax Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html).\n", + "Because the architecture in this example is relatively simple—you're just\n", + "stacking layers—you can define the inlined submodules directly within the\n", + "`__call__` method and wrap it with the\n", + "[`@compact`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact)\n", + "decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide." + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "cbc079cd", + "metadata": {}, + "outputs": [], + "source": [ + "from flax import linen as nn # Linen API\n", + "\n", + "class CNN(nn.Module):\n", + " \"\"\"A simple CNN model.\"\"\"\n", + "\n", + " @nn.compact\n", + " def __call__(self, x):\n", + " x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n", + " x = nn.relu(x)\n", + " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", + " x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n", + " x = nn.relu(x)\n", + " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", + " x = x.reshape((x.shape[0], -1)) # flatten\n", + " x = nn.Dense(features=256)(x)\n", + " x = nn.relu(x)\n", + " x = nn.Dense(features=10)(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "id": "hy7iRu7_zlx-", + "metadata": {}, + "source": [ + "### View model layers\n", + "\n", + "Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input." + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "lDHfog81zLQa", + "metadata": { + "outputId": "2c580f41-bf5d-40ec-f1cf-ab7f319a84da" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[3m CNN Summary \u001b[0m\n", + "┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mpath \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmodule\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1minputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1moutputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mflops \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mvjp_flops\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mparams \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", + "│ │ CNN │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 8708106 │ 26957556 │ │\n", + "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", + "│ Conv_0 │ Conv │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 455424 │ 1341472 │ bias: │\n", + "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", + "│ │ │ │ │ │ │ kernel: │\n", + "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", + "│ │ │ │ │ │ │ │\n", + "│ │ │ │ │ │ │ \u001b[1m320 \u001b[0m\u001b[1;2m(1.3 \u001b[0m │\n", + "│ │ │ │ │ │ │ \u001b[1;2mKB)\u001b[0m │\n", + "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", + "│ Conv_1 │ Conv │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 6566144 │ 19704320 │ bias: │\n", + "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[6… │\n", + "│ │ │ │ │ │ │ kernel: │\n", + "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", + "│ │ │ │ │ │ │ │\n", + "│ │ │ │ │ │ │ \u001b[1m18,496 \u001b[0m │\n", + "│ │ │ │ │ │ │ \u001b[1;2m(74.0 KB)\u001b[0m │\n", + "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", + "│ Dense_0 │ Dense │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 1605888 │ 5620224 │ bias: │\n", + "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[2… │\n", + "│ │ │ │ │ │ │ kernel: │\n", + "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", + "│ │ │ │ │ │ │ │\n", + "│ │ │ │ │ │ │ \u001b[1m803,072 \u001b[0m │\n", + "│ │ │ │ │ │ │ \u001b[1;2m(3.2 MB)\u001b[0m │\n", + "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", + "│ Dense_1 │ Dense │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 5130 │ 17940 │ bias: │\n", + "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[1… │\n", + "│ │ │ │ │ │ │ kernel: │\n", + "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[2… │\n", + "│ │ │ │ │ │ │ │\n", + "│ │ │ │ │ │ │ \u001b[1m2,570 \u001b[0m │\n", + "│ │ │ │ │ │ │ \u001b[1;2m(10.3 KB)\u001b[0m │\n", + "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", + "│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m Total\u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m824,458 \u001b[0m\u001b[1m \u001b[0m│\n", + "│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1;2m(3.3 MB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\n", + "└─────────┴────────┴────────────┴───────────┴─────────┴───────────┴────────────┘\n", + "\u001b[1m \u001b[0m\n", + "\u001b[1m Total Parameters: 824,458 \u001b[0m\u001b[1;2m(3.3 MB)\u001b[0m\u001b[1m \u001b[0m\n", + "\n", + "\n" + ] + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp # JAX NumPy\n", + "\n", + "cnn = CNN()\n", + "print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1)),\n", + " compute_flops=True, compute_vjp_flops=True))" + ] + }, + { + "cell_type": "markdown", + "id": "4b5ac16e", + "metadata": {}, + "source": [ + "## 4. Create a `TrainState`\n", + "\n", + "A common pattern in Flax is to create a single dataclass that represents the\n", + "entire training state, including step number, parameters, and optimizer state.\n", + "\n", + "Because this is such a common pattern, Flax provides the class\n", + "[`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/flax.training.html#train-state)\n", + "that serves most basic usecases." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "qXr7JDpIxGNZ", + "metadata": { + "outputId": "1249b7fb-6787-41eb-b34c-61d736300844" + }, + "outputs": [], + "source": [ + "!pip install -q clu" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "CJDaJNijyOji", + "metadata": {}, + "outputs": [], + "source": [ + "from clu import metrics\n", + "from flax.training import train_state # Useful dataclass to keep train state\n", + "from flax import struct # Flax dataclasses\n", + "import optax # Common loss functions and optimizers" + ] + }, + { + "cell_type": "markdown", + "id": "8b86b5f1", + "metadata": {}, + "source": [ + "We will be using the `clu` library for computing metrics. For more information on `clu`, refer to the [repo](https://github.com/google/CommonLoopUtils) and [notebook](https://colab.research.google.com/github/google/CommonLoopUtils/blob/master/clu_synopsis.ipynb#scrollTo=ueom-uBWLbeQ)." + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "7W0qf7FC9uG5", + "metadata": {}, + "outputs": [], + "source": [ + "@struct.dataclass\n", + "class Metrics(metrics.Collection):\n", + " accuracy: metrics.Accuracy\n", + " loss: metrics.Average.from_output('loss')" + ] + }, + { + "cell_type": "markdown", + "id": "f3ce5e4c", + "metadata": {}, + "source": [ + "You can then subclass `train_state.TrainState` so that it also contains metrics. This has the advantage that we only need\n", + "to pass around a single argument to functions like `train_step()` (see below) to calculate the loss, update the parameters and compute the metrics all at once." + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "e0102447", + "metadata": {}, + "outputs": [], + "source": [ + "class TrainState(train_state.TrainState):\n", + " metrics: Metrics\n", + "\n", + "def create_train_state(module, rng, learning_rate, momentum):\n", + " \"\"\"Creates an initial `TrainState`.\"\"\"\n", + " params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image\n", + " tx = optax.sgd(learning_rate, momentum)\n", + " return TrainState.create(\n", + " apply_fn=module.apply, params=params, tx=tx,\n", + " metrics=Metrics.empty())" + ] + }, + { + "cell_type": "markdown", + "id": "a15de484", + "metadata": {}, + "source": [ + "## 5. Training step\n", + "\n", + "A function that:\n", + "\n", + "- Evaluates the neural network given the parameters and a batch of input images\n", + " with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply)\n", + " method (forward pass)).\n", + "- Computes the cross entropy loss, using the predefined [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels). Note that this function expects integer labels, so there is no need to convert labels to onehot encoding.\n", + "- Evaluates the gradient of the loss function using\n", + " [`jax.grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad).\n", + "- Applies a\n", + " [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions)\n", + " of gradients to the optimizer to update the model's parameters.\n", + "\n", + "Use JAX's [@jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit)\n", + "decorator to trace the entire `train_step` function and just-in-time compile\n", + "it with [XLA](https://www.tensorflow.org/xla) into fused device operations\n", + "that run faster and more efficiently on hardware accelerators." + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "9b0af486", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def train_step(state, batch):\n", + " \"\"\"Train for a single step.\"\"\"\n", + " def loss_fn(params):\n", + " logits = state.apply_fn({'params': params}, batch['image'])\n", + " loss = optax.softmax_cross_entropy_with_integer_labels(\n", + " logits=logits, labels=batch['label']).mean()\n", + " return loss\n", + " grad_fn = jax.grad(loss_fn)\n", + " grads = grad_fn(state.params)\n", + " state = state.apply_gradients(grads=grads)\n", + " return state" + ] + }, + { + "cell_type": "markdown", + "id": "0ff5145f", + "metadata": {}, + "source": [ + "## 6. Metric computation\n", + "\n", + "Create a separate function for loss and accuracy metrics. Loss is calculated using the `optax.softmax_cross_entropy_with_integer_labels` function, while accuracy is calculated using `clu.metrics`." + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "961bf70b", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def compute_metrics(*, state, batch):\n", + " logits = state.apply_fn({'params': state.params}, batch['image'])\n", + " loss = optax.softmax_cross_entropy_with_integer_labels(\n", + " logits=logits, labels=batch['label']).mean()\n", + " metric_updates = state.metrics.single_from_model_output(\n", + " logits=logits, labels=batch['label'], loss=loss)\n", + " metrics = state.metrics.merge(metric_updates)\n", + " state = state.replace(metrics=metrics)\n", + " return state" + ] + }, + { + "cell_type": "markdown", + "id": "497241c3", + "metadata": {}, + "source": [ + "## 7. Download data" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "bff5393e", + "metadata": {}, + "outputs": [], + "source": [ + "num_epochs = 10\n", + "batch_size = 32\n", + "\n", + "train_ds, test_ds = get_datasets(num_epochs, batch_size)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "809ae1a0", + "metadata": {}, + "source": [ + "## 8. Seed randomness\n", + "\n", + "- Set the TF random seed to ensure dataset shuffling (with `tf.data.Dataset.shuffle`) is reproducible.\n", + "- Get one\n", + " [PRNGKey](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax.random.PRNGKey)\n", + " and use it for parameter initialization. (Learn\n", + " more about\n", + " [JAX PRNG design](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html)\n", + " and [PRNG chains](https://flax.readthedocs.io/en/latest/philosophy.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).)" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "xC4MFyBsfT-U", + "metadata": {}, + "outputs": [], + "source": [ + "tf.random.set_seed(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "e4f6f4d3", + "metadata": {}, + "outputs": [], + "source": [ + "init_rng = jax.random.key(0)" + ] + }, + { + "cell_type": "markdown", + "id": "80fbb60b", + "metadata": {}, + "source": [ + "## 9. Initialize the `TrainState`\n", + "\n", + "Remember that the function `create_train_state` initializes the model parameters, optimizer and metrics\n", + "and puts them into the training state dataclass that is returned." + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "445fcab0", + "metadata": {}, + "outputs": [], + "source": [ + "learning_rate = 0.01\n", + "momentum = 0.9" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "5221eafd", + "metadata": {}, + "outputs": [], + "source": [ + "state = create_train_state(cnn, init_rng, learning_rate, momentum)\n", + "del init_rng # Must not be used anymore." + ] + }, + { + "cell_type": "markdown", + "id": "b1c00230", + "metadata": {}, + "source": [ + "## 10. Train and evaluate\n", + "\n", + "Create a \"shuffled\" dataset by:\n", + "- Repeating the dataset equal to the number of training epochs\n", + "- Allocating a buffer of size 1024 (containing the first 1024 samples in the dataset) of which to randomly sample batches from\n", + " - Everytime a sample is randomly drawn from the buffer, the next sample in the dataset is loaded into the buffer\n", + "\n", + "Define a training loop that:\n", + "- Randomly samples batches from the dataset.\n", + "- Runs an optimization step for each training batch.\n", + "- Computes the mean training metrics across each batch in an epoch.\n", + "- Computes the metrics for the test set using the updated parameters.\n", + "- Records the train and test metrics for visualization.\n", + "\n", + "Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "74295360", + "metadata": {}, + "outputs": [], + "source": [ + "# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs\n", + "num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "cRtnMZuQFlKl", + "metadata": {}, + "outputs": [], + "source": [ + "metrics_history = {'train_loss': [],\n", + " 'train_accuracy': [],\n", + " 'test_loss': [],\n", + " 'test_accuracy': []}" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "2c40ce90", + "metadata": { + "outputId": "258a2c76-2c8f-4a9e-d48b-dde57c342a87" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train epoch: 1, loss: 0.20290373265743256, accuracy: 93.87000274658203\n", + "test epoch: 1, loss: 0.07591685652732849, accuracy: 97.60617065429688\n", + "train epoch: 2, loss: 0.05760224163532257, accuracy: 98.28500366210938\n", + "test epoch: 2, loss: 0.050395529717206955, accuracy: 98.3974380493164\n", + "train epoch: 3, loss: 0.03897436335682869, accuracy: 98.83000183105469\n", + "test epoch: 3, loss: 0.04574578255414963, accuracy: 98.54767608642578\n", + "train epoch: 4, loss: 0.028721099719405174, accuracy: 99.15166473388672\n", + "test epoch: 4, loss: 0.035722777247428894, accuracy: 98.91827392578125\n", + "train epoch: 5, loss: 0.021948494017124176, accuracy: 99.37999725341797\n", + "test epoch: 5, loss: 0.035723842680454254, accuracy: 98.87820434570312\n", + "train epoch: 6, loss: 0.01705147698521614, accuracy: 99.54833221435547\n", + "test epoch: 6, loss: 0.03456473350524902, accuracy: 98.96835327148438\n", + "train epoch: 7, loss: 0.014007646590471268, accuracy: 99.6116714477539\n", + "test epoch: 7, loss: 0.04089202359318733, accuracy: 98.7880630493164\n", + "train epoch: 8, loss: 0.011265480890870094, accuracy: 99.73333740234375\n", + "test epoch: 8, loss: 0.03337760642170906, accuracy: 98.93830108642578\n", + "train epoch: 9, loss: 0.00918484665453434, accuracy: 99.78334045410156\n", + "test epoch: 9, loss: 0.034478139132261276, accuracy: 98.96835327148438\n", + "train epoch: 10, loss: 0.007260234095156193, accuracy: 99.84166717529297\n", + "test epoch: 10, loss: 0.032822880893945694, accuracy: 99.07852172851562\n" + ] + } + ], + "source": [ + "for step,batch in enumerate(train_ds.as_numpy_iterator()):\n", + "\n", + " # Run optimization steps over training batches and compute batch metrics\n", + " state = train_step(state, batch) # get updated train state (which contains the updated parameters)\n", + " state = compute_metrics(state=state, batch=batch) # aggregate batch metrics\n", + "\n", + " if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed\n", + " for metric,value in state.metrics.compute().items(): # compute metrics\n", + " metrics_history[f'train_{metric}'].append(value) # record metrics\n", + " state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch\n", + "\n", + " # Compute metrics on the test set after each training epoch\n", + " test_state = state\n", + " for test_batch in test_ds.as_numpy_iterator():\n", + " test_state = compute_metrics(state=test_state, batch=test_batch)\n", + "\n", + " for metric,value in test_state.metrics.compute().items():\n", + " metrics_history[f'test_{metric}'].append(value)\n", + "\n", + " print(f\"train epoch: {(step+1) // num_steps_per_epoch}, \"\n", + " f\"loss: {metrics_history['train_loss'][-1]}, \"\n", + " f\"accuracy: {metrics_history['train_accuracy'][-1] * 100}\")\n", + " print(f\"test epoch: {(step+1) // num_steps_per_epoch}, \"\n", + " f\"loss: {metrics_history['test_loss'][-1]}, \"\n", + " f\"accuracy: {metrics_history['test_accuracy'][-1] * 100}\")" + ] + }, + { + "cell_type": "markdown", + "id": "gfsecJzvzgCT", + "metadata": {}, + "source": [ + "## 11. Visualize metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "Zs5atiqIG9Kz", + "metadata": { + "outputId": "431a2fcd-44fa-4202-f55a-906555f060ac" + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt # Visualization\n", + "\n", + "# Plot loss and accuracy in subplots\n", + "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n", + "ax1.set_title('Loss')\n", + "ax2.set_title('Accuracy')\n", + "for dataset in ('train','test'):\n", + " ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')\n", + " ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')\n", + "ax1.legend()\n", + "ax2.legend()\n", + "plt.show()\n", + "plt.clf()" + ] + }, + { + "cell_type": "markdown", + "id": "qQbKS0tV3sZ1", + "metadata": {}, + "source": [ + "## 12. Perform inference on test set\n", + "\n", + "Define a jitted inference function `pred_step`. Use the learned parameters to do model inference on the test set and visualize the images and their corresponding predicted labels." + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "DFwxgBQf44ks", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def pred_step(state, batch):\n", + " logits = state.apply_fn({'params': state.params}, test_batch['image'])\n", + " return logits.argmax(axis=1)\n", + "\n", + "test_batch = test_ds.as_numpy_iterator().next()\n", + "pred = pred_step(state, test_batch)" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "5d5nF3u44JFI", + "metadata": { + "outputId": "1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e" + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axs = plt.subplots(5, 5, figsize=(12, 12))\n", + "for i, ax in enumerate(axs.flatten()):\n", + " ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')\n", + " ax.set_title(f\"label={pred[i]}\")\n", + " ax.axis('off')" + ] + }, + { + "cell_type": "markdown", + "id": "edb528b6", + "metadata": {}, + "source": [ + "Congratulations! You made it to the end of the annotated MNIST example. You can revisit\n", + "the same example, but structured differently as a couple of Python modules, test\n", + "modules, config files, another Colab, and documentation in Flax's Git repo:\n", + "\n", + "[https://github.com/google/flax/tree/main/examples/mnist](https://github.com/google/flax/tree/main/examples/mnist)" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst", + "main_language": "python" + }, + "language_info": { + "name": "python", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs_nnx/quick_start.md b/docs_nnx/quick_start.md new file mode 100644 index 0000000000..ac8a9fb860 --- /dev/null +++ b/docs_nnx/quick_start.md @@ -0,0 +1,355 @@ +--- +jupytext: + formats: ipynb,md:myst + main_language: python + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/quick_start.ipynb) +[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/quick_start.ipynb) + +# Quick start + +Welcome to Flax! + +Flax is an open source Python neural network library built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural +network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train +the network for image classification on the MNIST dataset. + ++++ + +## 1. Install Flax + +```{code-cell} +:tags: [skip-execution] + +!pip install -q flax>=0.7.5 +``` + +## 2. Loading data + +Flax can use any +data-loading pipeline and this example demonstrates how to utilize TFDS. Define a function that loads and prepares the MNIST dataset and converts the +samples to floating-point numbers. + +```{code-cell} +import tensorflow_datasets as tfds # TFDS for MNIST +import tensorflow as tf # TensorFlow operations + +def get_datasets(num_epochs, batch_size): + """Load MNIST train and test datasets into memory.""" + train_ds = tfds.load('mnist', split='train') + test_ds = tfds.load('mnist', split='test') + + train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'], + tf.float32) / 255., + 'label': sample['label']}) # normalize train set + test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'], + tf.float32) / 255., + 'label': sample['label']}) # normalize test set + + train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from + train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency + test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from + test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency + + return train_ds, test_ds +``` + +## 3. Define network + +Create a convolutional neural network with the Linen API by subclassing +[Flax Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html). +Because the architecture in this example is relatively simple—you're just +stacking layers—you can define the inlined submodules directly within the +`__call__` method and wrap it with the +[`@compact`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact) +decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide. + +```{code-cell} +from flax import linen as nn # Linen API + +class CNN(nn.Module): + """A simple CNN model.""" + + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = nn.Dense(features=256)(x) + x = nn.relu(x) + x = nn.Dense(features=10)(x) + return x +``` + +### View model layers + +Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input. + +```{code-cell} +:outputId: 2c580f41-bf5d-40ec-f1cf-ab7f319a84da + +import jax +import jax.numpy as jnp # JAX NumPy + +cnn = CNN() +print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1)), + compute_flops=True, compute_vjp_flops=True)) +``` + +## 4. Create a `TrainState` + +A common pattern in Flax is to create a single dataclass that represents the +entire training state, including step number, parameters, and optimizer state. + +Because this is such a common pattern, Flax provides the class +[`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/flax.training.html#train-state) +that serves most basic usecases. + +```{code-cell} +:outputId: 1249b7fb-6787-41eb-b34c-61d736300844 + +!pip install -q clu +``` + +```{code-cell} +from clu import metrics +from flax.training import train_state # Useful dataclass to keep train state +from flax import struct # Flax dataclasses +import optax # Common loss functions and optimizers +``` + +We will be using the `clu` library for computing metrics. For more information on `clu`, refer to the [repo](https://github.com/google/CommonLoopUtils) and [notebook](https://colab.research.google.com/github/google/CommonLoopUtils/blob/master/clu_synopsis.ipynb#scrollTo=ueom-uBWLbeQ). + +```{code-cell} +@struct.dataclass +class Metrics(metrics.Collection): + accuracy: metrics.Accuracy + loss: metrics.Average.from_output('loss') +``` + +You can then subclass `train_state.TrainState` so that it also contains metrics. This has the advantage that we only need +to pass around a single argument to functions like `train_step()` (see below) to calculate the loss, update the parameters and compute the metrics all at once. + +```{code-cell} +class TrainState(train_state.TrainState): + metrics: Metrics + +def create_train_state(module, rng, learning_rate, momentum): + """Creates an initial `TrainState`.""" + params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image + tx = optax.sgd(learning_rate, momentum) + return TrainState.create( + apply_fn=module.apply, params=params, tx=tx, + metrics=Metrics.empty()) +``` + +## 5. Training step + +A function that: + +- Evaluates the neural network given the parameters and a batch of input images + with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) + method (forward pass)). +- Computes the cross entropy loss, using the predefined [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels). Note that this function expects integer labels, so there is no need to convert labels to onehot encoding. +- Evaluates the gradient of the loss function using + [`jax.grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad). +- Applies a + [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions) + of gradients to the optimizer to update the model's parameters. + +Use JAX's [@jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit) +decorator to trace the entire `train_step` function and just-in-time compile +it with [XLA](https://www.tensorflow.org/xla) into fused device operations +that run faster and more efficiently on hardware accelerators. + +```{code-cell} +@jax.jit +def train_step(state, batch): + """Train for a single step.""" + def loss_fn(params): + logits = state.apply_fn({'params': params}, batch['image']) + loss = optax.softmax_cross_entropy_with_integer_labels( + logits=logits, labels=batch['label']).mean() + return loss + grad_fn = jax.grad(loss_fn) + grads = grad_fn(state.params) + state = state.apply_gradients(grads=grads) + return state +``` + +## 6. Metric computation + +Create a separate function for loss and accuracy metrics. Loss is calculated using the `optax.softmax_cross_entropy_with_integer_labels` function, while accuracy is calculated using `clu.metrics`. + +```{code-cell} +@jax.jit +def compute_metrics(*, state, batch): + logits = state.apply_fn({'params': state.params}, batch['image']) + loss = optax.softmax_cross_entropy_with_integer_labels( + logits=logits, labels=batch['label']).mean() + metric_updates = state.metrics.single_from_model_output( + logits=logits, labels=batch['label'], loss=loss) + metrics = state.metrics.merge(metric_updates) + state = state.replace(metrics=metrics) + return state +``` + +## 7. Download data + +```{code-cell} +num_epochs = 10 +batch_size = 32 + +train_ds, test_ds = get_datasets(num_epochs, batch_size) +``` + +## 8. Seed randomness + +- Set the TF random seed to ensure dataset shuffling (with `tf.data.Dataset.shuffle`) is reproducible. +- Get one + [PRNGKey](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax.random.PRNGKey) + and use it for parameter initialization. (Learn + more about + [JAX PRNG design](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html) + and [PRNG chains](https://flax.readthedocs.io/en/latest/philosophy.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).) + +```{code-cell} +tf.random.set_seed(0) +``` + +```{code-cell} +init_rng = jax.random.key(0) +``` + +## 9. Initialize the `TrainState` + +Remember that the function `create_train_state` initializes the model parameters, optimizer and metrics +and puts them into the training state dataclass that is returned. + +```{code-cell} +learning_rate = 0.01 +momentum = 0.9 +``` + +```{code-cell} +state = create_train_state(cnn, init_rng, learning_rate, momentum) +del init_rng # Must not be used anymore. +``` + +## 10. Train and evaluate + +Create a "shuffled" dataset by: +- Repeating the dataset equal to the number of training epochs +- Allocating a buffer of size 1024 (containing the first 1024 samples in the dataset) of which to randomly sample batches from + - Everytime a sample is randomly drawn from the buffer, the next sample in the dataset is loaded into the buffer + +Define a training loop that: +- Randomly samples batches from the dataset. +- Runs an optimization step for each training batch. +- Computes the mean training metrics across each batch in an epoch. +- Computes the metrics for the test set using the updated parameters. +- Records the train and test metrics for visualization. + +Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy. + +```{code-cell} +# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs +num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs +``` + +```{code-cell} +metrics_history = {'train_loss': [], + 'train_accuracy': [], + 'test_loss': [], + 'test_accuracy': []} +``` + +```{code-cell} +:outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87 + +for step,batch in enumerate(train_ds.as_numpy_iterator()): + + # Run optimization steps over training batches and compute batch metrics + state = train_step(state, batch) # get updated train state (which contains the updated parameters) + state = compute_metrics(state=state, batch=batch) # aggregate batch metrics + + if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed + for metric,value in state.metrics.compute().items(): # compute metrics + metrics_history[f'train_{metric}'].append(value) # record metrics + state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch + + # Compute metrics on the test set after each training epoch + test_state = state + for test_batch in test_ds.as_numpy_iterator(): + test_state = compute_metrics(state=test_state, batch=test_batch) + + for metric,value in test_state.metrics.compute().items(): + metrics_history[f'test_{metric}'].append(value) + + print(f"train epoch: {(step+1) // num_steps_per_epoch}, " + f"loss: {metrics_history['train_loss'][-1]}, " + f"accuracy: {metrics_history['train_accuracy'][-1] * 100}") + print(f"test epoch: {(step+1) // num_steps_per_epoch}, " + f"loss: {metrics_history['test_loss'][-1]}, " + f"accuracy: {metrics_history['test_accuracy'][-1] * 100}") +``` + +## 11. Visualize metrics + +```{code-cell} +:outputId: 431a2fcd-44fa-4202-f55a-906555f060ac + +import matplotlib.pyplot as plt # Visualization + +# Plot loss and accuracy in subplots +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) +ax1.set_title('Loss') +ax2.set_title('Accuracy') +for dataset in ('train','test'): + ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss') + ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy') +ax1.legend() +ax2.legend() +plt.show() +plt.clf() +``` + +## 12. Perform inference on test set + +Define a jitted inference function `pred_step`. Use the learned parameters to do model inference on the test set and visualize the images and their corresponding predicted labels. + +```{code-cell} +@jax.jit +def pred_step(state, batch): + logits = state.apply_fn({'params': state.params}, test_batch['image']) + return logits.argmax(axis=1) + +test_batch = test_ds.as_numpy_iterator().next() +pred = pred_step(state, test_batch) +``` + +```{code-cell} +:outputId: 1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e + +fig, axs = plt.subplots(5, 5, figsize=(12, 12)) +for i, ax in enumerate(axs.flatten()): + ax.imshow(test_batch['image'][i, ..., 0], cmap='gray') + ax.set_title(f"label={pred[i]}") + ax.axis('off') +``` + +Congratulations! You made it to the end of the annotated MNIST example. You can revisit +the same example, but structured differently as a couple of Python modules, test +modules, config files, another Colab, and documentation in Flax's Git repo: + +[https://github.com/google/flax/tree/main/examples/mnist](https://github.com/google/flax/tree/main/examples/mnist) diff --git a/docs_nnx/robots.txt b/docs_nnx/robots.txt new file mode 100644 index 0000000000..17cb42594c --- /dev/null +++ b/docs_nnx/robots.txt @@ -0,0 +1,5 @@ +User-agent: * + +Disallow: /api_reference/flax.linen/_autosummary/ # for SEO, since Google still indexes this deprecated link + +Sitemap: https://flax.readthedocs.io/sitemap.xml diff --git a/flax/core/meta.py b/flax/core/meta.py index 531b463c7d..27686a40b5 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -22,7 +22,6 @@ """ import abc -import dataclasses import functools from typing import Any, Generic, TypeVar from collections.abc import Callable @@ -288,19 +287,6 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding: """Returns the ``NamedSharding`` for this partitioned value.""" return jax.sharding.NamedSharding(mesh, self.get_partition_spec()) - def to_nnx_metadata(self) -> dict[str, Any]: - """Return a dict of metadata that can translate into an `nnx.Variable`.""" - metadata = vars(self) - metadata['sharding'] = metadata.pop('names') - return metadata - - @classmethod - def from_nnx_metadata(cls, metadata: dict[str, Any]): - """Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`.""" - metadata['names'] = metadata.pop('sharding') - fields = {x.name for x in dataclasses.fields(cls)} - return cls(**{k: v for k, v in metadata.items() if k in fields}) - def with_partitioning( fn: Callable[..., Any], diff --git a/flax/errors.py b/flax/errors.py index b2ecd1be69..7284c6e3fb 100644 --- a/flax/errors.py +++ b/flax/errors.py @@ -64,15 +64,6 @@ def __reduce__(self): return (FlaxError, (str(self),)) -################################################# -# NNX errors # -################################################# - - -class TraceContextError(FlaxError): - pass - - ################################################# # lazy_init.py errors # ################################################# diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index cd622bbdae..93afab7646 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -328,21 +328,6 @@ def unbox(self, apply_constraint=True) -> Any: else: return self.value - def to_nnx_metadata(self) -> dict[str, Any]: - """Return a dict of metadata that can translate into an `nnx.Variable`.""" - metadata = vars(self) - metadata['sharding'] = metadata.pop('names') - metadata['sharding_rules'] = metadata.pop('rules') - return metadata - - @classmethod - def from_nnx_metadata(cls, metadata: dict[str, Any]): - """Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`.""" - metadata['names'] = metadata.pop('sharding') - metadata['rules'] = metadata.pop('sharding_rules') - fields = {x.name for x in dataclasses.fields(cls)} - return cls(**{k: v for k, v in metadata.items() if k in fields}) - def with_logical_partitioning( fn: Callable[..., Any], diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index d73f645f3b..9d8714274a 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -58,10 +58,10 @@ def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str: def register_variable_name_type_pair(name, typ, overwrite = False): - """Register a pair of Linen collection name and its NNX type.""" + """Register a pair of variable type name (like Linen collections) and its NNX type.""" if not overwrite and name in VariableTypeCache: raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. ' - 'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.') + 'To overwrite, call with `overwrite=True`.') VariableTypeCache[name] = typ @@ -85,7 +85,8 @@ def _variable_parents_count(t: type): class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]): - """Default Flax metadata class for `nnx.VariableState`.""" + """Default Flax metadata class for `nnx.VariableState`. + """ var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False) value: Any = struct.field(pytree_node=True) @@ -109,11 +110,10 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata: metadata = vs.get_metadata() if 'linen_meta_type' in metadata: - linen_type = metadata['linen_meta_type'] - if hasattr(linen_type, 'from_nnx_metadata'): - return linen_type.from_nnx_metadata({'value': vs.value, **metadata}) - return linen_type(vs.value, **metadata) - return NNXMeta(vs.type, vs.value, metadata) + if metadata['linen_meta_type'] is not meta.Partitioned: + raise ValueError('Not supporting Linen metadata types other than nn.Partitioned') + return meta.Partitioned(vs.value, names=metadata['sharding'], mesh=metadata['mesh']) + return NNXMeta(vs.type, vs.value, vs.get_metadata()) def get_col_name(keypath: tp.Sequence[Any]) -> str: @@ -124,15 +124,15 @@ def get_col_name(keypath: tp.Sequence[Any]) -> str: def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable: - """Convert a Linen variable to an NNX variable.""" + """Convert a Linen variable to an NNX variable. + This process needs the collection name, + """ vtype = variable_type(col) if isinstance(x, NNXMeta): assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}' return x.var_type(x.value, **x.metadata) if isinstance(x, meta.AxisMetadata): - x_metadata = vars(x) - if hasattr(x, 'to_nnx_metadata'): - x_metadata = x.to_nnx_metadata() - assert hasattr(x, 'value') - return vtype(**x_metadata, linen_meta_type=type(x)) - return vtype(x) \ No newline at end of file + if isinstance(x, meta.Partitioned): + return vtype(x.value, sharding=x.names, mesh=x.mesh, linen_meta_type=meta.Partitioned) + raise ValueError('Not yet supporting metadata types other than nn.Partitioned and NNXMeta') + return vtype(x) diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index d209d89819..20ac7a2601 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -74,7 +74,7 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs): module = fn assert callable(fn) else: - if not hasattr(fn, '__self__') and isinstance(fn.__self__, Module): + if not (hasattr(fn, '__self__') and isinstance(fn.__self__, Module)): raise ValueError(f'{fn = } needs to be a method of an NNX Module.') module = fn.__self__ _set_initializing(module, True) @@ -124,7 +124,6 @@ def __init__( self.linen_collections: tuple[str, ...] = () def lazy_init(self, *args, **kwargs): - """A shortcut of calling `nnx.bridge.lazy_init()` upon this module.""" return lazy_init(self, *args, **kwargs) def __call__( @@ -225,6 +224,28 @@ class ToLinen(linen.Module): skip_rng: bool = False metadata_type: tp.Type = bv.NNXMeta + def update_variables(self, module): + """Store the NNX module's graph def and state inside Linen module variables.""" + gdef, state = nnx.split(module) + # Save the graph def. + if self.is_mutable_collection('nnx'): + self.put_variable('nnx', 'graphdef', gdef) + # Sort all the variable types. + types = set(jax.tree.leaves( + jax.tree.map(lambda x: x.type, state, + is_leaf=lambda x: isinstance(x, nnx.VariableState)))) + types = bv.sort_variable_types(types) + _, *state_by_types = nnx.split(module, *types) + # Each variable type goes to its own linen collection, and + # each attribute goes to its own linen variable + for typ, state in zip(types, state_by_types): + collection = bv.variable_type_name(typ) + if self.is_mutable_collection(collection): + for k, v in state.raw_mapping.items(): + v = jax.tree.map(bv.to_linen_var, v, + is_leaf=lambda x: isinstance(x, nnx.VariableState)) + self.put_variable(collection, k, v) + @linen.compact def __call__(self, *args, **kwargs): # init codepath @@ -234,7 +255,7 @@ def __call__(self, *args, **kwargs): module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self))) module = self.nnx_class(*self.args, **module_kwargs) # TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`. - self._update_variables(module) + self.update_variables(module) return module(*args, **kwargs) # apply codepath @@ -249,33 +270,11 @@ def __call__(self, *args, **kwargs): module = nnx.merge(gdef, nnx_state) nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call. out = module(*args, **kwargs) - self._update_variables(module) + self.update_variables(module) return out - def _update_variables(self, module): - """Store the NNX module's graph def and state inside Linen module variables.""" - gdef, state = nnx.split(module) - # Save the graph def. - if self.is_mutable_collection('nnx'): - self.put_variable('nnx', 'graphdef', gdef) - # Sort all the variable types. - types = set(jax.tree.leaves( - jax.tree.map(lambda x: x.type, state, - is_leaf=lambda x: isinstance(x, nnx.VariableState)))) - types = bv.sort_variable_types(types) - _, *state_by_types = nnx.split(module, *types) - # Each variable type goes to its own linen collection, and - # each attribute goes to its own linen variable - for typ, state in zip(types, state_by_types): - collection = bv.variable_type_name(typ) - if self.is_mutable_collection(collection): - for k, v in state.raw_mapping.items(): - v = jax.tree.map(bv.to_linen_var, v, - is_leaf=lambda x: isinstance(x, nnx.VariableState)) - self.put_variable(collection, k, v) - def to_linen(nnx_class: tp.Callable[..., Module], *args, name: str | None = None, **kwargs): - """Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields.""" + """Shortcut of `ToLinen` if user is not changing any of `ToLinen` default fields.""" return ToLinen(nnx_class, args=args, kwargs=kwargs, name=name) \ No newline at end of file diff --git a/flax/nnx/errors.py b/flax/nnx/errors.py new file mode 100644 index 0000000000..41c7d4fab5 --- /dev/null +++ b/flax/nnx/errors.py @@ -0,0 +1,17 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class TraceContextError(Exception): + pass diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index 845544c307..6ecf6f2405 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -22,7 +22,7 @@ from flax import struct from flax.nnx.object import Object -from flax.typing import Missing, PathParts +from flax.typing import MISSING, PathParts from flax.nnx import graph @@ -59,7 +59,7 @@ def extract_graph_nodes( pytree: A, /, *, - prefix: tp.Any = Missing, + prefix: tp.Any = MISSING, validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None, ) -> ( tuple[A, tuple[tp.Any, ...]] @@ -101,7 +101,7 @@ def extract_graph_nodes( pytree_out = jax.tree.unflatten(treedef, leaves) - if prefix is Missing: + if prefix is MISSING: return pytree_out, tuple(nodes) # type: ignore[bad-return-type] else: return pytree_out, tuple(nodes), tuple(node_prefixes) # type: ignore[bad-return-type] @@ -330,13 +330,12 @@ def to_tree( tree, /, *, - prefix: tp.Any = Missing, + prefix: tp.Any = MISSING, split_fn: tp.Callable[ [graph.SplitContext, KeyPath, Prefix, Leaf], tp.Any ] = default_split_fn, map_non_graph_nodes: bool = False, ctxtag: str | None = None, - check_aliasing: bool = True, ) -> tp.Any: leaf_prefixes = broadcast_prefix( prefix, @@ -352,10 +351,9 @@ def to_tree( with graph.split_context(ctxtag) as split_ctx: for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): if graph.is_graph_node(leaf): - if check_aliasing: - check_consistent_aliasing( - leaf, leaf_prefix, node_prefixes=node_prefixes - ) + check_consistent_aliasing( + leaf, leaf_prefix, node_prefixes=node_prefixes + ) tree_node = split_fn(split_ctx, keypath, leaf_prefix, leaf) leaves_out.append(tree_node) else: @@ -383,7 +381,7 @@ def from_tree( tree: tp.Any, /, *, - prefix: tp.Any = Missing, + prefix: tp.Any = MISSING, merge_fn: tp.Callable[ [graph.MergeContext, KeyPath, Prefix, Leaf], tp.Any ] = merge_tree_node, diff --git a/flax/nnx/object.py b/flax/nnx/object.py index f2714ff7fd..9e14155108 100644 --- a/flax/nnx/object.py +++ b/flax/nnx/object.py @@ -25,13 +25,13 @@ import numpy as np from flax.nnx import ( + errors, reprlib, tracers, ) from flax.nnx import graph from flax.nnx.variables import Variable, VariableState from flax.typing import Key -from flax import errors G = tp.TypeVar('G', bound='Object') diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index 9b20d32381..e18003276b 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -44,7 +44,7 @@ def _add_axis(x: tp.Any): sharding.insert(index, axis_name) x.sharding = tuple(sharding) # type: ignore - x.add_axis(index, axis_name) + x.add_axis(axis_name, index) return x return jax.tree.map( @@ -61,7 +61,7 @@ def _remove_axis(x: tp.Any): sharding = list(x.sharding) assert sharding.pop(index) == axis_name x.sharding = tuple(sharding) - x.remove_axis(index, axis_name) + x.remove_axis(axis_name, index) return x return jax.tree.map( @@ -89,15 +89,9 @@ def _maybe_replicate(x): else: return None - def from_rules(sharding, sharding_rules): - rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules} - return (rules[s] if s in rules else s for s in sharding) - def f(x): if isinstance(x, (variables.VariableState, variables.Variable)): if hasattr(x, 'sharding') and x.sharding: - if hasattr(x, 'sharding_rules') and x.sharding_rules: - return x.replace(PartitionSpec(*from_rules(x.sharding, x.sharding_rules))) return x.replace(PartitionSpec(*x.sharding)) else: return x.replace(_maybe_replicate(x.value)) diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index 1f63654d63..d715898ce0 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -324,7 +324,6 @@ def jit_wrapper(*args, **kwargs): (args, kwargs), prefix=(in_shardings, kwarg_shardings), split_fn=_jit_split_fn, - check_aliasing=in_shardings is not None, ctxtag='jit', ) pure_args_out, pure_kwargs_out, pure_out = jitted_fn( diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index c169a91fa1..36c351f34f 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -107,25 +107,17 @@ def _update_variable_sharding_metadata( ): def _update_axes_fn(tree_node): if isinstance(tree_node, extract.TreeNode) and isinstance( - tree_node.metatata, (StateAxes, int) + tree_node.metatata, StateAxes ): - if isinstance(tree_node.metatata, int): - graph_def_state = tree_node.graphdef_states[0] - assert isinstance(graph_def_state, extract.GraphDefState) - graphdef_state = axis_fn( - graph_def_state, tree_node.metatata, transform_metadata - ) - return tree_node.replace(graphdef_states=(graphdef_state,)) - else: - graphdef_states_out: list[extract.GraphDefState] = [] - for graphdef_state, axis in zip( + graphdef_states_out: list[extract.GraphDefState] = [] + for graphdef_state, axis in zip( tree_node.graphdef_states, tree_node.metatata.axes - ): - assert isinstance(graphdef_state, extract.GraphDefState) - if isinstance(axis, int): - graphdef_state = axis_fn(graphdef_state, axis, transform_metadata) - graphdef_states_out.append(graphdef_state) - return tree_node.replace(graphdef_states=tuple(graphdef_states_out)) + ): + assert isinstance(graphdef_state, extract.GraphDefState) + if isinstance(axis, int): + graphdef_state = axis_fn(graphdef_state, axis, transform_metadata) + graphdef_states_out.append(graphdef_state) + return tree_node.replace(graphdef_states=tuple(graphdef_states_out)) return tree_node return jax.tree.map( @@ -138,7 +130,7 @@ def _vmap_split_fn(ctx: graph.SplitContext, path, prefix, x): return extract.TreeNode.from_split( *ctx.split(x, *prefix.filters), metadata=prefix ) - return extract.TreeNode.from_split(*ctx.split(x), metadata=prefix) + return extract.TreeNode.from_split(*ctx.split(x)) @dataclasses.dataclass(eq=False) diff --git a/flax/nnx/variables.py b/flax/nnx/variables.py index ee6c8a003b..76805477f5 100644 --- a/flax/nnx/variables.py +++ b/flax/nnx/variables.py @@ -22,7 +22,7 @@ import jax -from flax import errors +from flax import nnx from flax.nnx import reprlib, tracers from flax.typing import Missing import jax.tree_util as jtu @@ -36,8 +36,8 @@ CreateValueHook = tp.Callable[['Variable[A]', A], A] AxisName = str AxisIndex = int -AddAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None] -RemoveAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None] +AddAxisHook = tp.Callable[[V, AxisName, AxisIndex], None] +RemoveAxisHook = tp.Callable[[V, AxisName, AxisIndex], None] VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {} @@ -150,43 +150,67 @@ def __init__( **metadata: tp.Any, ): vars(self)['_trace_state'] = tracers.TraceState() - if callable(set_value_hooks): - set_value_hooks = (set_value_hooks,) + if set_value_hooks: + if callable(set_value_hooks): + set_value_hooks = (set_value_hooks,) + else: + set_value_hooks = tuple(set_value_hooks) else: - set_value_hooks = tuple(set_value_hooks) - - if callable(get_value_hooks): - get_value_hooks = (get_value_hooks,) + set_value_hooks = () + if get_value_hooks: + if callable(get_value_hooks): + get_value_hooks = (get_value_hooks,) + else: + get_value_hooks = tuple(get_value_hooks) else: - get_value_hooks = tuple(get_value_hooks) + get_value_hooks = () - if callable(create_value_hooks): - create_value_hooks = (create_value_hooks,) + if create_value_hooks: + if callable(create_value_hooks): + create_value_hooks = (create_value_hooks,) + else: + create_value_hooks = tuple(create_value_hooks) else: - create_value_hooks = tuple(create_value_hooks) + create_value_hooks = () - if callable(add_axis_hooks): - add_axis_hooks = (add_axis_hooks,) + if add_axis_hooks: + if callable(add_axis_hooks): + add_axis_hooks = (add_axis_hooks,) + else: + add_axis_hooks = tuple(add_axis_hooks) else: - add_axis_hooks = tuple(add_axis_hooks) + add_axis_hooks = () - if callable(remove_axis_hooks): - remove_axis_hooks = (remove_axis_hooks,) + if remove_axis_hooks: + if callable(remove_axis_hooks): + remove_axis_hooks = (remove_axis_hooks,) + else: + remove_axis_hooks = tuple(remove_axis_hooks) else: - remove_axis_hooks = tuple(remove_axis_hooks) + remove_axis_hooks = () if isinstance(value, VariableMetadata): value_metadata = dict(value.metadata) - if value.set_value_hooks: + if set_value_hooks and value.set_value_hooks: set_value_hooks = set_value_hooks + value.set_value_hooks - if value.get_value_hooks: + elif value.set_value_hooks: + set_value_hooks = value.set_value_hooks + if get_value_hooks and value.get_value_hooks: get_value_hooks = get_value_hooks + value.get_value_hooks - if value.create_value_hooks: + elif value.get_value_hooks: + get_value_hooks = value.get_value_hooks + if create_value_hooks and value.create_value_hooks: create_value_hooks = create_value_hooks + value.create_value_hooks - if value.add_axis_hooks: + elif value.create_value_hooks: + create_value_hooks = value.create_value_hooks + if add_axis_hooks and value.add_axis_hooks: add_axis_hooks = add_axis_hooks + value.add_axis_hooks - if value.remove_axis_hooks: + elif value.add_axis_hooks: + add_axis_hooks = value.add_axis_hooks + if remove_axis_hooks and value.remove_axis_hooks: remove_axis_hooks = remove_axis_hooks + value.remove_axis_hooks + elif value.remove_axis_hooks: + remove_axis_hooks = value.remove_axis_hooks metadata.update(value_metadata) value = tp.cast(A, value.raw_value) @@ -235,7 +259,7 @@ def __setattr__(self, name: str, value: Any) -> None: def _setattr(self, name: str, value: tp.Any): if not self._trace_state.is_valid(): - raise errors.TraceContextError( + raise nnx.errors.TraceContextError( f'Cannot mutate {type(self).__name__} from a different trace level' ) @@ -294,13 +318,13 @@ def create_value(self, value: A): value = hook(self, value) return value - def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): + def add_axis(self, axis_name: AxisName, axis_index: AxisIndex): for hook in self.add_axis_hooks: - hook(self, axis_index, axis_name) + hook(self, axis_name, axis_index) - def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): + def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex): for hook in self.remove_axis_hooks: - hook(self, axis_index, axis_name) + hook(self, axis_name, axis_index) def __eq__(self, other: object) -> bool: return type(self) is type(other) and vars(other) == vars(self) @@ -394,11 +418,11 @@ def on_set_value(self, value: A) -> A: ... def on_create_value(self, value: A) -> A: ... def on_add_axis( - self: V, axis_index: AxisIndex, axis_name: AxisName | None + self: V, axis_name: AxisName, axis_index: AxisIndex ) -> V: ... def on_remove_axis( - self: V, axis_index: AxisIndex, axis_name: AxisName | None + self: V, axis_name: AxisName, axis_index: AxisIndex ) -> V: ... def __jax_array__(self): @@ -846,13 +870,17 @@ def get_metadata(self) -> dict[str, tp.Any]: del metadata['value'] return metadata - def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): + def add_axis(self, axis_name: AxisName, axis_index: AxisIndex): + if not hasattr(self, 'add_axis_hooks'): + raise ValueError(f'No add_axis_hooks found for VariableState: {self}') for hook in self.add_axis_hooks: - hook(self, axis_index, axis_name) + hook(self, axis_name, axis_index) - def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): + def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex): + if not hasattr(self, 'remove_axis_hooks'): + raise ValueError(f'No remove_axis_hooks found for VariableState: {self}') for hook in self.remove_axis_hooks: - hook(self, axis_index, axis_name) + hook(self, axis_name, axis_index) def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool): diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py index 27f2927fd9..72d42eb6d4 100644 --- a/tests/nnx/bridge/wrappers_test.py +++ b/tests/nnx/bridge/wrappers_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' from absl.testing import absltest import flax @@ -26,12 +24,6 @@ class TestCompatibility(absltest.TestCase): - def setUp(self): - super().setUp() - dim1 = max(jax.device_count() // 2, 1) - device_mesh = np.array(jax.devices()).reshape(dim1, jax.device_count() // dim1) - self.mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=('in', 'out')) - def test_functional(self): # Functional API for NNX Modules functional = bridge.functional(nnx.Linear)(32, 64) @@ -143,35 +135,21 @@ def vmap_fn(inner, x): def test_linen_to_nnx_metadata(self): linen_module = nn.Dense( features=64, - kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros_init(), ('out-alias',), - rules=(('out-alias', 'out'),)), - ) + kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out'))) x = jax.numpy.ones((1, 32)) linen_vars = linen_module.init(jax.random.key(0), x) - - @nnx.jit - def create_sharded_nnx_module(x): - model = bridge.lazy_init(bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)), x) - state = nnx.state(model) - sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state)) - nnx.update(model, sharded_state) - return model - with self.mesh: - nnx_model = create_sharded_nnx_module(x) - - # nn.Partitioned metadata boxes translated into valid nnx.Variable boxes. + nnx_model = bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) + # nn.Partitioned metadata box is translated into a valid nnx.Variable / VariableState box. self.assertIsInstance(linen_vars['params']['kernel'], nn.Partitioned) - self.assertIsInstance(linen_vars['params']['bias'], nn.LogicallyPartitioned) self.assertIsInstance(nnx_model.params['kernel'], nnx.Variable) + np.testing.assert_array_equal(linen_vars['params']['kernel'].value, + nnx_model.params['kernel'].value) assert nnx_model.params['kernel'].sharding == ('in', 'out') - assert nnx_model.params['kernel'].value.sharding.is_equivalent_to( - jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('in', 'out')), ndim=2) - - assert nnx_model.params['bias'].sharding == ('out-alias',) - assert nnx_model.params['bias'].sharding_rules == (('out-alias', 'out'),) - assert nnx_model.params['bias'].value.sharding.is_equivalent_to( - jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('out',)), ndim=1) + _, nnx_state = nnx.split(nnx_model) + self.assertIsInstance(nnx_state['params']['kernel'], nnx.VariableState) + np.testing.assert_array_equal(linen_vars['params']['kernel'].value, + nnx_state['params']['kernel'].value) + assert nnx_state['params']['kernel'].sharding == ('in', 'out') ################## @@ -328,9 +306,7 @@ class LinenMiddle(nn.Module): @nn.compact def __call__(self, x): dot = bridge.to_linen(NNXInner, x.shape[-1], self.dout, self.dropout_rate, name='dot') - logical_init = nn.with_logical_partitioning( - nn.initializers.lecun_normal(), ('out-alias',), rules=(('out-alias', 'out'))) - b = self.param('b', logical_init, (1, self.dout)) + b = self.param('b', nn.initializers.lecun_normal(), (1, self.dout)) return dot(x) + b class NNXOuter(nnx.Module): @@ -359,7 +335,6 @@ def __call__(self, x): self.assertIsInstance(w, nnx.Param) np.testing.assert_allclose(model(x), x @ w + b) assert hasattr(w, 'sharding') and w.sharding == ('in', 'out') - assert hasattr(b, 'sharding') and b.sharding == ('out-alias', ) def test_linen_nnx_linen(self): # TODO: add when we can safely `lazy_init` the NNX module inside `ToLinen` without diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index a3f7bf8c22..d5aeae08cd 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -17,7 +17,7 @@ from typing import Any, TypeVar from absl.testing import absltest -from flax import nnx, errors +from flax import nnx import jax import jax.numpy as jnp import numpy as np @@ -39,7 +39,7 @@ def test_trace_level(self): @jax.jit def f(): with self.assertRaisesRegex( - errors.TraceContextError, + nnx.errors.TraceContextError, "Cannot mutate 'Dict' from different trace level", ): m.a = 2 diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py index eeb65ccaed..0e42918264 100644 --- a/tests/nnx/rngs_test.py +++ b/tests/nnx/rngs_test.py @@ -21,7 +21,6 @@ from absl.testing import absltest from flax import nnx -from flax import errors class TestRngs(absltest.TestCase): @@ -59,7 +58,7 @@ def test_rng_trace_level_constraints(self): @jax.jit def f(): with self.assertRaisesRegex( - errors.TraceContextError, + nnx.errors.TraceContextError, 'Cannot call RngStream from a different trace level', ): rngs.params() @@ -77,7 +76,7 @@ def h(): self.assertIsInstance(rngs1, nnx.Rngs) with self.assertRaisesRegex( - errors.TraceContextError, + nnx.errors.TraceContextError, 'Cannot call RngStream from a different trace level', ): rngs1.params() diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 6a202e8135..15808e0800 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -100,64 +100,6 @@ def __call__(self, x): assert state_spec.opt_state[0].mu['w'].value == PartitionSpec('row', 'col') assert state_spec.opt_state[0].nu['w'].value == PartitionSpec('row', 'col') - def test_add_remove_axis_in_transform(self): - test = self - kadds, kremoves, badds, bremoves = [], [], [], [] - class MLP(nnx.Module): - - @nnx.split_rngs(splits=5) - @nnx.vmap( - in_axes=(0, 0), - transform_metadata={nnx.PARTITION_NAME: 'layers'}, - ) - def __init__(self, rngs: nnx.Rngs): - self.linear = nnx.Linear( - 3, - 3, - kernel_init=nnx.with_metadata( - nnx.initializers.lecun_normal(), sharding=('din', 'dout'), - add_axis_hooks=lambda _, idx, name: kadds.append((idx, name)), - remove_axis_hooks=lambda _, idx, name: kremoves.append((idx, name)), - ), - bias_init=nnx.with_metadata( - nnx.initializers.zeros_init(), # no sharding annotation here! - add_axis_hooks=lambda _, idx, name: badds.append((idx, name)), - remove_axis_hooks=lambda _, idx, name: bremoves.append((idx, name)), - ), - rngs=rngs, - ) - - @nnx.scan( - in_axes=(0, nnx.Carry), - transform_metadata={nnx.PARTITION_NAME: 'layers'} - ) - def __call__(self, x: jax.Array): - x = self.linear(x) - # test sharding layer axes is not present inside scan - test.assertEqual(self.linear.kernel.shape, (3, 3)) - test.assertEqual(self.linear.kernel.sharding, ('din', 'dout')) - # at least a remove_axis was already called to remove the layer axis - test.assertEqual(kremoves[-1], (0, 'layers')) - test.assertEqual(bremoves[-1], (0, 'layers')) - return x, None - - m = MLP(rngs=nnx.Rngs(0)) - self.assertEqual(m.linear.kernel.shape, (5, 3, 3)) - self.assertEqual(m.linear.kernel.sharding, ('layers', 'din', 'dout')) - self.assertEqual(m.linear.bias.shape, (5, 3)) - # One add_axis called to add the `nnx.vmap` dimension - self.assertEqual(kadds, [(0, 'layers')]) - self.assertEqual(kremoves, []) - self.assertEqual(badds, [(0, 'layers')]) - self.assertEqual(bremoves, []) - - # One remove_axis and one add_axis called when in and out of `nnx.scan` - y = m(jnp.ones((5, 3))) - self.assertEqual(kadds, [(0, 'layers'), (0, 'layers')]) - self.assertEqual(kremoves, [(0, 'layers')]) - self.assertEqual(badds, [(0, 'layers'), (0, 'layers')]) - self.assertEqual(bremoves, [(0, 'layers')]) - if __name__ == '__main__': absltest.main() diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 824e7b6b0e..be487628fe 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -323,7 +323,7 @@ def f(m: Foo): def test_apply_shardings(self): n_devices = max(jax.local_device_count() // 2, 1) - devices = mesh_utils.create_device_mesh((n_devices, jax.local_device_count() // n_devices)) + devices = mesh_utils.create_device_mesh((n_devices, n_devices)) mesh = jax.sharding.Mesh(devices, ('a', 'b')) def sharding(*args): @@ -2235,27 +2235,6 @@ def forward(model, x): self.assertEqual(y.shape, (5, 4, 3)) - def test_metadata(self): - @nnx.vmap( - in_axes=(None,), - out_axes=0, - axis_size=5, - transform_metadata={nnx.spmd.PARTITION_NAME: 'c'}, - ) - def create_block(rngs: nnx.Rngs): - return nnx.Linear( - 16, - 32, - rngs=rngs, - kernel_init=nnx.with_partitioning( - nnx.initializers.lecun_normal(), ('a', 'b') - ), - ) - - m = create_block(nnx.Rngs(0)) - self.assertEqual(m.kernel.value.shape, (5, 16, 32)) - self.assertEqual(m.kernel.sharding, ('c', 'a', 'b')) - class TestPmap(absltest.TestCase): diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index b3135a9e37..e2ded604d5 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -88,8 +88,10 @@ if $RUN_DOCTEST; then echo "=== RUNNING DOCTESTS ===" # test doctest sphinx-build -M doctest docs docs/_build -T + sphinx-build -M doctest docs_nnx docs_nnx/_build -T # test build html sphinx-build -M html docs docs/_build -T + sphinx-build -M html docs_nnx docs_nnx/_build -T # test docstrings pytest -n auto flax \ --doctest-modules \ diff --git a/uv.lock b/uv.lock index 29d358e255..5dbc9e8070 100644 --- a/uv.lock +++ b/uv.lock @@ -767,7 +767,7 @@ wheels = [ [[package]] name = "flax" -version = "0.9.0" +version = "0.8.6" source = { editable = "." } dependencies = [ { name = "jax" }, @@ -809,9 +809,7 @@ docs = [ testing = [ { name = "clu" }, { name = "einops" }, - { name = "gymnasium" }, - { name = "gymnasium", extra = ["accept-rom-license"] }, - { name = "gymnasium", extra = ["atari"] }, + { name = "gymnasium", extra = ["accept-rom-license", "atari"] }, { name = "jaxlib" }, { name = "jaxtyping" }, { name = "jraph" }, @@ -1046,11 +1044,9 @@ wheels = [ [package.optional-dependencies] accept-rom-license = [ - { name = "autorom" }, { name = "autorom", extra = ["accept-rom-license"] }, ] atari = [ - { name = "shimmy" }, { name = "shimmy", extra = ["atari"] }, ] @@ -3591,6 +3587,9 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 }, { url = "https://files.pythonhosted.org/packages/33/3e/a2f59384587eff6aeb7d37b6780de7fedd2214935e27520430ca9f5b7975/triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c", size = 209438883 }, { url = "https://files.pythonhosted.org/packages/fe/7b/7757205dee3628f75e7991021d15cd1bd0c9b044ca9affe99b50879fc0e1/triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb", size = 209464695 }, + { url = "https://files.pythonhosted.org/packages/15/67/84e5a4b7b45bdeb11da26a67dfa2b988c512abbcbcad8cbc30aa579051b2/triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230", size = 209380247 }, + { url = "https://files.pythonhosted.org/packages/ea/6b/1d72cc8a7379822dadf050474add7d8b73b02c35057446b6f17d27cb9ea2/triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e", size = 209442823 }, + { url = "https://files.pythonhosted.org/packages/ae/b2/048c9ecfdba0e6b0ae3c02eed2d9dd3e9e990a6d46da98555cf0c2232168/triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253", size = 209468633 }, ] [[package]]