-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add text classification example (#134)
* Add text classification example. * update containers for tests
- Loading branch information
Showing
4 changed files
with
174 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,3 +9,9 @@ doc | |
Meta | ||
mnist | ||
pets | ||
Untitled.py | ||
aclImdb | ||
hello.txt | ||
imdb | ||
tokenizer-20000.json | ||
tokenizer.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
--- | ||
title: "Text classification from scratch" | ||
desc: "Implements text classification from scratch." | ||
category: 'basic' | ||
editor_options: | ||
chunk_output_type: console | ||
--- | ||
|
||
This example is a port of ['Text classification from scratch'](https://keras.io/examples/nlp/text_classification_from_scratch/) from Keras documentation | ||
by Mark Omerick and François Chollet. | ||
|
||
First we implement a torch dataset that downloads and pre-process the data. | ||
The initialize method is called when we instantiate a dataset. | ||
Our implementation: | ||
|
||
- Downloads the IMDB dataset if it doesn't exist in the `root` directory. | ||
- Extracts the files into `root`. | ||
- Creates a tokenizer using the files in the training set. | ||
|
||
We also implement the `.getitem` method that is used to extract a single | ||
element from the dataset and pre-process the file contents. | ||
|
||
```{r, eval = FALSE} | ||
library(torch) | ||
library(tok) | ||
library(luz) | ||
vocab_size <- 20000 # maximum number of items in the vocabulary | ||
output_length <- 500 # padding and truncation length. | ||
embedding_dim <- 128 # size of the embedding vectors | ||
imdb_dataset <- dataset( | ||
initialize = function(output_length, vocab_size, root, split = "train", download = TRUE) { | ||
url <- "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" | ||
fpath <- file.path(root, "aclImdb") | ||
# download if file doesn't exist yet | ||
if (!dir.exists(fpath) && download) { | ||
# download into tempdir, then extract and move to the root dir | ||
withr::with_tempfile("file", { | ||
download.file(url, file) | ||
untar(file, exdir = root) | ||
}) | ||
} | ||
# now list files for the split | ||
self$data <- rbind( | ||
data.frame( | ||
fname = list.files(file.path(fpath, split, "pos"), full.names = TRUE), | ||
y = 1 | ||
), | ||
data.frame( | ||
fname = list.files(file.path(fpath, split, "neg"), full.names = TRUE), | ||
y = 0 | ||
) | ||
) | ||
# train a tokenizer on the train data (if one doesn't exist yet) | ||
tokenizer_path <- file.path(root, glue::glue("tokenizer-{vocab_size}.json")) | ||
if (!file.exists(tokenizer_path)) { | ||
self$tok <- tok::tokenizer$new(tok::model_bpe$new()) | ||
self$tok$pre_tokenizer <- tok::pre_tokenizer_whitespace$new() | ||
files <- list.files(file.path(fpath, "train"), recursive = TRUE, full.names = TRUE) | ||
self$tok$train(files, tok::trainer_bpe$new(vocab_size = vocab_size)) | ||
self$tok$save(tokenizer_path) | ||
} else { | ||
self$tok <- tok::tokenizer$from_file(tokenizer_path) | ||
} | ||
self$tok$enable_padding(length = output_length) | ||
self$tok$enable_truncation(max_length = output_length) | ||
}, | ||
.getitem = function(i) { | ||
item <- self$data[i,] | ||
# takes item i, reads the file content into a char string | ||
# then makes everything lower case and removes html + punctuaction | ||
# next uses the tokenizer to encode the text. | ||
text <- item$fname %>% | ||
readr::read_file() %>% | ||
stringr::str_to_lower() %>% | ||
stringr::str_replace_all("<br />", " ") %>% | ||
stringr::str_remove_all("[:punct:]") %>% | ||
self$tok$encode() | ||
list( | ||
x = text$ids + 1L, | ||
y = item$y | ||
) | ||
}, | ||
.length = function() { | ||
nrow(self$data) | ||
} | ||
) | ||
train_ds <- imdb_dataset(output_length, vocab_size, "./imdb", split = "train") | ||
test_ds <- imdb_dataset(output_length, vocab_size, "./imdb", split = "test") | ||
``` | ||
|
||
We now define the model we want to train. The model is a 1D convnet starting with | ||
an embedding layer and we plug a classifier at the output. | ||
|
||
```{r, eval = FALSE} | ||
model <- nn_module( | ||
initialize = function(vocab_size, embedding_dim) { | ||
self$embedding <- nn_sequential( | ||
nn_embedding(num_embeddings = vocab_size, embedding_dim = embedding_dim), | ||
nn_dropout(0.5) | ||
) | ||
self$convs <- nn_sequential( | ||
nn_conv1d(embedding_dim, 128, kernel_size = 7, stride = 3, padding = "valid"), | ||
nn_relu(), | ||
nn_conv1d(128, 128, kernel_size = 7, stride = 3, padding = "valid"), | ||
nn_relu(), | ||
nn_adaptive_max_pool2d(c(128, 1)) # reduces the length dimension | ||
) | ||
self$classifier <- nn_sequential( | ||
nn_flatten(), | ||
nn_linear(128, 128), | ||
nn_relu(), | ||
nn_dropout(0.5), | ||
nn_linear(128, 1) | ||
) | ||
}, | ||
forward = function(x) { | ||
emb <- self$embedding(x) | ||
out <- emb$transpose(2, 3) %>% | ||
self$convs() %>% | ||
self$classifier() | ||
# we drop the last so we get (B) instead of (B, 1) | ||
out$squeeze(2) | ||
} | ||
) | ||
# test the model for a single example batch | ||
# m <- model(vocab_size, embedding_dim) | ||
# x <- torch_randint(1, 20000, size = c(32, 500), dtype = "int") | ||
# m(x) | ||
``` | ||
|
||
We can finally train the model: | ||
|
||
```{r, eval = FALSE} | ||
fitted_model <- model %>% | ||
setup( | ||
loss = nnf_binary_cross_entropy_with_logits, | ||
optimizer = optim_adam, | ||
metrics = luz_metric_binary_accuracy_with_logits() | ||
) %>% | ||
set_hparams(vocab_size = vocab_size, embedding_dim = embedding_dim) %>% | ||
fit(train_ds, epochs = 3) | ||
``` | ||
|
||
We can finally obtain the metrics on the test dataset: | ||
|
||
```{r, eval = FALSE} | ||
fitted_model %>% evaluate(test_ds) | ||
``` | ||
|
||
Remember that in order to predict for texts, we need make the same pre-processing | ||
as used in the dataset definition. |