Skip to content

Commit

Permalink
Cleanup unused path.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555396460
  • Loading branch information
andsteing authored and copybara-github committed Aug 10, 2023
1 parent ac6e056 commit b81313a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
python-version: ['3.10']
steps:
- name: Cancel previous
uses: styfle/[email protected]
Expand Down
12 changes: 6 additions & 6 deletions vit_jax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def init_model():
filename = best.filename
logging.info('Selected fillename="%s" for "%s" with final_val=%.3f',
filename, model_or_filename, best.final_val)
pretrained_path = os.path.join(config.pretrained_dir,
f'{config.model.model_name}.npz')
else:
# ViT / Mixer papers
filename = config.model.model_name
Expand All @@ -140,7 +138,7 @@ def init_model():
optax.sgd(
learning_rate=lr_fn,
momentum=0.9,
accumulator_dtype='bfloat16',
accumulator_dtype=config.optim_dtype,
),
)

Expand Down Expand Up @@ -212,7 +210,7 @@ def init_model():
(step == total_steps)):

accuracies = []
lt0 = time.time()
tt0 = time.time()
for test_batch in input_pipeline.prefetch(ds_test, config.prefetch):
logits = infer_fn_repl(
dict(params=params_repl), test_batch['image'])
Expand All @@ -223,8 +221,7 @@ def init_model():
accuracy_test = np.mean(accuracies)
img_sec_core_test = (
config.batch_eval * ds_test.cardinality().numpy() /
(time.time() - lt0) / jax.device_count())
lt0 = time.time()
(time.time() - tt0) / jax.device_count())

lr = float(lr_fn(step))
logging.info(f'Step: {step} ' # pylint: disable=logging-fstring-interpolation
Expand All @@ -237,14 +234,17 @@ def init_model():
accuracy_test=accuracy_test,
lr=lr,
img_sec_core_test=img_sec_core_test))
lt0 += time.time() - tt0

# Store checkpoint.
if ((config.checkpoint_every and step % config.eval_every == 0) or
step == total_steps):
tt0 = time.time()
checkpoint_path = flax_checkpoints.save_checkpoint(
workdir, (flax.jax_utils.unreplicate(params_repl),
flax.jax_utils.unreplicate(opt_state_repl), step), step)
logging.info('Stored checkpoint at step %d to "%s"', step,
checkpoint_path)
lt0 += time.time() - tt0

return flax.jax_utils.unreplicate(params_repl)

0 comments on commit b81313a

Please sign in to comment.