Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change. #244

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Table of contents:
- [MLP-Mixer](#mlp-mixer)
- [Available Mixer models](#available-mixer-models)
- [Expected Mixer results](#expected-mixer-results)
- [LiT models](#lit-models)
- [LiT models](#lit-models)
- [Running on cloud](#running-on-cloud)
- [Create a VM](#create-a-vm)
- [Setup VM](#setup-vm)
Expand Down Expand Up @@ -296,11 +296,16 @@ ImageNet-21k | Mixer-L/16 | cifar10 | 98.34% | 10.0h | [tensorboard.

## LiT models

We have just published a post on the Google AI blog
[LiT: adding language understanding to image models](http://ai.googleblog.com/2022/04/locked-image-tuning-adding-language.html)
about our new CVPR paper "LiT: Zero-Shot Transfer with Locked-image text Tuning"
For details, refer to the Google AI blog post
[LiT: adding language understanding to image models](http://ai.googleblog.com/2022/04/locked-image-tuning-adding-language.html),
or read the CVPR paper "LiT: Zero-Shot Transfer with Locked-image text Tuning"
(https://arxiv.org/abs/2111.07991).

We published a Transformer B/16-base model with an ImageNet zeroshot accuracy of
72.1%, and a L/16-large model with an ImageNet zeroshot accuracy of 75.7%. For
more details about these models, please refer to the
[LiT model card](model_cards/lit.md).

We provide a in-browser demo with small text encoders for interactive use (the
smallest models should even run on a modern cell phone):

Expand Down
1 change: 1 addition & 0 deletions lit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,7 @@
"# # described in\n",
"# # https://www.tensorflow.org/datasets/catalog/imagenet2012\n",
"# # and then replace `data_dir` below with that GCS bucket.\n",
"# # If you get a `PermissionDeniedError`, try restarting the kernel.\n",
"# from google.colab import auth\n",
"# auth.authenticate_user() # Required to access access protected GCS buckets.\n",
"# builder = tfds.builder('imagenet2012', data_dir='gs://tensorflow-datasets/datasets')\n",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
'absl-py',
'clu',
'einops',
'flax==0.5.3', # requires deprecated flax.optim 😱
'flax',
'flaxformer @ git+https://github.com/google/flaxformer',
'jax',
'ml-collections',
Expand Down
77 changes: 0 additions & 77 deletions vit_jax/momentum_clip.py

This file was deleted.

53 changes: 30 additions & 23 deletions vit_jax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import tensorflow as tf

from vit_jax import checkpoint
from vit_jax import input_pipeline
from vit_jax import models
from vit_jax import momentum_clip
from vit_jax import utils


def make_update_fn(*, apply_fn, accum_steps, lr_fn):
def make_update_fn(*, apply_fn, accum_steps, tx):
"""Returns update step for data parallel training."""

def update_fn(opt, step, batch, rng):
def update_fn(params, opt_state, batch, rng):

_, new_rng = jax.random.split(rng)
# Bind the rng key to the device id (which is unique across hosts)
Expand All @@ -58,13 +58,14 @@ def loss_fn(params, images, labels):
return cross_entropy_loss(logits=logits, labels=labels)

l, g = utils.accumulate_gradient(
jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
jax.value_and_grad(loss_fn), params, batch['image'], batch['label'],
accum_steps)
g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g)
updates, opt_state = tx.update(g, opt_state)
params = optax.apply_updates(params, updates)
l = jax.lax.pmean(l, axis_name='batch')

opt = opt.apply_gradient(g, learning_rate=lr_fn(step))
return opt, l, new_rng
return params, opt_state, l, new_rng

return jax.pmap(update_fn, axis_name='batch', donate_argnums=(0,))

Expand Down Expand Up @@ -130,28 +131,33 @@ def init_model():
model_config=config.model)

total_steps = config.total_steps

lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr,
config.decay_type,
config.warmup_steps)
tx = optax.chain(
optax.clip_by_global_norm(config.grad_norm_clip),
optax.sgd(
learning_rate=lr_fn,
momentum=0.9,
accumulator_dtype='bfloat16',
),
)

update_fn_repl = make_update_fn(
apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn)
apply_fn=model.apply, accum_steps=config.accum_steps, tx=tx)
infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False))

# Create optimizer and replicate it over all TPUs/GPUs
opt = momentum_clip.Optimizer(
dtype=config.optim_dtype,
grad_norm_clip=config.grad_norm_clip).create(params)

initial_step = 1
opt, initial_step = flax_checkpoints.restore_checkpoint(
workdir, (opt, initial_step))
opt_state = tx.init(params)
params, opt_state, initial_step = flax_checkpoints.restore_checkpoint(
workdir, (params, opt_state, initial_step))
logging.info('Will start/continue training at initial_step=%d', initial_step)

opt_repl = flax.jax_utils.replicate(opt)
params_repl, opt_state_repl = flax.jax_utils.replicate((params, opt_state))

# Delete references to the objects that are not needed anymore
del opt
del opt_state
del params

# Prepare the learning-rate and pre-fetch it to device to avoid delays.
Expand All @@ -175,8 +181,8 @@ def init_model():
input_pipeline.prefetch(ds_train, config.prefetch)):

with jax.profiler.StepTraceAnnotation('train', step_num=step):
opt_repl, loss_repl, update_rng_repl = update_fn_repl(
opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)
params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl(
params_repl, opt_state_repl, batch, update_rng_repl)

for hook in hooks:
hook(step)
Expand All @@ -197,7 +203,7 @@ def init_model():
train_loss=float(flax.jax_utils.unreplicate(loss_repl)),
img_sec_core_train=img_sec_core_train))
done = step / total_steps
logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' # pylint: disable=logging-format-interpolation
logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' # pylint: disable=logging-fstring-interpolation
f'img/sec/core: {img_sec_core_train:.1f}, '
f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h')

Expand All @@ -209,7 +215,7 @@ def init_model():
lt0 = time.time()
for test_batch in input_pipeline.prefetch(ds_test, config.prefetch):
logits = infer_fn_repl(
dict(params=opt_repl.target), test_batch['image'])
dict(params=params_repl), test_batch['image'])
accuracies.append(
(np.argmax(logits,
axis=-1) == np.argmax(test_batch['label'],
Expand All @@ -221,7 +227,7 @@ def init_model():
lt0 = time.time()

lr = float(lr_fn(step))
logging.info(f'Step: {step} ' # pylint: disable=logging-format-interpolation
logging.info(f'Step: {step} ' # pylint: disable=logging-fstring-interpolation
f'Learning rate: {lr:.7f}, '
f'Test accuracy: {accuracy_test:0.5f}, '
f'img/sec/core: {img_sec_core_test:.1f}')
Expand All @@ -236,8 +242,9 @@ def init_model():
if ((config.checkpoint_every and step % config.eval_every == 0) or
step == total_steps):
checkpoint_path = flax_checkpoints.save_checkpoint(
workdir, (flax.jax_utils.unreplicate(opt_repl), step), step)
workdir, (flax.jax_utils.unreplicate(params_repl),
flax.jax_utils.unreplicate(opt_state_repl), step), step)
logging.info('Stored checkpoint at step %d to "%s"', step,
checkpoint_path)

return flax.jax_utils.unreplicate(opt_repl)
return flax.jax_utils.unreplicate(params_repl)