-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add chargpt model * fix generation example
- Loading branch information
Showing
1 changed file
with
157 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
--- | ||
title: "CharGPT" | ||
desc: "Train a character-level GPT-2 on Shakespeare texts." | ||
category: 'intermediate' | ||
--- | ||
|
||
This example is inspired by the [chargpt](https://github.com/karpathy/minGPT/tree/master/projects/chargpt) project by Andrey Karpathy. | ||
We are going to train character-level language model on Shakespeare texts. | ||
|
||
We first load the libraries that we plan to use: | ||
|
||
```{r, eval = FALSE} | ||
library(torch) | ||
library(luz) | ||
library(zeallot) | ||
``` | ||
|
||
Next we define the torch dataset that will pre-process data for the model. It splits the text into a character vector, | ||
each element containing exactly one character. | ||
|
||
Then lists all unique characters into the `vocab` attribute. The order of the characters in the vocabulary is used to | ||
encode each character to an integer value, that will be used in the embedding layer. | ||
|
||
The `.getitem()` method, can take chunks of `block_size` characters and encode them into their integer representation. | ||
|
||
```{r, eval=FALSE} | ||
url <- "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" | ||
char_dataset <- torch::dataset( | ||
initialize = function(data, block_size = 128) { | ||
self$block_size <- block_size | ||
self$data <- stringr::str_split_1(data, "") | ||
self$data_size <- length(self$data) | ||
self$vocab <- unique(self$data) | ||
self$vocab_size <- length(self$vocab) | ||
}, | ||
.getitem = function(i) { | ||
chunk <- self$data[i + seq_len(self$block_size + 1)] | ||
idx <- match(chunk, self$vocab) | ||
list( | ||
x = head(idx, self$block_size), | ||
y = tail(idx, self$block_size) | ||
) | ||
}, | ||
.length = function() { | ||
self$data_size - self$block_size - 1L # this is to account the last value | ||
} | ||
) | ||
dataset <- char_dataset(readr::read_file(url)) | ||
dataset[1] # this allows us to see an element of the dataset | ||
``` | ||
|
||
We then define the neural net we are going to train. Defining a GPT-2 model is quite verbose, so we are going | ||
to use the minhub implementation directly. You can find the full model definition [here](https://github.com/mlverse/minhub/blob/main/R/gpt2.R), | ||
and this code is entirely self-contained, so you don't need to install minhub, if you don't want to. | ||
|
||
We also implemented the `generate` method for the model, that allows one to generate completions using the model. | ||
It applies the model in a loop, at each iteration prediction what's the next character. | ||
|
||
```{r, eval=FALSE} | ||
model <- torch::nn_module( | ||
initialize = function(vocab_size) { | ||
# remotes::install_github("mlverse/minhub") | ||
self$gpt <- minhub::gpt2( | ||
vocab_size = vocab_size, | ||
n_layer = 6, | ||
n_head = 6, | ||
n_embd = 192 | ||
) | ||
}, | ||
forward = function(x) { | ||
# we have to transpose to make the vocabulary the last dimension | ||
self$gpt(x)$transpose(2,3) | ||
}, | ||
generate = function(x, temperature = 1, iter = 50, top_k = 10) { | ||
# samples from the model givn a context vector. | ||
for (i in seq_len(iter)) { | ||
logits <- self$forward(x)[,,-1] | ||
logits <- logits/temperature | ||
c(prob, ind) %<-% logits$topk(top_k) | ||
logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob) | ||
logits <- nnf_softmax(logits, dim = -1) | ||
id_next <- torch_multinomial(logits, num_samples = 1) | ||
x <- torch_cat(list(x, id_next), dim = 2) | ||
} | ||
x | ||
} | ||
) | ||
``` | ||
|
||
Next, we implemented a callback that is used for nicely displaying generated samples during the model | ||
training: | ||
|
||
```{r, eval=FALSE} | ||
# samples from the model using the context. | ||
generate <- function(model, vocab, context, ...) { | ||
local_no_grad() # disables gradient for sampling | ||
x <- match(stringr::str_split_1(context, ""), vocab) | ||
x <- torch_tensor(x)[NULL,]$to(device = model$device) | ||
content <- as.integer(model$generate(x, ...)$cpu()) | ||
paste0(vocab[content], collapse = "") | ||
} | ||
display_cb <- luz_callback( | ||
initialize = function(iter = 500) { | ||
self$iter <- iter # print every 500 iterations | ||
}, | ||
on_train_batch_end = function() { | ||
if (!(ctx$iter %% self$iter == 0)) | ||
return() | ||
ctx$model$eval() | ||
with_no_grad({ | ||
# sample from the model... | ||
context <- "O God, O God!" | ||
text <- generate(ctx$model, dataset$vocab, context, iter = 100) | ||
cli::cli_h3(paste0("Iter ", ctx$iter)) | ||
cli::cli_text(text) | ||
}) | ||
} | ||
) | ||
``` | ||
|
||
Finally, you can train the model using `fit`: | ||
|
||
```{r, eval=FALSE} | ||
fitted <- model |> | ||
setup( | ||
loss = nn_cross_entropy_loss(), | ||
optimizer = optim_adam | ||
) |> | ||
set_opt_hparams(lr = 5e-4) |> | ||
set_hparams(vocab_size = dataset$vocab_size) |> | ||
fit( | ||
dataset, | ||
dataloader_options = list(batch_size = 128, shuffle = TRUE), | ||
epochs = 1, | ||
callbacks = list( | ||
display_cb(iter = 500), | ||
luz_callback_gradient_clip(max_norm = 1) | ||
) | ||
) | ||
``` | ||
|
||
One epoch, is reasonable for this dataset and takes ~1h on the M1 MBP. | ||
You can generate new samples with: | ||
|
||
```{r, eval=FALSE} | ||
context <- "O God, O God!" | ||
text <- generate(fitted$model, dataset$vocab, context, iter = 100) | ||
cat(text) | ||
``` | ||
|
||
|