Before anything else, you'll likely want to apply for access to the TPU Research Cloud (TRC). Combined with a Google Cloud free trial, that should allow you to do everything here for free. Once you're in TRC, you need to create a project, then with the name of the new project fill out the form that was emailed to you. Use create_tfrecords.py
from the GPT-NEO repo to prepare your data as tfrecords; I might do a separate guide on that. Another thing you might want to do is fork the mesh-transformer-jax repo to make it easier to add and modify the config files.
-
Install the Google Cloud SDK. We'll need it later.
-
If you didn't make a project and activate TPU access through TRC yet (or if you plan on paying out of pocket), make one now.
-
TPUs use Google Cloud buckets for storage, go ahead and create one now. Make sure it's in the region the TPU VM will be; the email from TRC will tell you which region(s) you can use free TPUs in.
-
You'll need the full pretrained weights in order to fine-tune the model. Download those here.
Now that you have a bucket on the cloud and the weights on your PC, you need to upload the weights to the bucket in two steps:
-
Decompress and extract
GPT-J-6B/step_383500.tar.zstd
so you're left with the uncompressed folder containing the sharded checkpoint. -
Open the Google Cloud SDK and run the following command, replacing the path names as appropriate:
gsutil -m cp -R LOCAL_PATH_TO/step_383500 gs://YOUR-BUCKET
. If that works, the console will show the files being uploaded. Note: Took about 12 hours for me, uploading to the Netherlands from California; hopefully you'll have a better geographic situation than I did! I also initially made the mistake of uploading the still-packed .tar. Don't do that, TPU VMs don't have enough local storage for you to unpack it. To avoid needing to reupload, I had to unpack it in Colab.
You'll want to upload tfrecords of your data as well, you can do that here or through the web interface, but trust me when I say you don't want to upload the nearly 70GB weights through the web interface.
Note that steps 6 and 7, preparing the index and config files, can be done later on by editing the base repo in the VM's text editor. It's more efficient to instead make these changes to your own fork of the repo as follows:
-
In the data folder, create a new file
foo.train.index
, replace foo with whatever you want to refer to your dataset as. For each tfrecord in your bucket that you intend to train with, add the path as a line in the index. Makefoo.val.index
and do the same for your validation dataset (if you have one). See the existing files for examples. -
Duplicate the config file
6B_roto_256.json
, rename it to something appropriate for your project. Open it up and make these edits:
tpu_size
: Change from256
to8
bucket
: Change to your bucketmodel_dir
: Change to the directory you'd like to save your checkpoints intrain_set
andval_set
: Change to the index files from the last stepeval_harness_tasks
: Can be removed if you don't plan on using the eval harnessval_batches
&val_every
&ckpt_every
&keep_every
: Usage should be intuitive. Don't set thefoo_every
values to 0 though or you'll get a divide by zero error. If you don't have aval_set
, just setval_every
to something higher thantotal_steps
.name
: Change to a name for your modelwarmup_steps
,lr
,etc.: see the Learning Rate Notes section at the end of the guide
-
Push the changes to your GitHub repo.
-
Follow this guide up to and including the step "Connect to your Cloud TPU VM".
At this point you should have remote access to the TPU VM!
-
In the new VM terminal, type
git clone https://github.com/kingoflolz/mesh-transformer-jax
(or, preferably, your own fork, after pushing the config and index files) -
Move to the new directory with
cd mesh-transformer-jax
and runpip install -r requirements.txt
. For whatever reason, the requirements file doesn't seem to install the correct version of Jax... that is, it does, but something must override it later and I haven't figured out what That's okay, just runpip install jax==0.2.12
and you'll be all set. -
Finally, run
python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/
. If everything is set up correctly this will begin the fine-tuning process. First the model has to be loaded into memory; whenloading network
displayed on the console it took about 10-15 minutes before the next step, setting up WandB for logging. Option 3 allows you to skip that if you aren't using WandB. A step 1 checkpoint will save, and the real training will start. If you have a small dataset, this will go by quickly; TPU VMs can train at a rate of ~5000 tokens/second. -
You did it! Now don't forget any clean up steps you need to take like shutting down your TPU VM or removing unneeded data in buckets, so that you don't have any unexpected charges from Google later.
This guide is labeled "The Basics", anything we haven't covered so far is out of scope, but go check out the rest of the repository! Try python3 device_sample.py --config=configs/YOUR_CONFIG.json
for a basic sampling interface. Use slim_model.py
to prepare an easier-to-deploy slim version of your new weights for inference. Experiment!
Thanks to nostalgebraist for talking about this! They're the one who explained this part on Discord, I'm just paraphrasing really:
The first thing you want to determine is how long a training epoch will be. gradient_accumulation_steps
is your batch size, it defaults to 16
, nostalgebraist recommends 32
. Your .tfrecord files should have a number in the file name indicating how many sequences are in the dataset. Divide that number by the batch size and the result is how many steps are in an epoch. Now we can write the schedule.
lr
is recommended to be between 1e-5
and 5e-5
, with end_lr
set to 1/5 or 1/10 of lr
. weight_decay
can remain 0.1
. total_steps
should be at least one epoch, longer if you have a validation set to determine your training loss with. warmup_steps
should be 5-10% of total, and finally anneal_steps
should be total_steps - warmup_steps
.
To illustrate: I have a small dataset that tokenized into 1147 sequences as a .tfrecord. Dividing by gradient_accumulation_steps
set to 16
, rounding up to ensure I use all the data, equals 72 steps per epoch. I'll set lr
to 5e-5
, end_lr
to a fifth of that, 1e-5
; that may be too much, it's on the high end of the recommended range. I'll set total_steps
to 72
for one epoch, since I don't have a validation set. Then I'll set anneal_steps
to 65
and warmup_steps
to 7
. Simple as that, but you may need to fiddle with the specifics on your own.