Skip to content

Commit

Permalink
add chargpt model example (#138)
Browse files Browse the repository at this point in the history
* add chargpt model

* fix generation example
  • Loading branch information
dfalbel authored Sep 15, 2023
1 parent 7a859d2 commit 68874e3
Showing 1 changed file with 157 additions and 0 deletions.
157 changes: 157 additions & 0 deletions vignettes/examples/chargpt.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
---
title: "CharGPT"
desc: "Train a character-level GPT-2 on Shakespeare texts."
category: 'intermediate'
---

This example is inspired by the [chargpt](https://github.com/karpathy/minGPT/tree/master/projects/chargpt) project by Andrey Karpathy.
We are going to train character-level language model on Shakespeare texts.

We first load the libraries that we plan to use:

```{r, eval = FALSE}
library(torch)
library(luz)
library(zeallot)
```

Next we define the torch dataset that will pre-process data for the model. It splits the text into a character vector,
each element containing exactly one character.

Then lists all unique characters into the `vocab` attribute. The order of the characters in the vocabulary is used to
encode each character to an integer value, that will be used in the embedding layer.

The `.getitem()` method, can take chunks of `block_size` characters and encode them into their integer representation.

```{r, eval=FALSE}
url <- "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
char_dataset <- torch::dataset(
initialize = function(data, block_size = 128) {
self$block_size <- block_size
self$data <- stringr::str_split_1(data, "")
self$data_size <- length(self$data)
self$vocab <- unique(self$data)
self$vocab_size <- length(self$vocab)
},
.getitem = function(i) {
chunk <- self$data[i + seq_len(self$block_size + 1)]
idx <- match(chunk, self$vocab)
list(
x = head(idx, self$block_size),
y = tail(idx, self$block_size)
)
},
.length = function() {
self$data_size - self$block_size - 1L # this is to account the last value
}
)
dataset <- char_dataset(readr::read_file(url))
dataset[1] # this allows us to see an element of the dataset
```

We then define the neural net we are going to train. Defining a GPT-2 model is quite verbose, so we are going
to use the minhub implementation directly. You can find the full model definition [here](https://github.com/mlverse/minhub/blob/main/R/gpt2.R),
and this code is entirely self-contained, so you don't need to install minhub, if you don't want to.

We also implemented the `generate` method for the model, that allows one to generate completions using the model.
It applies the model in a loop, at each iteration prediction what's the next character.

```{r, eval=FALSE}
model <- torch::nn_module(
initialize = function(vocab_size) {
# remotes::install_github("mlverse/minhub")
self$gpt <- minhub::gpt2(
vocab_size = vocab_size,
n_layer = 6,
n_head = 6,
n_embd = 192
)
},
forward = function(x) {
# we have to transpose to make the vocabulary the last dimension
self$gpt(x)$transpose(2,3)
},
generate = function(x, temperature = 1, iter = 50, top_k = 10) {
# samples from the model givn a context vector.
for (i in seq_len(iter)) {
logits <- self$forward(x)[,,-1]
logits <- logits/temperature
c(prob, ind) %<-% logits$topk(top_k)
logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob)
logits <- nnf_softmax(logits, dim = -1)
id_next <- torch_multinomial(logits, num_samples = 1)
x <- torch_cat(list(x, id_next), dim = 2)
}
x
}
)
```

Next, we implemented a callback that is used for nicely displaying generated samples during the model
training:

```{r, eval=FALSE}
# samples from the model using the context.
generate <- function(model, vocab, context, ...) {
local_no_grad() # disables gradient for sampling
x <- match(stringr::str_split_1(context, ""), vocab)
x <- torch_tensor(x)[NULL,]$to(device = model$device)
content <- as.integer(model$generate(x, ...)$cpu())
paste0(vocab[content], collapse = "")
}
display_cb <- luz_callback(
initialize = function(iter = 500) {
self$iter <- iter # print every 500 iterations
},
on_train_batch_end = function() {
if (!(ctx$iter %% self$iter == 0))
return()
ctx$model$eval()
with_no_grad({
# sample from the model...
context <- "O God, O God!"
text <- generate(ctx$model, dataset$vocab, context, iter = 100)
cli::cli_h3(paste0("Iter ", ctx$iter))
cli::cli_text(text)
})
}
)
```

Finally, you can train the model using `fit`:

```{r, eval=FALSE}
fitted <- model |>
setup(
loss = nn_cross_entropy_loss(),
optimizer = optim_adam
) |>
set_opt_hparams(lr = 5e-4) |>
set_hparams(vocab_size = dataset$vocab_size) |>
fit(
dataset,
dataloader_options = list(batch_size = 128, shuffle = TRUE),
epochs = 1,
callbacks = list(
display_cb(iter = 500),
luz_callback_gradient_clip(max_norm = 1)
)
)
```

One epoch, is reasonable for this dataset and takes ~1h on the M1 MBP.
You can generate new samples with:

```{r, eval=FALSE}
context <- "O God, O God!"
text <- generate(fitted$model, dataset$vocab, context, iter = 100)
cat(text)
```


0 comments on commit 68874e3

Please sign in to comment.