Skip to content

Commit

Permalink
Word-level tasks (#45)
Browse files Browse the repository at this point in the history
* Added word-level task codes

* minor change based on latest code

* minor correction acc to latest version

* Updated codes with llama2 config files

* Updated word-level tasks train and test config files

* Added all word-level configs for uni, bi, bi-mntp, and bi-mntp-simcse

* Updated readme with word-level tasks explanation

* minor local specific fix remove

---------

Co-authored-by: vaibhavad <[email protected]>
  • Loading branch information
ParishadBehnam and vaibhavad authored May 3, 2024
1 parent 125d803 commit c5cb6d7
Show file tree
Hide file tree
Showing 28 changed files with 1,479 additions and 1 deletion.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ dist/
**/__pycache__
wandb/**
output/**
cache/**
cache/**
output/
*.log
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,32 @@ The Mistral training configuration [file](train_configs/supervised/Mistral.json)
Similar configurations are also available for [Meta-Llama-3-8B](train_configs/supervised/MetaLlama3.json), [Llama-2-7B](train_configs/supervised/Llama2.json), and [Sheared-Llama-1.3B](train_configs/supervised/Sheared-Llama.json) models.


### Word-level tasks training

To tune the model for word-level tasks, we define a classifier on top of the models, and only train the classifier weights. The code is adapted from HuggingFace token classification [example](https://huggingface.co/docs/transformers/en/tasks/token_classification). To train and test the classifier for Llama-2-7B MNTP model on `pos_tags` task, run the following command:
```bash
python experiments/run_word_task.py train_configs/word-task/Llama2-bi-mntp.json
python experiments/test_word_task.py --config_file test_configs/word-task/Llama2-bi-mntp.json
```
The config files contain all the parameters and configurations used in our paper. For instance, `Llama2-bi-mntp.json` includes:
```json
{
"model_name_or_path": "meta-llama/Llama-2-7b-chat-hf",
"peft_addr": "McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp", // or any local directory containing `adapter_model` files.
"model_class": "custom",
"bidirectional": true,
"classifier_dropout": 0.1,
"merge_subwords": true,
"retroactive_labels": "next_token",
"output_dir": "output/word-task/pos_tags/Llama2/bi-mntp",
"dataset_name": "conll2003",
"task": "pos_tags", // or ner_tags, or chunk_tags
// ....
}
```
[train_configs/word-task](train_configs/word-task) and [test_configs/word-task](train_configs/word-task) contain similar configurations for Llama-2-7B, Mistral-7B, and Sheared-Llama-1.3B for all Uni, Bi, Bi-MNTP, and Bi-MNTP-SimCSE (LLM2Vec) variants.


## Citation
If you find our work helpful, please cite us:
```bibtex
Expand Down
Loading

0 comments on commit c5cb6d7

Please sign in to comment.