From 5cf2ed7d2c5e878f2531b53317dcc3860b16835f Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Thu, 27 Oct 2022 05:08:45 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 484223848 --- README.md | 13 ++++--- lit.ipynb | 1 + setup.py | 2 +- vit_jax/momentum_clip.py | 77 ---------------------------------------- vit_jax/train.py | 53 +++++++++++++++------------ 5 files changed, 41 insertions(+), 105 deletions(-) delete mode 100644 vit_jax/momentum_clip.py diff --git a/README.md b/README.md index 3584366..1ed902c 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ Table of contents: - [MLP-Mixer](#mlp-mixer) - [Available Mixer models](#available-mixer-models) - [Expected Mixer results](#expected-mixer-results) - - [LiT models](#lit-models) + - [LiT models](#lit-models) - [Running on cloud](#running-on-cloud) - [Create a VM](#create-a-vm) - [Setup VM](#setup-vm) @@ -296,11 +296,16 @@ ImageNet-21k | Mixer-L/16 | cifar10 | 98.34% | 10.0h | [tensorboard. ## LiT models -We have just published a post on the Google AI blog -[LiT: adding language understanding to image models](http://ai.googleblog.com/2022/04/locked-image-tuning-adding-language.html) -about our new CVPR paper "LiT: Zero-Shot Transfer with Locked-image text Tuning" +For details, refer to the Google AI blog post +[LiT: adding language understanding to image models](http://ai.googleblog.com/2022/04/locked-image-tuning-adding-language.html), +or read the CVPR paper "LiT: Zero-Shot Transfer with Locked-image text Tuning" (https://arxiv.org/abs/2111.07991). +We published a Transformer B/16-base model with an ImageNet zeroshot accuracy of +72.1%, and a L/16-large model with an ImageNet zeroshot accuracy of 75.7%. For +more details about these models, please refer to the +[LiT model card](model_cards/lit.md). + We provide a in-browser demo with small text encoders for interactive use (the smallest models should even run on a modern cell phone): diff --git a/lit.ipynb b/lit.ipynb index e1fd7cc..6f20d7f 100644 --- a/lit.ipynb +++ b/lit.ipynb @@ -706,6 +706,7 @@ "# # described in\n", "# # https://www.tensorflow.org/datasets/catalog/imagenet2012\n", "# # and then replace `data_dir` below with that GCS bucket.\n", + "# # If you get a `PermissionDeniedError`, try restarting the kernel.\n", "# from google.colab import auth\n", "# auth.authenticate_user() # Required to access access protected GCS buckets.\n", "# builder = tfds.builder('imagenet2012', data_dir='gs://tensorflow-datasets/datasets')\n", diff --git a/setup.py b/setup.py index 1576ccf..df4f8f5 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ 'absl-py', 'clu', 'einops', - 'flax==0.5.3', # requires deprecated flax.optim 😱 + 'flax', 'flaxformer @ git+https://github.com/google/flaxformer', 'jax', 'ml-collections', diff --git a/vit_jax/momentum_clip.py b/vit_jax/momentum_clip.py deleted file mode 100644 index c5ffe44..0000000 --- a/vit_jax/momentum_clip.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2022 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import flax -import flax.optim -import jax -import jax.numpy as jnp -import numpy as np - - -class Optimizer(flax.optim.OptimizerDef): - """Momentum optimizer that stores state using half-precision.""" - - @flax.struct.dataclass - class HyperParams: - learning_rate: np.ndarray - beta: np.ndarray - grad_norm_clip: np.ndarray - - @flax.struct.dataclass - class State: - momentum: np.ndarray - - def __init__(self, - learning_rate=None, - beta=0.9, - dtype='bfloat16', - grad_norm_clip=None): - hyper_params = Optimizer.HyperParams(learning_rate, beta, grad_norm_clip) - super().__init__(hyper_params) - self.dtype = dict(bfloat16=jnp.bfloat16, float32=jnp.float32)[dtype] - - def init_param_state(self, param): - return Optimizer.State(jnp.zeros_like(param, dtype=self.dtype)) - - def apply_gradient(self, hyper_params, params, state, grads): - step = state.step - params_flat, treedef = jax.tree_flatten(params) - states_flat = treedef.flatten_up_to(state.param_states) - grads_flat = treedef.flatten_up_to(grads) - - # Optionally resize the global gradient to a maximum norm. - if hyper_params.grad_norm_clip: - grads_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat])) - grads_factor = jnp.minimum(1.0, hyper_params.grad_norm_clip / grads_l2) - grads_flat = jax.tree_map(lambda param: grads_factor * param, grads_flat) - - out = [ - self.apply_param_gradient(step, hyper_params, param, state, grad) - for param, state, grad in zip(params_flat, states_flat, grads_flat) - ] - - new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ()) - new_params = jax.tree_unflatten(treedef, new_params_flat) - new_param_states = jax.tree_unflatten(treedef, new_states_flat) - new_state = flax.optim.OptimizerState(step + 1, new_param_states) - return new_params, new_state - - def apply_param_gradient(self, step, hyper_params, param, state, grad): - del step - assert hyper_params.learning_rate is not None, 'no learning rate provided.' - momentum = state.momentum - new_momentum = hyper_params.beta * momentum + grad - new_param = param - hyper_params.learning_rate * new_momentum - new_state = Optimizer.State(new_momentum.astype(self.dtype)) - return new_param, new_state diff --git a/vit_jax/train.py b/vit_jax/train.py index 6131bf5..fa546e0 100644 --- a/vit_jax/train.py +++ b/vit_jax/train.py @@ -25,19 +25,19 @@ import jax.numpy as jnp import ml_collections import numpy as np +import optax import tensorflow as tf from vit_jax import checkpoint from vit_jax import input_pipeline from vit_jax import models -from vit_jax import momentum_clip from vit_jax import utils -def make_update_fn(*, apply_fn, accum_steps, lr_fn): +def make_update_fn(*, apply_fn, accum_steps, tx): """Returns update step for data parallel training.""" - def update_fn(opt, step, batch, rng): + def update_fn(params, opt_state, batch, rng): _, new_rng = jax.random.split(rng) # Bind the rng key to the device id (which is unique across hosts) @@ -58,13 +58,14 @@ def loss_fn(params, images, labels): return cross_entropy_loss(logits=logits, labels=labels) l, g = utils.accumulate_gradient( - jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'], + jax.value_and_grad(loss_fn), params, batch['image'], batch['label'], accum_steps) g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g) + updates, opt_state = tx.update(g, opt_state) + params = optax.apply_updates(params, updates) l = jax.lax.pmean(l, axis_name='batch') - opt = opt.apply_gradient(g, learning_rate=lr_fn(step)) - return opt, l, new_rng + return params, opt_state, l, new_rng return jax.pmap(update_fn, axis_name='batch', donate_argnums=(0,)) @@ -130,28 +131,33 @@ def init_model(): model_config=config.model) total_steps = config.total_steps + lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr, config.decay_type, config.warmup_steps) + tx = optax.chain( + optax.clip_by_global_norm(config.grad_norm_clip), + optax.sgd( + learning_rate=lr_fn, + momentum=0.9, + accumulator_dtype='bfloat16', + ), + ) update_fn_repl = make_update_fn( - apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn) + apply_fn=model.apply, accum_steps=config.accum_steps, tx=tx) infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False)) - # Create optimizer and replicate it over all TPUs/GPUs - opt = momentum_clip.Optimizer( - dtype=config.optim_dtype, - grad_norm_clip=config.grad_norm_clip).create(params) - initial_step = 1 - opt, initial_step = flax_checkpoints.restore_checkpoint( - workdir, (opt, initial_step)) + opt_state = tx.init(params) + params, opt_state, initial_step = flax_checkpoints.restore_checkpoint( + workdir, (params, opt_state, initial_step)) logging.info('Will start/continue training at initial_step=%d', initial_step) - opt_repl = flax.jax_utils.replicate(opt) + params_repl, opt_state_repl = flax.jax_utils.replicate((params, opt_state)) # Delete references to the objects that are not needed anymore - del opt + del opt_state del params # Prepare the learning-rate and pre-fetch it to device to avoid delays. @@ -175,8 +181,8 @@ def init_model(): input_pipeline.prefetch(ds_train, config.prefetch)): with jax.profiler.StepTraceAnnotation('train', step_num=step): - opt_repl, loss_repl, update_rng_repl = update_fn_repl( - opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl) + params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl( + params_repl, opt_state_repl, batch, update_rng_repl) for hook in hooks: hook(step) @@ -197,7 +203,7 @@ def init_model(): train_loss=float(flax.jax_utils.unreplicate(loss_repl)), img_sec_core_train=img_sec_core_train)) done = step / total_steps - logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' # pylint: disable=logging-format-interpolation + logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' # pylint: disable=logging-fstring-interpolation f'img/sec/core: {img_sec_core_train:.1f}, ' f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h') @@ -209,7 +215,7 @@ def init_model(): lt0 = time.time() for test_batch in input_pipeline.prefetch(ds_test, config.prefetch): logits = infer_fn_repl( - dict(params=opt_repl.target), test_batch['image']) + dict(params=params_repl), test_batch['image']) accuracies.append( (np.argmax(logits, axis=-1) == np.argmax(test_batch['label'], @@ -221,7 +227,7 @@ def init_model(): lt0 = time.time() lr = float(lr_fn(step)) - logging.info(f'Step: {step} ' # pylint: disable=logging-format-interpolation + logging.info(f'Step: {step} ' # pylint: disable=logging-fstring-interpolation f'Learning rate: {lr:.7f}, ' f'Test accuracy: {accuracy_test:0.5f}, ' f'img/sec/core: {img_sec_core_test:.1f}') @@ -236,8 +242,9 @@ def init_model(): if ((config.checkpoint_every and step % config.eval_every == 0) or step == total_steps): checkpoint_path = flax_checkpoints.save_checkpoint( - workdir, (flax.jax_utils.unreplicate(opt_repl), step), step) + workdir, (flax.jax_utils.unreplicate(params_repl), + flax.jax_utils.unreplicate(opt_state_repl), step), step) logging.info('Stored checkpoint at step %d to "%s"', step, checkpoint_path) - return flax.jax_utils.unreplicate(opt_repl) + return flax.jax_utils.unreplicate(params_repl)