An implementation of nanogpt in jax from scratch ( Other than Optax for optimization and Equinox for handling PyTrees ) based on Andrej Karpathy's Let's build GPT Lecture.
- The Shakespeare dataset is in
data
folder. You only need to configure hyper-parameters innanogpt-jax/train.py
as per your test settings and then run :
$ python train.py
- Write DropOut Layers.
- LayerNorm.
- Apply weight initializers.
- Implement Adam.
- Andrej Karpathy's Let's build GPT Lecture.
- From PyTorch to JAX: towards neural net frameworks that purify stateful code .
- For my usecase I did not want to use Haiku or Flax. I wanted something very mimimal. And I found Equinox suitable. I got introduced to Equinox through this Repo by Phil Wang.