Skip to content

Commit

Permalink
Add text classification example (#134)
Browse files Browse the repository at this point in the history
* Add text classification example.

* update containers for tests
  • Loading branch information
dfalbel authored Aug 29, 2023
1 parent 5d7f08e commit e98c62f
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jobs:
name: 'gpu'

container:
image: nvidia/cuda:11.7.0-cudnn8-devel-ubuntu18.04
image: nvidia/cuda:11.7.1-cudnn8-devel-ubuntu18.04
options: --runtime=nvidia --gpus all

env:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-coverage-pak.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
runs-on: ['self-hosted', 'gce', 'gpu']

container:
image: nvidia/cuda:11.6.0-cudnn8-devel-ubuntu18.04
image: nvidia/cuda:11.7.1-cudnn8-devel-ubuntu18.04
options: --gpus all

env:
Expand Down
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,9 @@ doc
Meta
mnist
pets
Untitled.py
aclImdb
hello.txt
imdb
tokenizer-20000.json
tokenizer.json
166 changes: 166 additions & 0 deletions vignettes/examples/text-classification.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
---
title: "Text classification from scratch"
desc: "Implements text classification from scratch."
category: 'basic'
editor_options:
chunk_output_type: console
---

This example is a port of ['Text classification from scratch'](https://keras.io/examples/nlp/text_classification_from_scratch/) from Keras documentation
by Mark Omerick and François Chollet.

First we implement a torch dataset that downloads and pre-process the data.
The initialize method is called when we instantiate a dataset.
Our implementation:

- Downloads the IMDB dataset if it doesn't exist in the `root` directory.
- Extracts the files into `root`.
- Creates a tokenizer using the files in the training set.

We also implement the `.getitem` method that is used to extract a single
element from the dataset and pre-process the file contents.

```{r, eval = FALSE}
library(torch)
library(tok)
library(luz)
vocab_size <- 20000 # maximum number of items in the vocabulary
output_length <- 500 # padding and truncation length.
embedding_dim <- 128 # size of the embedding vectors
imdb_dataset <- dataset(
initialize = function(output_length, vocab_size, root, split = "train", download = TRUE) {
url <- "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
fpath <- file.path(root, "aclImdb")
# download if file doesn't exist yet
if (!dir.exists(fpath) && download) {
# download into tempdir, then extract and move to the root dir
withr::with_tempfile("file", {
download.file(url, file)
untar(file, exdir = root)
})
}
# now list files for the split
self$data <- rbind(
data.frame(
fname = list.files(file.path(fpath, split, "pos"), full.names = TRUE),
y = 1
),
data.frame(
fname = list.files(file.path(fpath, split, "neg"), full.names = TRUE),
y = 0
)
)
# train a tokenizer on the train data (if one doesn't exist yet)
tokenizer_path <- file.path(root, glue::glue("tokenizer-{vocab_size}.json"))
if (!file.exists(tokenizer_path)) {
self$tok <- tok::tokenizer$new(tok::model_bpe$new())
self$tok$pre_tokenizer <- tok::pre_tokenizer_whitespace$new()
files <- list.files(file.path(fpath, "train"), recursive = TRUE, full.names = TRUE)
self$tok$train(files, tok::trainer_bpe$new(vocab_size = vocab_size))
self$tok$save(tokenizer_path)
} else {
self$tok <- tok::tokenizer$from_file(tokenizer_path)
}
self$tok$enable_padding(length = output_length)
self$tok$enable_truncation(max_length = output_length)
},
.getitem = function(i) {
item <- self$data[i,]
# takes item i, reads the file content into a char string
# then makes everything lower case and removes html + punctuaction
# next uses the tokenizer to encode the text.
text <- item$fname %>%
readr::read_file() %>%
stringr::str_to_lower() %>%
stringr::str_replace_all("<br />", " ") %>%
stringr::str_remove_all("[:punct:]") %>%
self$tok$encode()
list(
x = text$ids + 1L,
y = item$y
)
},
.length = function() {
nrow(self$data)
}
)
train_ds <- imdb_dataset(output_length, vocab_size, "./imdb", split = "train")
test_ds <- imdb_dataset(output_length, vocab_size, "./imdb", split = "test")
```

We now define the model we want to train. The model is a 1D convnet starting with
an embedding layer and we plug a classifier at the output.

```{r, eval = FALSE}
model <- nn_module(
initialize = function(vocab_size, embedding_dim) {
self$embedding <- nn_sequential(
nn_embedding(num_embeddings = vocab_size, embedding_dim = embedding_dim),
nn_dropout(0.5)
)
self$convs <- nn_sequential(
nn_conv1d(embedding_dim, 128, kernel_size = 7, stride = 3, padding = "valid"),
nn_relu(),
nn_conv1d(128, 128, kernel_size = 7, stride = 3, padding = "valid"),
nn_relu(),
nn_adaptive_max_pool2d(c(128, 1)) # reduces the length dimension
)
self$classifier <- nn_sequential(
nn_flatten(),
nn_linear(128, 128),
nn_relu(),
nn_dropout(0.5),
nn_linear(128, 1)
)
},
forward = function(x) {
emb <- self$embedding(x)
out <- emb$transpose(2, 3) %>%
self$convs() %>%
self$classifier()
# we drop the last so we get (B) instead of (B, 1)
out$squeeze(2)
}
)
# test the model for a single example batch
# m <- model(vocab_size, embedding_dim)
# x <- torch_randint(1, 20000, size = c(32, 500), dtype = "int")
# m(x)
```

We can finally train the model:

```{r, eval = FALSE}
fitted_model <- model %>%
setup(
loss = nnf_binary_cross_entropy_with_logits,
optimizer = optim_adam,
metrics = luz_metric_binary_accuracy_with_logits()
) %>%
set_hparams(vocab_size = vocab_size, embedding_dim = embedding_dim) %>%
fit(train_ds, epochs = 3)
```

We can finally obtain the metrics on the test dataset:

```{r, eval = FALSE}
fitted_model %>% evaluate(test_ds)
```

Remember that in order to predict for texts, we need make the same pre-processing
as used in the dataset definition.

0 comments on commit e98c62f

Please sign in to comment.