diff --git a/examples/data_augmentation/uda/README.md b/examples/data_augmentation/uda/README.md new file mode 100644 index 000000000..eca516d93 --- /dev/null +++ b/examples/data_augmentation/uda/README.md @@ -0,0 +1,215 @@ +## Unsupervised Data Augmentation for Text Classification + +Unsupervised Data Augmentation or UDA is a semi-supervised learning method which achieves state-of-the-art results on a wide variety of language and vision tasks. For details, please refer to the [paper](https://arxiv.org/abs/1904.12848) and the [official repository](https://github.com/google-research/uda). + +In this example, we demonstrate Forte's implementation of UDA using a simple BERT-based text classifier. + +## Quick Start + +### Install the dependencies + +You need to install [texar-pytorch](https://github.com/asyml/texar-pytorch) first. + +You will also need to install `tensor2tensor` if you want to perform back translation on your own data. We will cover this later. + +### Get the IMDB data and back-translation models + +We use the IMDB Text Classification dataset for this example. Use the following script to download the supervised and unsupervised training data to `data/IMDB_raw`. It will also download the pre-trained translation models for back-translation to the directory `back_trans`. + + ```bash +python download.py +``` + +### Preprocess + +You can use the following script to preprocess the data. + + ```bash +python utils/imdb_format.py +``` + +This script does two things. It reads the raw data in TXT format and output two files `train.csv` and `test.csv`. It also splits the training set into sentences for back-translation. This is because the back-translation models are trained on sentences instead of long paragraphs. + +### Generate back-translation data + +**Notice:** back-translation is the most time-consuming step. If you just want to see the results, you can skip to the next section. Translating the whole dataset to French takes ~2 days on a GTX 1080 Ti. It takes another 2 days to translate back to English. + +If you would like to play with the back-translation parameters or work with your own data, you need to generate back-translation data yourself. Here we provide an example of back-translation on the IMDB dataset. + +First, you need to install `tensor2tensor` with Tensorflow 1.13. We provide a `requirements.txt` with the correct versions of dependencies. To install: + +``` +cd back_trans/ +pip install -r requirements.txt +pip install --no-deps tensor2tensor==1.13 +``` + +Then run the following command to run the back-translation (adapted from the original [UDA repo](https://github.com/google-research/uda/blob/master/back_translate/run.sh)): + +``` +cd back_trans/ + +# forward translation +t2t-decoder \ + --problem=translate_enfr_wmt32k \ + --model=transformer \ + --hparams_set=transformer_big \ + --hparams="sampling_method=random,sampling_temp=0.8" \ + --decode_hparams="beam_size=1,batch_size=16" \ + --checkpoint_path=checkpoints/enfr/model.ckpt-500000 \ + --output_dir=/tmp/t2t \ + --decode_from_file=train_split_sent.txt \ + --decode_to_file=forward_gen.txt \ + --data_dir=checkpoints + +# backward translation +t2t-decoder \ + --problem=translate_enfr_wmt32k_rev \ + --model=transformer \ + --hparams_set=transformer_big \ + --hparams="sampling_method=random,sampling_temp=0.8" \ + --decode_hparams="beam_size=1,batch_size=16,alpha=0" \ + --checkpoint_path=checkpoints/fren/model.ckpt-500000 \ + --output_dir=/tmp/t2t \ + --decode_from_file=forward_gen.txt \ + --decode_to_file=backward_gen.txt \ + --data_dir=checkpoints + +# merge sentences back to paragraphs +python merge_back_trans_sentences.py \ +--input_file=backward_gen.txt \ +--output_file=back_translate.txt \ +--doc_len_file=train_doc_len.json +``` + +You can tune the `sampling_temp` parameter. See [here](https://github.com/google-research/uda#guidelines-for-hyperparameters) for more details. + +The final result of the above commands is `back_translate.txt`. Each line in the file is a back translated example corresponding to the same line in `train.csv` (without the header). + +Next, copy `back_translate.txt` to `data/IMDB/`. + +``` +cp back_translate.txt ../data/IMDB/ +``` + +Of course, you can use a different name for the back translation file. Look at `config_data.py` to configure. + +### Download preprocessed and augmented data + +For demonstration purpose, we provide the processed and augmented data files: [download link](https://drive.google.com/file/d/1OKrbS76mbGCIz3FcFQ8-qPpMTQkQy8bP/view?usp=sharing). Place the CSV and txt files in directory `data/IMDB`. + +### Train + +To train the baseline model without UDA: + + ```bash +python main.py --do-train --do-eval --do-test +``` + +To train with UDA: + + ```bash +python main.py --do-train --do-eval --do-test --use-uda +``` + +To change the hyperparameters, please see `config_data.py`. You can also change the number of labeled examples used for training (`num_train_data`). + +#### GPU Memory Issue: + +According to the authors' [guideline for hyperparameters](https://github.com/google-research/uda#general-guidelines-for-setting-hyperparameters), longer sequence length and larger batch size lead to better performances. The sequence length and batch size are limited by the GPU memory. By default, we use `max_seq_length=128` and `batch_size=24` to run on a GTX1080Ti with 11GB memory. + +## Results + +With the provided data, you should be able to achieve performance similar to the following: + +| Number of Labeled Examples | BERT Accuracy | BERT+UDA Accuracy| +| -------------------------- | ------------- | ------------------ | +| 24 | 61.54 | 84.92 | +| 25000 | 89.68 | 90.19 | + +When training with 24 examples, we use the Training Signal Annealing technique which can be turned on by setting `tsa=True`. + +You can further improve the performance by tuning hyperparameters, generate better back-translation data, using a larger BERT model, using a larger `max_seq_length` etc. + +## Using the UDAIterator + +Here is a brief tutorial to using Forte's `UDAIterator`. You can also refer to the `run_uda` function in `main.py`. + +### Initialization + +First, we initialize the `UDAIterator` with the supervised and unsupervised data: + +``` +iterator = tx.data.DataIterator( + {"train": train_dataset, "eval": eval_dataset} +) + +unsup_iterator = tx.data.DataIterator( + {"unsup": unsup_dataset} +) + +uda_iterator = UDAIterator( + iterator, + unsup_iterator, + softmax_temperature=1.0, + confidence_threshold=-1, + reduction="mean") +``` + +The next step is to tell the iterator which dataset to use, and initialize the internal iterators: + +``` +uda_iterator.switch_to_dataset_unsup("unsup") +uda_iterator.switch_to_dataset("train", use_unsup=True) +uda_iterator = iter(uda_iterator) # call iter() to initialize the internal iterators +``` + +### Training with UDA + +The UDA loss is the KL divergence between the the output probabilities of original input and augmented input. Here, we define `unsup_forward_fn` to calculate the probabilities: + +``` +def unsup_forward_fn(batch): + input_ids = batch["input_ids"] + segment_ids = batch["segment_ids"] + input_length = (1 - (input_ids == 0).int()).sum(dim=1) + + aug_input_ids = batch["aug_input_ids"] + aug_segment_ids = batch["aug_segment_ids"] + aug_input_length = (1 - (aug_input_ids == 0).int()).sum(dim=1) + + logits, _ = model(input_ids, input_length, segment_ids) + logits = logits.detach() # gradient does not propagate back to original input + aug_logits, _ = model(aug_input_ids, aug_input_length, aug_segment_ids) + return logits, aug_logits +``` + +Then, `UDAIterator.calculate_uda_loss` computes the UDA loss for us. Inside the training loop, we compute the supervised loss as usual (or with a TSA schedule), and add the unsupervised loss to produce the final loss: + +``` +# ... +# Inside Training Loop: +# sup loss +logits, _ = model(input_ids, input_length, segment_ids) +loss = _compute_loss_tsa(logits, labels, scheduler.last_epoch,\ + num_train_steps) +# unsup loss +unsup_logits, unsup_aug_logits = unsup_forward_fn(unsup_batch) +unsup_loss = uda_iterator.calculate_uda_loss(unsup_logits, unsup_aug_logits) + +loss = loss + unsup_loss # unsup coefficient = 1 +loss.backward() +# ... +``` + +You can read more about the TSA schedule from the UDA paper. + +### Evaluation + +For evaluation, we simply switch to the eval dataset. In the `for` loop we only need the supervised batch: + +``` +uda_iterator.switch_to_dataset("eval", use_unsup=False) +for batch, _ in uda_iterator: +# do evaluation ... +``` diff --git a/examples/data_augmentation/uda/__init__.py b/examples/data_augmentation/uda/__init__.py new file mode 100644 index 000000000..a5dd21c1f --- /dev/null +++ b/examples/data_augmentation/uda/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 The Forte Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/data_augmentation/uda/back_trans/merge_back_trans_sentences.py b/examples/data_augmentation/uda/back_trans/merge_back_trans_sentences.py new file mode 100644 index 000000000..3fb985986 --- /dev/null +++ b/examples/data_augmentation/uda/back_trans/merge_back_trans_sentences.py @@ -0,0 +1,51 @@ +# Copyright 2020 The Forte Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Compose paraphrased sentences back to paragraphs. +""" +import argparse +import json + +parser = argparse.ArgumentParser() +parser.add_argument( + '--input_file', type=str, default='backward_gen.txt', + help="Path to the back translated sentence file.") +parser.add_argument( + '--output_file', type=str, default='back_translate.txt', + help="Path to the output paragraph file.") +parser.add_argument( + '--doc_len_file', type=str, default='train_doc_len.json', + help="The file that records the length information.") +args = parser.parse_args() + + +def main(): + with open(args.input_file, encoding='utf-8') as inf: + sentences = inf.readlines() + with open(args.doc_len_file, encoding='utf-8') as inf: + doc_len_list = json.load(inf) + cnt = 0 + print("Printing paraphrases:") + with open(args.output_file, "w", encoding='utf-8') as ouf: + for i, sent_num in enumerate(doc_len_list): + para = "" + for _ in range(sent_num): + para += sentences[cnt].strip() + " " + cnt += 1 + print("Paraphrase {}: {}".format(i, para)) + ouf.write(para.strip() + "\n") + + +if __name__ == '__main__': + main() diff --git a/examples/data_augmentation/uda/back_trans/requirements.txt b/examples/data_augmentation/uda/back_trans/requirements.txt new file mode 100644 index 000000000..3fb220a39 --- /dev/null +++ b/examples/data_augmentation/uda/back_trans/requirements.txt @@ -0,0 +1,10 @@ +mesh-tensorflow==0.0.5 +tensorboard==1.13.0 +tensorboard-logger==0.1.0 +tensorboardX==1.8 +tensorflow-datasets==1.3.0 +tensorflow-estimator==1.13.0 +tensorflow-gpu==1.13.1 +tensorflow-hub==0.7.0 +tensorflow-metadata==0.13.0 +tensorflow-probability==0.6.0 diff --git a/examples/data_augmentation/uda/config_classifier.py b/examples/data_augmentation/uda/config_classifier.py new file mode 100644 index 000000000..3000603ec --- /dev/null +++ b/examples/data_augmentation/uda/config_classifier.py @@ -0,0 +1,11 @@ +name = "bert_classifier" +hidden_size = 768 +clas_strategy = "cls_time" +dropout = 0.1 +num_classes = 2 + +# This hyperparams is used in bert_with_hypertuning_main.py example +hyperparams = { + "optimizer.warmup_steps": {"start": 10000, "end": 20000, "dtype": int}, + "optimizer.static_lr": {"start": 1e-3, "end": 1e-2, "dtype": float} +} diff --git a/examples/data_augmentation/uda/config_data.py b/examples/data_augmentation/uda/config_data.py new file mode 100644 index 000000000..c13f2c9d8 --- /dev/null +++ b/examples/data_augmentation/uda/config_data.py @@ -0,0 +1,77 @@ +pickle_data_dir = "data/IMDB" +unsup_bt_file = "data/IMDB/back_translate.txt" +max_seq_length = 128 +num_classes = 2 +num_train_data = 24 # supervised data limit. max 25000 + +train_batch_size = 24 +max_train_epoch = 3000 +display_steps = 50 # Print training loss every display_steps; -1 to disable + +eval_steps = 100 # Eval every eval_steps; if -1 will eval every epoch +# Proportion of training to perform linear learning rate warmup for. +# E.g., 0.1 = 10% of training. +warmup_proportion = 0.1 +eval_batch_size = 8 +test_batch_size = 8 + +feature_types = { + # Reading features from pickled data file. + # E.g., Reading feature "input_ids" as dtype `int64`; + # "FixedLenFeature" indicates its length is fixed for all data instances; + # and the sequence length is limited by `max_seq_length`. + "input_ids": ["int64", "stacked_tensor", max_seq_length], + "input_mask": ["int64", "stacked_tensor", max_seq_length], + "segment_ids": ["int64", "stacked_tensor", max_seq_length], + "label_ids": ["int64", "stacked_tensor"] +} + +train_hparam = { + "allow_smaller_final_batch": False, + "batch_size": train_batch_size, + "dataset": { + "data_name": "data", + "feature_types": feature_types, + "files": "{}/train.pkl".format(pickle_data_dir) + }, + "shuffle": True, + "shuffle_buffer_size": None +} + +eval_hparam = { + "allow_smaller_final_batch": True, + "batch_size": eval_batch_size, + "dataset": { + "data_name": "data", + "feature_types": feature_types, + "files": "{}/eval.pkl".format(pickle_data_dir) + }, + "shuffle": False +} + +# UDA config +tsa = True +tsa_schedule = "linear_schedule" # linear_schedule, exp_schedule, log_schedule + +unsup_feature_types = { + "input_ids": ["int64", "stacked_tensor", max_seq_length], + "input_mask": ["int64", "stacked_tensor", max_seq_length], + "segment_ids": ["int64", "stacked_tensor", max_seq_length], + "label_ids": ["int64", "stacked_tensor"], + "aug_input_ids": ["int64", "stacked_tensor", max_seq_length], + "aug_input_mask": ["int64", "stacked_tensor", max_seq_length], + "aug_segment_ids": ["int64", "stacked_tensor", max_seq_length], + "aug_label_ids": ["int64", "stacked_tensor"] +} + +unsup_hparam = { + "allow_smaller_final_batch": True, + "batch_size": train_batch_size, + "dataset": { + "data_name": "data", + "feature_types": unsup_feature_types, + "files": "{}/unsup.pkl".format(pickle_data_dir) + }, + "shuffle": True, + "shuffle_buffer_size": None, +} diff --git a/examples/data_augmentation/uda/download.py b/examples/data_augmentation/uda/download.py new file mode 100644 index 000000000..63f61bbab --- /dev/null +++ b/examples/data_augmentation/uda/download.py @@ -0,0 +1,37 @@ +# Copyright 2020 The Forte Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Download IMDB dataset. +""" +from forte.data.data_utils import maybe_download + + +def main(): + imdb_path = "data/IMDB_raw" + maybe_download(urls=[ + "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"], + path=imdb_path, + extract=True) + bt_path = "back_trans" + maybe_download(urls=[ + "https://storage.googleapis.com/uda_model/text/" + "back_trans_checkpoints.zip"], + path=bt_path, + extract=True) + + +if __name__ == '__main__': + main() diff --git a/examples/data_augmentation/uda/main.py b/examples/data_augmentation/uda/main.py new file mode 100644 index 000000000..0cd2e51fa --- /dev/null +++ b/examples/data_augmentation/uda/main.py @@ -0,0 +1,540 @@ +# Copyright 2020 The Forte Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import functools +import logging +import os + +import torch +import torch.nn.functional as F +import texar.torch as tx + +import config_data +import config_classifier +from utils import data_utils, model_utils + +# pylint: disable=no-name-in-module + +from forte.processors.data_augment.algorithms.UDA import UDAIterator + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--use-uda", action="store_true", + help="Whenther to train with UDA") +parser.add_argument( + '--pretrained-model-name', type=str, default='bert-base-uncased', + choices=tx.modules.BERTEncoder.available_checkpoints(), + help="Name of the pre-trained BERT model to load.") +parser.add_argument( + '--checkpoint', type=str, default=None, + help="Path to the checkpoint to load.") +parser.add_argument( + "--output-dir", default="output/", + help="The output directory where the model checkpoints will be written.") +parser.add_argument( + "--do-train", action="store_true", help="Whether to run training.") +parser.add_argument( + "--do-eval", action="store_true", + help="Whether to run eval on the dev set.") +parser.add_argument( + "--do-test", action="store_true", + help="Whether to run test on the test set.") +args = parser.parse_args() + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +logging.root.setLevel(logging.INFO) + + +class IMDBClassifierTrainer: + """ + A baseline text classifier trainer for the IMDB dataset. + The input data should be CSV format with columns (content label id). + An example usage can be found at examples/text_classification. + """ + + def __init__(self, trainer_config_data, trainer_config_classifier, + checkpoint=args.checkpoint, + pretrained_model_name=args.pretrained_model_name): + """Constructs the text classifier. + Args: + trainer_config_data: data config file. + trainer_config_classifier: classifier config file + checkpoint: the saved checkpoint to use + pretrained_model_name: name of the pretrained model to use + """ + self.config_data = trainer_config_data + self.config_classifier = trainer_config_classifier + self.checkpoint = checkpoint + self.pretrained_model_name = pretrained_model_name + + def run(self, do_train, do_eval, do_test, output_dir="output/"): + """ + Builds the model and runs. + """ + tx.utils.maybe_create_dir(output_dir) + + # Loads data + num_train_data = self.config_data.num_train_data + + hparams = { + k: v for k, v in self.config_classifier.__dict__.items() + if not k.startswith('__') and k != "hyperparams"} + + # Builds BERT + model = tx.modules.BERTClassifier( + pretrained_model_name=self.pretrained_model_name, + hparams=hparams) + model.to(device) + + num_train_steps = int(num_train_data / self.config_data.train_batch_size + * self.config_data.max_train_epoch) + num_warmup_steps = int(num_train_steps + * self.config_data.warmup_proportion) + + # Builds learning rate decay scheduler + static_lr = 2e-5 + + vars_with_decay = [] + vars_without_decay = [] + for name, param in model.named_parameters(): + if 'layer_norm' in name or name.endswith('bias'): + vars_without_decay.append(param) + else: + vars_with_decay.append(param) + + opt_params = [{ + 'params': vars_with_decay, + 'weight_decay': 0.01, + }, { + 'params': vars_without_decay, + 'weight_decay': 0.0, + }] + optim = tx.core.BertAdam( + opt_params, betas=(0.9, 0.999), eps=1e-6, lr=static_lr) + + scheduler = torch.optim.lr_scheduler.LambdaLR( + optim, functools.partial(model_utils.get_lr_multiplier, + total_steps=num_train_steps, + warmup_steps=num_warmup_steps)) + + train_dataset = tx.data.RecordData( + hparams=self.config_data.train_hparam, device=device) + eval_dataset = tx.data.RecordData( + hparams=self.config_data.eval_hparam, device=device) + + iterator = tx.data.DataIterator( + {"train": train_dataset, "eval": eval_dataset} + ) + + def _compute_loss(logits, labels): + r"""Compute loss. + """ + if model.is_binary: + loss = F.binary_cross_entropy( + logits.view(-1), labels.view(-1), reduction='mean') + else: + loss = F.cross_entropy( + logits.view(-1, model.num_classes), + labels.view(-1), reduction='mean') + return loss + + def _train_epoch(): + r"""Trains on the training set, and evaluates on the dev set + periodically. + """ + iterator.switch_to_dataset("train") + model.train() + + for batch in iterator: + optim.zero_grad() + input_ids = batch["input_ids"] + segment_ids = batch["segment_ids"] + labels = batch["label_ids"] + + input_length = (1 - (input_ids == 0).int()).sum(dim=1) + + logits, _ = model(input_ids, input_length, segment_ids) + + loss = _compute_loss(logits, labels) + loss.backward() + optim.step() + scheduler.step() + step = scheduler.last_epoch + + dis_steps = self.config_data.display_steps + if dis_steps > 0 and step % dis_steps == 0: + logging.info("step: %d; loss: %f", step, loss) + + eval_steps = self.config_data.eval_steps + if eval_steps > 0 and step % eval_steps == 0: + _eval_epoch() + model.train() + + @torch.no_grad() + def _eval_epoch(): + """Evaluates on the dev set. + """ + iterator.switch_to_dataset("eval") + model.eval() + + nsamples = 0 + avg_rec = tx.utils.AverageRecorder() + for batch in iterator: + input_ids = batch["input_ids"] + segment_ids = batch["segment_ids"] + labels = batch["label_ids"] + + input_length = (1 - (input_ids == 0).int()).sum(dim=1) + + logits, preds = model(input_ids, input_length, segment_ids) + + loss = _compute_loss(logits, labels) + accu = tx.evals.accuracy(labels, preds) + batch_size = input_ids.size()[0] + avg_rec.add([accu, loss], batch_size) + nsamples += batch_size + logging.info("eval accu: %.4f; loss: %.4f; nsamples: %d", + avg_rec.avg(0), avg_rec.avg(1), nsamples) + + @torch.no_grad() + def _test_epoch(): + """Does predictions on the test set. + """ + iterator.switch_to_dataset("eval") + model.eval() + + _all_preds = [] + for batch in iterator: + input_ids = batch["input_ids"] + segment_ids = batch["segment_ids"] + + input_length = (1 - (input_ids == 0).int()).sum(dim=1) + + _, preds = model(input_ids, input_length, segment_ids) + + _all_preds.extend(preds.tolist()) + + output_file = os.path.join(output_dir, "test_results.tsv") + with open(output_file, "w+") as writer: + writer.write("\n".join(str(p) for p in _all_preds)) + logging.info("test output written to %s", output_file) + + if self.checkpoint: + ckpt = torch.load(self.checkpoint) + model.load_state_dict(ckpt['model']) + optim.load_state_dict(ckpt['optimizer']) + scheduler.load_state_dict(ckpt['scheduler']) + if do_train: + for _ in range(self.config_data.max_train_epoch): + _train_epoch() + if self.config_data.eval_steps == -1: + _eval_epoch() + states = { + 'model': model.state_dict(), + 'optimizer': optim.state_dict(), + 'scheduler': scheduler.state_dict(), + } + torch.save(states, os.path.join(output_dir, 'model.ckpt')) + + if do_eval: + _eval_epoch() + + if do_test: + _test_epoch() + + def run_uda(self, do_train, do_eval, do_test, output_dir="output/"): + """ + Builds the model and runs. + """ + tx.utils.maybe_create_dir(output_dir) + + logging.root.setLevel(logging.INFO) + + # Loads data + num_train_data = self.config_data.num_train_data + + hparams = { + k: v for k, v in self.config_classifier.__dict__.items() + if not k.startswith('__') and k != "hyperparams"} + + # Builds BERT + model = tx.modules.BERTClassifier( + pretrained_model_name=self.pretrained_model_name, + hparams=hparams) + model.to(device) + + num_train_steps = int(num_train_data / self.config_data.train_batch_size + * self.config_data.max_train_epoch) + num_warmup_steps = int(num_train_steps + * self.config_data.warmup_proportion) + + # Builds learning rate decay scheduler + static_lr = 2e-5 + + vars_with_decay = [] + vars_without_decay = [] + for name, param in model.named_parameters(): + if 'layer_norm' in name or name.endswith('bias'): + vars_without_decay.append(param) + else: + vars_with_decay.append(param) + + opt_params = [{ + 'params': vars_with_decay, + 'weight_decay': 0.01, + }, { + 'params': vars_without_decay, + 'weight_decay': 0.0, + }] + optim = tx.core.BertAdam( + opt_params, betas=(0.9, 0.999), eps=1e-6, lr=static_lr) + + scheduler = torch.optim.lr_scheduler.LambdaLR( + optim, functools.partial(model_utils.get_lr_multiplier, + total_steps=num_train_steps, + warmup_steps=num_warmup_steps)) + + train_dataset = tx.data.RecordData( + hparams=self.config_data.train_hparam, device=device) + eval_dataset = tx.data.RecordData( + hparams=self.config_data.eval_hparam, device=device) + unsup_dataset = tx.data.RecordData( + hparams=self.config_data.unsup_hparam, device=device) + + iterator = tx.data.DataIterator( + {"train": train_dataset, "eval": eval_dataset} + ) + + unsup_iterator = tx.data.DataIterator( + {"unsup": unsup_dataset} + ) + + def unsup_forward_fn(batch): + input_ids = batch["input_ids"] + segment_ids = batch["segment_ids"] + input_length = (1 - (input_ids == 0).int()).sum(dim=1) + + aug_input_ids = batch["aug_input_ids"] + aug_segment_ids = batch["aug_segment_ids"] + aug_input_length = (1 - (aug_input_ids == 0).int()).sum(dim=1) + + logits, _ = model(input_ids, input_length, segment_ids) + # gradient does not propagate back to original input + logits = logits.detach() + aug_logits, _ = model( + aug_input_ids, aug_input_length, aug_segment_ids) + return logits, aug_logits + + uda_iterator = UDAIterator( + iterator, + unsup_iterator, + softmax_temperature=1.0, + confidence_threshold=-1, + reduction="mean") + + uda_iterator.switch_to_dataset_unsup("unsup") + uda_iterator.switch_to_dataset("train", use_unsup=True) + # call iter() to initialize the internal iterators + uda_iterator = iter(uda_iterator) + + def _compute_loss(logits, labels): + r"""Compute loss. + """ + if model.is_binary: + loss = F.binary_cross_entropy( + logits.view(-1), labels.view(-1), reduction='mean') + else: + loss = F.cross_entropy( + logits.view(-1, model.num_classes), + labels.view(-1), reduction='mean') + return loss + + def _compute_loss_tsa(logits, labels, global_step, num_train_steps): + r"""Compute loss. + """ + loss = 0 + log_probs = F.log_softmax(logits) + one_hot_labels = torch.zeros_like( + log_probs, dtype=torch.float).to(device) + one_hot_labels.scatter_(1, labels.view(-1, 1), 1) + + per_example_loss = -(one_hot_labels * log_probs).sum(dim=-1) + loss_mask = torch.ones_like( + per_example_loss, dtype=per_example_loss.dtype).to(device) + correct_label_probs = \ + (one_hot_labels * torch.exp(log_probs)).sum(dim=-1) + + if self.config_data.tsa: + tsa_start = 1. / model.num_classes + tsa_threshold = model_utils.get_tsa_threshold( + self.config_data.tsa_schedule, global_step, + num_train_steps, start=tsa_start, end=1) + larger_than_threshold = torch.gt( + correct_label_probs, tsa_threshold) + loss_mask = loss_mask * (1 - larger_than_threshold.float()) + else: + tsa_threshold = 1 + + loss_mask = loss_mask.detach() + per_example_loss = per_example_loss * loss_mask + loss_mask_sum = loss_mask.sum() + loss = per_example_loss.sum() + if loss_mask_sum > 0: + loss = loss / loss_mask_sum + return loss + + def _train_epoch(): + r"""Trains on the training set, and evaluates on the dev set + periodically. + """ + model.train() + uda_iterator.switch_to_dataset("train", use_unsup=True) + iter(uda_iterator) + nsamples = 0 + for batch, unsup_batch in uda_iterator: + optim.zero_grad() + input_ids = batch["input_ids"] + segment_ids = batch["segment_ids"] + labels = batch["label_ids"] + + batch_size = input_ids.size()[0] + nsamples += batch_size + + input_length = (1 - (input_ids == 0).int()).sum(dim=1) + + # sup loss + logits, _ = model(input_ids, input_length, segment_ids) + loss = _compute_loss_tsa(logits, labels, scheduler.last_epoch, + num_train_steps) + # unsup loss + unsup_logits, unsup_aug_logits = unsup_forward_fn(unsup_batch) + unsup_loss = uda_iterator.calculate_uda_loss( + unsup_logits, unsup_aug_logits) + + loss = loss + unsup_loss # unsup coefficient = 1 + loss.backward() + optim.step() + scheduler.step() + step = scheduler.last_epoch + + dis_steps = self.config_data.display_steps + if dis_steps > 0 and step % dis_steps == 0: + logging.info( + "step: %d; loss: %f, unsup_loss %f", + step, loss, unsup_loss) + + eval_steps = self.config_data.eval_steps + if eval_steps > 0 and step % eval_steps == 0: + _eval_epoch() + model.train() + # uda_iterator.switch_to_dataset("train", use_unsup=True) + print("Train nsamples:", nsamples) + + @torch.no_grad() + def _eval_epoch(): + """Evaluates on the dev set. + """ + uda_iterator.switch_to_dataset("eval", use_unsup=False) + model.eval() + + nsamples = 0 + avg_rec = tx.utils.AverageRecorder() + for batch, _ in uda_iterator: + input_ids = batch["input_ids"] + segment_ids = batch["segment_ids"] + labels = batch["label_ids"] + + input_length = (1 - (input_ids == 0).int()).sum(dim=1) + + logits, preds = model(input_ids, input_length, segment_ids) + + loss = _compute_loss(logits, labels) + accu = tx.evals.accuracy(labels, preds) + batch_size = input_ids.size()[0] + avg_rec.add([accu, loss], batch_size) + nsamples += batch_size + logging.info("eval accu: %.4f; loss: %.4f; nsamples: %d", + avg_rec.avg(0), avg_rec.avg(1), nsamples) + + @torch.no_grad() + def _test_epoch(): + """Does predictions on the test set. + """ + uda_iterator.switch_to_dataset("eval", use_unsup=False) + model.eval() + + _all_preds = [] + for batch, _ in uda_iterator: + input_ids = batch["input_ids"] + segment_ids = batch["segment_ids"] + + input_length = (1 - (input_ids == 0).int()).sum(dim=1) + + _, preds = model(input_ids, input_length, segment_ids) + + _all_preds.extend(preds.tolist()) + + output_file = os.path.join(output_dir, "test_results.tsv") + with open(output_file, "w+") as writer: + writer.write("\n".join(str(p) for p in _all_preds)) + logging.info("test output written to %s", output_file) + + if self.checkpoint: + ckpt = torch.load(self.checkpoint) + model.load_state_dict(ckpt['model']) + optim.load_state_dict(ckpt['optimizer']) + scheduler.load_state_dict(ckpt['scheduler']) + if do_train: + for i in range(self.config_data.max_train_epoch): + print("Epoch", i) + _train_epoch() + if self.config_data.eval_steps == -1: + # eval after epoch because switch_dataset + # just resets the iterator + _eval_epoch() + states = { + 'model': model.state_dict(), + 'optimizer': optim.state_dict(), + 'scheduler': scheduler.state_dict(), + } + torch.save(states, os.path.join(output_dir, 'model.ckpt')) + + if do_eval: + _eval_epoch() + + if do_test: + _test_epoch() + + +def main(): + trainer = IMDBClassifierTrainer(config_data, config_classifier) + if not os.path.isfile("data/IMDB/train.pkl")\ + or not os.path.isfile("data/IMDB/eval.pkl")\ + or not os.path.isfile("data/IMDB/predict.pkl")\ + or not os.path.isfile("data/IMDB/unsup.pkl"): + data_utils.prepare_data( + trainer.pretrained_model_name, config_data, "data/IMDB") + if args.use_uda: + trainer.run_uda( + args.do_train, args.do_eval, args.do_test, args.output_dir) + else: + trainer.run(args.do_train, args.do_eval, args.do_test, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/examples/data_augmentation/uda/utils/data_utils.py b/examples/data_augmentation/uda/utils/data_utils.py new file mode 100644 index 000000000..7711777fc --- /dev/null +++ b/examples/data_augmentation/uda/utils/data_utils.py @@ -0,0 +1,409 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This is the Data Loading Pipeline for Sentence Classifier Task from: + `https://github.com/google-research/bert/blob/master/run_classifier.py` +""" + +import copy +import os +import csv +import logging +import math +import random + +import numpy as np +import texar.torch as tx + + +class InputExample(): + """A single training/test example for simple sequence classification.""" + + def __init__(self, guid, text_a, text_b=None, label=None): + """Constructs a InputExample. + Args: + guid: Unique id for the example. + text_a: string. The untokenized text of the first sequence. + For single sequence tasks, only this sequence must be specified. + text_b: (Optional) string. The untokenized text of the second + sequence. Only must be specified for sequence pair tasks. + label: (Optional) string. The label of the example. This should be + specified for train and dev examples, but not for test examples. + """ + self.guid = guid + self.text_a = text_a + self.text_b = text_b + self.label = label + + +class InputFeatures: + """A single set of features of data.""" + + def __init__(self, input_ids, input_mask, segment_ids, label_id): + self.input_ids = input_ids + self.input_mask = input_mask + self.segment_ids = segment_ids + self.label_id = label_id + + +class DataProcessor(): + """Base class for data converters for sequence classification data sets.""" + + def get_train_examples(self, data_dir): + """Gets a collection of `InputExample`s for the train set.""" + raise NotImplementedError() + + def get_dev_examples(self, data_dir): + """Gets a collection of `InputExample`s for the dev set.""" + raise NotImplementedError() + + def get_test_examples(self, data_dir): + """Gets a collection of `InputExample`s for prediction.""" + raise NotImplementedError() + + def get_labels(self): + """Gets the list of labels for this data set.""" + raise NotImplementedError() + + @classmethod + def _read_tsv(cls, input_file, quotechar=None): + """Reads a tab separated value file.""" + with open(input_file, "r", encoding="utf-8") as f: + reader = csv.reader(f, delimiter="\t", quotechar=quotechar) + lines = [] + for line in reader: + if len(line) > 0: + lines.append(line) + return lines + + +class IMDbProcessor(DataProcessor): + """Processor for the IMDb data set.""" + + def get_train_examples(self, raw_data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(raw_data_dir, "train.csv"), + quotechar='"'), "train") + + def get_dev_examples(self, raw_data_dir): + """The IMDB dataset does not have a dev set so we just use test set""" + return self._create_examples( + self._read_tsv(os.path.join(raw_data_dir, "test.csv"), + quotechar='"'), "test") + + def get_unsup_examples(self, raw_data_dir, unsup_set): + """See base class.""" + if unsup_set == "unsup_ext": + return self._create_examples( + self._read_tsv(os.path.join(raw_data_dir, "unsup_ext.csv"), + quotechar='"'), "unsup_ext", skip_unsup=False) + elif unsup_set == "unsup_in": + return self._create_examples( + self._read_tsv(os.path.join(raw_data_dir, "train.csv"), quotechar='"'), "unsup_in", skip_unsup=False) + + def get_unsup_aug_examples(self, raw_data_dir, unsup_set): + """See base class.""" + if unsup_set == "unsup_ext": + return self._create_examples( + self._read_tsv(os.path.join(raw_data_dir, "unsup_ext.csv"), + quotechar='"'), "unsup_ext", skip_unsup=False) + elif unsup_set == "unsup_in": + return self._create_examples( + self._read_tsv(os.path.join(raw_data_dir, "train_aug.csv"), + quotechar='"'), "unsup_in", skip_unsup=False) + + def get_labels(self): + """See base class.""" + return ["pos", "neg"] + + def _create_examples(self, lines, set_type, skip_unsup=True): + """Creates examples for the training and dev sets.""" + examples = [] + print(len(lines)) + for (i, line) in enumerate(lines): + if i == 0 or len(line) == 1: # newline + continue + if skip_unsup and line[-2] == "unsup": + continue + # Original UDA implementation + # if line[-2] == "unsup" and len(line[0]) < 500: + # tf.logging.info("skipping short samples:{:s}".format(line[0])) + # continue + guid = "%s-%s" % (set_type, line[-1]) + text_a = " ".join(line[:-2]) + label = line[-2] + if label not in ["pos", "neg", "unsup"]: + print(line) + examples.append(InputExample(guid=guid, text_a=text_a, + text_b=None, label=label)) + return examples + + def get_train_size(self): + return 25000 + + def get_dev_size(self): + return 25000 + + +def convert_single_example(ex_index, example, label_list, max_seq_length, + tokenizer): + r"""Converts a single `InputExample` into a single `InputFeatures`.""" + label_map = {} + for (i, label) in enumerate(label_list): + label_map[label] = i + + input_ids, segment_ids, input_mask = \ + tokenizer.encode_text(text_a=example.text_a, + text_b=example.text_b, + max_seq_length=max_seq_length) + + label_id = label_map[example.label] + + # here we disable the verbose printing of the data + if ex_index < 0: + logging.info("*** Example ***") + logging.info("guid: %s", example.guid) + logging.info("input_ids: %s", " ".join([str(x) for x in input_ids])) + logging.info("input_ids length: %d", len(input_ids)) + logging.info("input_mask: %s", " ".join([str(x) for x in input_mask])) + logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) + logging.info("label: %s (id = %d)", example.label, label_id) + + feature = InputFeatures(input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + label_id=label_id) + return feature + + +def convert_examples_to_features_and_output_to_files( + examples, label_list, max_seq_length, tokenizer, output_file, + feature_types): + r"""Convert a set of `InputExample`s to a pickled file.""" + + with tx.data.RecordData.writer(output_file, feature_types) as writer: + for (ex_index, example) in enumerate(examples): + feature = convert_single_example(ex_index, example, label_list, + max_seq_length, tokenizer) + + features = { + "input_ids": feature.input_ids, + "input_mask": feature.input_mask, + "segment_ids": feature.segment_ids, + "label_ids": feature.label_id + } + writer.write(features) + + +def convert_unsup_examples_to_features_and_output_to_files( + examples, aug_examples, label_list, max_seq_length, tokenizer, output_file, + feature_types): + r"""Convert a set of `InputExample`s and the augmented examples + to a pickled file. + """ + + with tx.data.RecordData.writer(output_file, feature_types) as writer: + print(len(examples), "unsup examples") + print(len(aug_examples), "augmented unsup examples") + assert(len(examples) == len(aug_examples)) + for (ex_index, (example, aug_example)) in enumerate(zip(examples, aug_examples)): + feature = convert_single_example(ex_index, example, label_list, + max_seq_length, tokenizer) + aug_feature = convert_single_example(ex_index, aug_example, label_list, + max_seq_length, tokenizer) + + features = { + "input_ids": feature.input_ids, + "input_mask": feature.input_mask, + "segment_ids": feature.segment_ids, + "label_ids": feature.label_id, + "aug_input_ids": aug_feature.input_ids, + "aug_input_mask": aug_feature.input_mask, + "aug_segment_ids": aug_feature.segment_ids, + "aug_label_ids": aug_feature.label_id, + } + writer.write(features) + + +def replace_with_length_check( + ori_text, new_text, + use_min_length, + use_max_length_diff_ratio): + """Use new_text if the text length satisfies several constraints.""" + if len(ori_text) < use_min_length or len(new_text) < use_min_length: + if random.random() < 0.001: + print("not replacing due to short text: \n\tori: {:s}\n\tnew: {:s}\n".format( + ori_text, + new_text)) + return ori_text + length_diff_ratio = 1.0 * (len(new_text) - len(ori_text)) / len(ori_text) + if math.fabs(length_diff_ratio) > use_max_length_diff_ratio: + if random.random() < 0.001: + print("not replacing due to too different text length:\n" + "\tori: {:s}\n\tnew: {:s}\n".format( + ori_text, + new_text)) + return ori_text + return new_text + + +def back_translation(examples, back_translation_file, data_total_size): + """Load back translation.""" + use_min_length = 10 + use_max_length_diff_ratio = 0.5 + + text_per_example = 1 + + with open(back_translation_file, encoding='utf-8') as inf: + paraphrases = inf.readlines() + for i in range(len(paraphrases)): + paraphrases[i] = paraphrases[i].strip() + assert len(paraphrases) == data_total_size + + aug_examples = [] + aug_cnt = 0 + for i in range(len(examples)): + ori_example = examples[i] + text_a = replace_with_length_check( + ori_example.text_a, + paraphrases[i * text_per_example], + use_min_length, + use_max_length_diff_ratio, + ) + if text_a == paraphrases[i * text_per_example]: + aug_cnt += 1 + if ori_example.text_b is not None: + text_b = replace_with_length_check( + ori_example.text_b, + paraphrases[i * text_per_example + 1], + use_min_length, + use_max_length_diff_ratio, + ) + else: + text_b = None + + example = InputExample( + guid=ori_example.guid, + text_a=text_a, + text_b=text_b, + label=ori_example.label) + aug_examples += [example] + if i % 10000 == 0: + print("processing example # {:d}".format(i)) + logging.info("applied back translation for {:.1f} percent of data".format( + aug_cnt * 1. / len(examples) * 100)) + logging.info("finishing running back translation augmentation") + return aug_examples + + +def prepare_record_data(processor, tokenizer, + data_dir, max_seq_length, output_dir, + feature_types, unsup_feature_types=None, sup_size_limit=None, unsup_bt_file=None): + r"""Prepare record data. + Args: + processor: Data Preprocessor, which must have get_labels, + get_train/dev/test/examples methods defined. + tokenizer: The Sentence Tokenizer. Generally should be + SentencePiece Model. + data_dir: The input data directory. + max_seq_length: Max sequence length. + output_dir: The directory to save the pickled file in. + feature_types: The original type of the feature. + unsup_feature_types: Feature types for the unsupervised data. + sup_size_limit: the number of supervised data to use + unsup_bt_file: the path to the back-translation of the + unsupervised dataset. + """ + label_list = processor.get_labels() + + train_file = os.path.join(output_dir, "train.pkl") + if not os.path.isfile(train_file): + train_examples = processor.get_train_examples(data_dir) + if sup_size_limit is not None: + train_examples = get_data_by_size_lim(train_examples, processor, sup_size_limit) + convert_examples_to_features_and_output_to_files( + train_examples, label_list, max_seq_length, + tokenizer, train_file, feature_types) + + eval_file = os.path.join(output_dir, "eval.pkl") + if not os.path.isfile(eval_file): + eval_examples = processor.get_dev_examples(data_dir) + convert_examples_to_features_and_output_to_files( + eval_examples, label_list, + max_seq_length, tokenizer, eval_file, feature_types) + + unsup_file = os.path.join(output_dir, "unsup.pkl") + if not os.path.isfile(unsup_file): + unsup_label_list = label_list + ["unsup"] + unsup_examples = processor.get_unsup_examples(data_dir, "unsup_in") + unsup_aug_examples = copy.deepcopy(unsup_examples) + unsup_aug_examples = back_translation(unsup_aug_examples, unsup_bt_file, len(unsup_aug_examples)) + convert_unsup_examples_to_features_and_output_to_files( + unsup_examples, unsup_aug_examples, unsup_label_list, + max_seq_length, tokenizer, unsup_file, unsup_feature_types) + + +def get_data_by_size_lim(train_examples, processor, sup_size): + """Deterministicly get a dataset with only sup_size examples.""" + # Assuming sup_size < number of labeled data and + # that there are same number of examples for each category + assert sup_size % len(processor.get_labels()) == 0 + per_label_size = sup_size // len(processor.get_labels()) + per_label_examples = {} + for i in range(len(train_examples)): + label = train_examples[i].label + if label not in per_label_examples: + per_label_examples[label] = [] + per_label_examples[label] += [train_examples[i]] + + for label in processor.get_labels(): + assert len(per_label_examples[label]) >= per_label_size, ( + "label {} only has {} examples while the limit" + "is {}".format(label, len(per_label_examples[label]), per_label_size)) + + new_train_examples = [] + for i in range(per_label_size): + for label in processor.get_labels(): + new_train_examples += [per_label_examples[label][i]] + train_examples = new_train_examples + return train_examples + + +def prepare_data(pretrained_model_name, config_data, data_dir): + r"""Prepares data. + Args: + pretrained_model_name: the pretrained BERT model name to use + config_data: the config_data module + data_dir: path to the output record data + """ + logging.info("Loading data") + + processor = IMDbProcessor() + + tokenizer = tx.data.BERTTokenizer( + pretrained_model_name=pretrained_model_name) + + prepare_record_data( + processor=processor, + tokenizer=tokenizer, + data_dir=data_dir, + max_seq_length=config_data.max_seq_length, + output_dir=data_dir, + feature_types=config_data.feature_types, + unsup_feature_types=config_data.unsup_feature_types, + sup_size_limit=config_data.num_train_data, + unsup_bt_file=config_data.unsup_bt_file, + ) diff --git a/examples/data_augmentation/uda/utils/imdb_format.py b/examples/data_augmentation/uda/utils/imdb_format.py new file mode 100644 index 000000000..88bb3711e --- /dev/null +++ b/examples/data_augmentation/uda/utils/imdb_format.py @@ -0,0 +1,87 @@ +# Copyright 2020 The Forte Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Read all data in IMDB and merge them to a csv file.""" +import os +import csv +import json + +from forte.data.multi_pack import MultiPack +from forte.data.readers import LargeMovieReader +from forte.pipeline import Pipeline +from forte.processors.nltk_processors import NLTKSentenceSegmenter +from forte.utils.utils_io import maybe_create_dir +from ft.onto.base_ontology import Document, Sentence + + +def main(): + pipeline = Pipeline[MultiPack]() + reader = LargeMovieReader() + pipeline.set_reader(reader) + pipeline.add(NLTKSentenceSegmenter()) + + pipeline.initialize() + + dataset_path = "data/IMDB_raw/aclImdb/" + input_file_path = { + "train": os.path.join(dataset_path, "train"), + "test": os.path.join(dataset_path, "test") + } + output_path = "data/IMDB/" + maybe_create_dir(output_path) + output_file_path = { + "train": os.path.join(output_path, "train.csv"), + "test": os.path.join(output_path, "test.csv") + } + set_labels = { + "train": ["pos", "neg", "unsup"], + "test": ["pos", "neg"], + } + + back_trans_data_path = "back_trans/" + maybe_create_dir(back_trans_data_path) + split_sent_output_path = os.path.join(back_trans_data_path, "train_split_sent.txt") + doc_len_path = os.path.join(back_trans_data_path, "train_doc_len.json") + + sents = [] + doc_lens = [] + for split in ["train", "test"]: + with open(output_file_path[split], "w", encoding="utf-8")\ + as output_file: + writer = csv.writer(output_file, delimiter="\t", quotechar="\"") + writer.writerow(["content", "label", "id"]) + for label in set_labels[split]: + data_packs = \ + pipeline.process_dataset( + os.path.join(input_file_path[split], label)) + for pack in data_packs: + example_id = pack.pack_name + for doc in pack.get(Document): + writer.writerow( + [doc.text.strip(), label, example_id]) + if split == "train": + doc_len = 0 + for sent in pack.get(Sentence): + sents.append(sent.text) + doc_len += 1 + doc_lens.append(doc_len) + + with open(split_sent_output_path, "w", encoding="utf-8") as output_file: + for sent in sents: + output_file.write(sent + "\n") + with open(doc_len_path, "w", encoding="utf-8") as output_file: + json.dump(doc_lens, output_file) + + +if __name__ == "__main__": + main() diff --git a/examples/data_augmentation/uda/utils/model_utils.py b/examples/data_augmentation/uda/utils/model_utils.py new file mode 100644 index 000000000..91a968140 --- /dev/null +++ b/examples/data_augmentation/uda/utils/model_utils.py @@ -0,0 +1,54 @@ +""" +Model utility functions +""" +import math +import torch + + +def get_lr_multiplier(step: int, total_steps: int, warmup_steps: int) -> float: + r"""Calculate the learning rate multiplier given current step and the number + of warm-up steps. The learning rate schedule follows a linear warm-up and + linear decay. + + Args: + step: the current step + total_steps: total number of steps + warmup_steps: the number of warmup steps + """ + step = min(step, total_steps) + + multiplier = (1 - (step - warmup_steps) / (total_steps - warmup_steps)) + + if warmup_steps > 0 and step < warmup_steps: + warmup_percent_done = step / warmup_steps + multiplier = warmup_percent_done + + return multiplier + + +def get_tsa_threshold(schedule: str, global_step: int, num_train_steps: int, + start: float, end: float) -> float: + r"""Get threshold for Training Signal Annealing. From the UDA paper: + If the model’s predicted probability for the correct category pθ(y*|x) is higher than + a threshold ηt, we remove that example from the loss function. + Please see the paper for more details. + + Args: + schedule: one of 'linear_schedule', 'exp_schedule', 'log_schedule' + global_step: the current global step + num_train_steps: the total number of training steps + start: starting threshold + end: ending threshold + """ + training_progress = float(global_step) / float(num_train_steps) + if schedule == "linear_schedule": + threshold = training_progress + elif schedule == "exp_schedule": + scale = 5 + threshold = math.exp((training_progress - 1) * scale) + # [exp(-5), exp(0)] = [1e-2, 1] + elif schedule == "log_schedule": + scale = 5 + # [1 - exp(0), 1 - exp(-5)] = [0, 0.99] + threshold = 1 - math.exp((-training_progress) * scale) + return threshold * (end - start) + start diff --git a/forte/data/readers/largemovie_reader.py b/forte/data/readers/largemovie_reader.py index a224b0e33..876e99361 100644 --- a/forte/data/readers/largemovie_reader.py +++ b/forte/data/readers/largemovie_reader.py @@ -40,13 +40,12 @@ import os import logging -import re from typing import Iterator, List from forte.data.data_pack import DataPack from forte.data.data_utils_io import dataset_path_iterator from forte.data.readers.base_reader import PackReader -from ft.onto.base_ontology import Document, Sentence +from ft.onto.base_ontology import Document __all__ = [ "LargeMovieReader" @@ -60,17 +59,29 @@ class LargeMovieReader(PackReader): following the convention [[id]_[rating].txt]. """ - def __init__(self): - super().__init__() - self.REPLACE_NO_SPACE = re.compile( - r"(\:)|(\')|(\,)|(\")|(\()|(\))|(\[)|(\])") - self.REPLACE_WITH_NEWLINE = re.compile( - r"()|(\-)|(\/)|(\.)|(\;)|(\!)|(\?)") - - def preprocess_reviews(self, para): - para = self.REPLACE_NO_SPACE.sub("", para.lower()) - para = self.REPLACE_WITH_NEWLINE.sub("\n", para) - return para + def preprocess_reviews(self, st: str): + r"""Clean text. + Args: + st: input text string + """ + st = st.replace("
", " ") + st = st.replace(""", "\"") + st = st.replace("

", " ") + if "", start_pos) + if end_pos != -1: + st = st[:start_pos] + st[end_pos + 1:] + else: + print("incomplete href") + print("before", st) + st = st[:start_pos] + st[start_pos + len("", "") + st = st.replace("\\n", " ") + return st def _collect(self, *args, **kwargs) -> Iterator[str]: # pylint: disable = unused-argument @@ -89,21 +100,15 @@ def _collect(self, *args, **kwargs) -> Iterator[str]: def _parse_pack(self, file_path: str) -> Iterator[DataPack]: data_pack: DataPack = DataPack() - sent_begin: int = 0 doc_text: str = "" with open(file_path, encoding="utf8") as doc: - for para in doc: - para = self.preprocess_reviews(para) - sents = para.split("\n") - for sent in sents: - if len(sent) > 0: - sent = sent.strip() - doc_text += sent + " " - doc_offset = sent_begin + len(sent) + 1 - # Add sentences. - Sentence(data_pack, sent_begin, doc_offset - 1) - sent_begin = doc_offset + st_list = doc.readlines() + if len(st_list) != 1: + raise AssertionError("Raw data file contains more than" + "one example.") + doc_text = st_list[0] + doc_text = self.preprocess_reviews(doc_text) pos_dir: str = os.path.basename(os.path.dirname(file_path)) movie_file: str = os.path.basename(file_path) diff --git a/tests/forte/data/readers/largemovie_reader_test.py b/tests/forte/data/readers/largemovie_reader_test.py index 182a4a817..86135296f 100644 --- a/tests/forte/data/readers/largemovie_reader_test.py +++ b/tests/forte/data/readers/largemovie_reader_test.py @@ -43,22 +43,33 @@ def setUp(self): # pos0 doc's leading text, neg1 doc's ending text. self.doc_text: Dict[str, str] = \ {"pos": - "bromwell high is a cartoon comedy it ran at the same time as " - "some other programs about school life such as teachers my 35 " - "years in the teaching profession lead me to believe that " - "bromwell highs satire is much closer to reality than is", + 'Bromwell High is a cartoon comedy. It ran at the same time' + ' as some other programs about school life, such as "Teache' + 'rs". My 35 years in the teaching profession lead me to bel' + 'ieve that Bromwell High\'s satire is much closer to reality ' + 'than is "Teachers". The scramble to survive financially, the' + ' insightful students who can see right through their pathetic' + ' teachers\' pomp, the pettiness of the whole situation, all ' + 'remind me of the schools I knew and their students. When I saw' + ' the episode in which a student repeatedly tried to burn down ' + 'the school, I immediately recalled ......... at .......... ' + 'High. A classic line: INSPECTOR: I\'m here to sack one of your ' + 'teachers. STUDENT: Welcome to Bromwell High. I expect that ' + 'many adults of my age think that Bromwell High is far fetched.' + ' What a pity that it isn\'t!', "neg": - "this new imdb rule of requiring ten lines for every review " - "when a movie is this worthless it doesnt require ten lines of " - "text to let other readers know that it is a waste of time " - "and tape avoid this movie"} - # pos3 sentence #1, neg3 sentence #5. - self.sent_text: Dict[str, str] = \ - {"pos": - "all the worlds a stage and its people actors in it", - "neg": - "i was put through tears repulsion shock anger sympathy " - "and misery when reading about the women of union street"} + 'Robert DeNiro plays the most unbelievably intelligent ' + 'illiterate of all time. This movie is so wasteful of talent,' + ' it is truly disgusting. The script is unbelievable. The ' + 'dialog is unbelievable. Jane Fonda\'s character is a caricature' + ' of herself, and not a funny one. The movie moves at a snail\'s' + ' pace, is photographed in an ill-advised manner, and is ' + 'insufferably preachy. It also plugs in every cliche in the ' + 'book. Swoozie Kurtz is excellent in a supporting role, but so' + ' what? Equally annoying is this new IMDB rule of requiring ' + 'ten lines for every review. When a movie is this worthless, ' + 'it doesn\'t require ten lines of text to let other readers ' + 'know that it is a waste of time and tape. Avoid this movie.'} # pos0 doc's score, neg1 doc's score. self.score: Dict[str, float] = \ {"pos": 0.9, @@ -84,15 +95,11 @@ def test_reader_text(self): docid1 = self.doc_ids[dir][1] if pack.pack_name == docid1: for doc in pack.get(Document): + print(doc.text) self.assertIn(self.doc_text[dir], doc.text) # test sentiments. self.assertEqual( doc.sentiment[docid1], self.score[dir]) - # Test sentences. - elif pack.pack_name == docid0: - sents = pack.get(Sentence) - self.assertTrue(self.sent_text[dir] in - [sent.text for sent in sents]) count_packs += 1