Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

**Training script** #23

Open
8 of 9 tasks
arampacha opened this issue Jul 6, 2021 · 5 comments
Open
8 of 9 tasks

**Training script** #23

arampacha opened this issue Jul 6, 2021 · 5 comments
Assignees

Comments

@arampacha
Copy link
Collaborator

arampacha commented Jul 6, 2021

  • - add bf16 support
  • - check if training with bf16 weights works fine
  • - add resuming from ckpt
  • - add wandb tracking
  • - complete adafactor option
  • - figure out how to best utilize profiler for training loop optimization
  • - add gradient accumulation
  • - support iterable datasets and max_steps argument
  • - prefetch generator for dataloader
@arampacha
Copy link
Collaborator Author

Casting weights to bf16 is not recommended and removed for now.

@shpotes
Copy link
Member

shpotes commented Jul 8, 2021

here's the gradient accumulation from the vision_transformer codebase:
https://github.com/google-research/vision_transformer/blob/ba9a85bdc430daf4da7b9da67b486a4e0f5bb278/vit_jax/hyper.py#L77

And here's a small example
https://github.com/google-research/vision_transformer/blob/ba9a85bdc430daf4da7b9da67b486a4e0f5bb278/vit_jax/train.py#L63-L66

@mrinal18
Copy link
Collaborator

mrinal18 commented Jul 8, 2021

for gradient accumulation, i have opened a PR: #29
let me know if we can sync up for this

@celsofranssa
Copy link

Hello,
what are the minimum hardware requirements to run the training script?

@arampacha
Copy link
Collaborator Author

Hi @celsofranssa, the hyperparameters in HF model cards (for example here) are tuned for TPU-v3-8. But you can run the script on GPU adjusting batch size accordingly and mb switching dtype from bfloat16 to float16 for your hardware. Not sure what the minimum requirement would be exactly. You can also consider decreasing block_size if you run out of memory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants