diff --git a/.gitignore b/.gitignore index d2b3239..c87e811 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,6 @@ dist/ **/__pycache__ wandb/** output/** -cache/** \ No newline at end of file +cache/** +output/ +*.log \ No newline at end of file diff --git a/README.md b/README.md index ede5bf9..25feed6 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,32 @@ The Mistral training configuration [file](train_configs/supervised/Mistral.json) Similar configurations are also available for [Meta-Llama-3-8B](train_configs/supervised/MetaLlama3.json), [Llama-2-7B](train_configs/supervised/Llama2.json), and [Sheared-Llama-1.3B](train_configs/supervised/Sheared-Llama.json) models. +### Word-level tasks training + +To tune the model for word-level tasks, we define a classifier on top of the models, and only train the classifier weights. The code is adapted from HuggingFace token classification [example](https://huggingface.co/docs/transformers/en/tasks/token_classification). To train and test the classifier for Llama-2-7B MNTP model on `pos_tags` task, run the following command: +```bash +python experiments/run_word_task.py train_configs/word-task/Llama2-bi-mntp.json +python experiments/test_word_task.py --config_file test_configs/word-task/Llama2-bi-mntp.json +``` +The config files contain all the parameters and configurations used in our paper. For instance, `Llama2-bi-mntp.json` includes: +```json +{ + "model_name_or_path": "meta-llama/Llama-2-7b-chat-hf", + "peft_addr": "McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp", // or any local directory containing `adapter_model` files. + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "merge_subwords": true, + "retroactive_labels": "next_token", + "output_dir": "output/word-task/pos_tags/Llama2/bi-mntp", + "dataset_name": "conll2003", + "task": "pos_tags", // or ner_tags, or chunk_tags + // .... +} +``` +[train_configs/word-task](train_configs/word-task) and [test_configs/word-task](train_configs/word-task) contain similar configurations for Llama-2-7B, Mistral-7B, and Sheared-Llama-1.3B for all Uni, Bi, Bi-MNTP, and Bi-MNTP-SimCSE (LLM2Vec) variants. + + ## Citation If you find our work helpful, please cite us: ```bibtex diff --git a/experiments/run_word_task.py b/experiments/run_word_task.py new file mode 100644 index 0000000..7985d07 --- /dev/null +++ b/experiments/run_word_task.py @@ -0,0 +1,725 @@ +""" +The script is adapted from https://huggingface.co/docs/transformers/en/tasks/token_classification +""" +import logging +import os +import sys +import warnings +from dataclasses import dataclass, field +import numpy as np +from typing import List, Optional, Tuple, Union + +import datasets +import evaluate +from datasets import load_dataset + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +import transformers +from transformers import ( + PreTrainedModel, + MODEL_FOR_MASKED_LM_MAPPING, + AutoConfig, + AutoTokenizer, + HfArgumentParser, + Trainer, + TrainingArguments, + TrainerCallback, + set_seed, + AutoModelForTokenClassification, + DataCollatorForTokenClassification +) + +from transformers.modeling_outputs import TokenClassifierOutput +from transformers.utils import send_example_telemetry +from transformers.utils.versions import require_version + +from llm2vec import LLM2Vec + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +class ModelForWordTask(PreTrainedModel): + def __init__(self, config, model, merge_subwords=False, **model_args): + PreTrainedModel.__init__(self, config) + self.model = model + self.merge_subwords = merge_subwords + + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + + self.dropout = nn.Dropout(classifier_dropout) + self.num_labels = config.num_labels + self.classifier = nn.Linear(config.hidden_size, config.num_labels).to(model_args.get("torch_dtype")) + + # Initialize weights and apply final processing + self.post_init() + + def _merge_subwords(self, hidden_states, token_type_ids, attention_mask): + new_hidden_states = hidden_states.clone() + for b in range(hidden_states.shape[0]): + for w in torch.arange(0, token_type_ids[b].max() + 1): + words_w = (token_type_ids[b] == w) * (attention_mask[b]>0) + new_hidden_states[b][words_w] = torch.mean(hidden_states[b][words_w], dim=0).repeat(sum(words_w), 1) + return new_hidden_states + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_type_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None + ) -> Union[Tuple, TokenClassifierOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + if self.merge_subwords: + hidden_states = self._merge_subwords(hidden_states, token_type_ids, attention_mask) + + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs.hidden_states + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + attentions=outputs.attentions, + ) + +logger = logging.getLogger(__name__) +MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) +LABELS = { + "conll2003": { + "pos_tags": {'"': 0, "''": 1, '#': 2, '$': 3, '(': 4, ')': 5, ',': 6, '.': 7, ':': 8, '``': 9, 'CC': 10, 'CD': 11, 'DT': 12, + 'EX': 13, 'FW': 14, 'IN': 15, 'JJ': 16, 'JJR': 17, 'JJS': 18, 'LS': 19, 'MD': 20, 'NN': 21, 'NNP': 22, 'NNPS': 23, + 'NNS': 24, 'NN|SYM': 25, 'PDT': 26, 'POS': 27, 'PRP': 28, 'PRP$': 29, 'RB': 30, 'RBR': 31, 'RBS': 32, 'RP': 33, + 'SYM': 34, 'TO': 35, 'UH': 36, 'VB': 37, 'VBD': 38, 'VBG': 39, 'VBN': 40, 'VBP': 41, 'VBZ': 42, 'WDT': 43, + 'WP': 44, 'WP$': 45, 'WRB': 46}, + "chunk_tags": {'O': 0, 'B-ADJP': 1, 'I-ADJP': 2, 'B-ADVP': 3, 'I-ADVP': 4, 'B-CONJP': 5, 'I-CONJP': 6, 'B-INTJ': 7, 'I-INTJ': 8, + 'B-LST': 9, 'I-LST': 10, 'B-NP': 11, 'I-NP': 12, 'B-PP': 13, 'I-PP': 14, 'B-PRT': 15, 'I-PRT': 16, 'B-SBAR': 17, + 'I-SBAR': 18, 'B-UCP': 19, 'I-UCP': 20, 'B-VP': 21, 'I-VP': 22}, + "ner_tags": {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8} + } +} + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + }, + ) + config_overrides: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + ) + }, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + token: str = field( + default=None, + metadata={ + "help": ( + "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " + "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." + ) + }, + ) + use_auth_token: bool = field( + default=None, + metadata={ + "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead." + }, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": ( + "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option " + "should only be set to `True` for repositories you trust and in which you have read the code, as it will " + "execute code present on the Hub on your local machine." + ) + }, + ) + low_cpu_mem_usage: bool = field( + default=False, + metadata={ + "help": ( + "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. " + "set True will benefit LLM loading time and RAM consumption." + ) + }, + ) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) + attn_implementation: Optional[str] = field( + default="sdpa", + metadata={ + "help": ("The attention implementation to use in the model."), + "choices": ["eager", "sdpa", "flash_attention_2"], + }, + ) + classifier_dropout: Optional[float] = field( + default=0.1, + metadata={"help": "The dropout rate for models"} + ) + peft_addr: Optional[str] = field( + default=None, + metadata={"help": "addr of lora adapter weights"} + ) + model_class: str = field( + default="custom", + metadata={ + "help": "One of the items 'custom' or 'auto'. 'custom' for LLM2Vec models and 'auto' for pretrained encoders such as BERT.", + "choices": ["custom", "auto"] + } + ) + merge_subwords: bool = field( + default=True, + metadata={"help": "Whether the representations of the subtokens get averaged."} + ) + bidirectional: bool = field( + default=True, + metadata={"help": "Whether to use bidirectional attention."} + ) + + def __post_init__(self): + if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): + raise ValueError( + "--config_overrides can't be used in combination with --config_name or --model_name_or_path" + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + overwrite_cache: bool = field( + default=True, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + max_seq_length: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated." + ) + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + mlm_probability: float = field( + default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} + ) + line_by_line: bool = field( + default=False, + metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": ( + "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + ) + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) + + def __post_init__(self): + if self.streaming: + require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") + + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + if extension not in ["csv", "json", "txt"]: + raise ValueError("`train_file` should be a csv, a json or a txt file.") + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + if extension not in ["csv", "json", "txt"]: + raise ValueError("`validation_file` should be a csv, a json or a txt file.") + +# add more arguments +@dataclass +class CustomArguments: + """ + Custom arguments for the script + """ + stop_after_n_steps: int = field( + default=10000, + metadata={"help": "Stop training after n steps"} + ) + data_collator_type: str = field( + default="custom", + metadata={"help": "The type of data collator. Options: custom, default, custom_no_random"} + ) + task: Optional[str] = field( + default="pos_tags", + metadata={ + "help": "One of the 'pos_tags', 'chunk_tags', and 'ner_tags' choices", + "choices": ["pos_tags", "ner_tags", "chunk_tags"] + } + ) + retroactive_labels: str = field( + default="next_token", + metadata={ + "help": "Whether the tokens representations are used to predict the next token's labels. Options: same_token, next_word, next_token.", + "choices": ["next_token", "same_token"] + } + ) + + +class StopTrainingCallback(TrainerCallback): + def __init__(self, stop_after_n_steps: int): + self.stop_after_n_steps = stop_after_n_steps + + def on_step_end(self, args, state, control, **kwargs): + if state.global_step >= self.stop_after_n_steps: + control.should_training_stop = True + + +class WordTaskTrainer(Trainer): + def _save(self, output_dir: Optional[str] = None, state_dict=None): + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + + torch.save(self.model.classifier, os.path.join(output_dir, 'classifier.pt')) + self.tokenizer.save_pretrained(output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(self.args, os.path.join(output_dir, "training_args.bin")) + + +def main(): + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)) + # model_args, data_args, training_args, custom_args = parser.parse_args_into_dataclasses() + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args, custom_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + ( + model_args, + data_args, + training_args, + custom_args, + ) = parser.parse_args_into_dataclasses() + + if training_args.gradient_checkpointing: + training_args.gradient_checkpointing_kwargs = {'use_reentrant': False} + + if model_args.use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.", + FutureWarning, + ) + if model_args.token is not None: + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + model_args.token = model_args.use_auth_token + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_word_task", model_args, data_args) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " + + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + set_seed(training_args.seed) + + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + streaming=data_args.streaming, + ) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + token=model_args.token, + streaming=data_args.streaming, + ) + raw_datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + token=model_args.token, + streaming=data_args.streaming, + ) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if extension == "txt": + extension = "text" + raw_datasets = load_dataset( + extension, + data_files=data_files, + cache_dir=model_args.cache_dir, + token=model_args.token, + ) + + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + token=model_args.token, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + token=model_args.token, + ) + + assert data_args.dataset_name in LABELS and custom_args.task in LABELS[data_args.dataset_name], f"LABELS[{data_args.dataset_name}][{custom_args.task}] is not defined." + + config_kwargs = { + "num_labels": len(LABELS[data_args.dataset_name][custom_args.task]), + "id2label": {i: lab for (lab, i) in LABELS[data_args.dataset_name][custom_args.task].items()}, + "label2id": LABELS[data_args.dataset_name][custom_args.task], + "classifier_dropout": model_args.classifier_dropout + } + + tokenizer_kwargs = { + "cache_dir": model_args.cache_dir, + "use_fast": model_args.use_fast_tokenizer, + "revision": model_args.model_revision, + "token": model_args.token, + "trust_remote_code": model_args.trust_remote_code, + } + if model_args.tokenizer_name: + if "gpt" in model_args.tokenizer_name: + tokenizer_kwargs["add_prefix_space"] = True + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) + elif model_args.model_name_or_path: + if "gpt" in model_args.model_name_or_path: + tokenizer_kwargs["add_prefix_space"] = True + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script. " + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if model_args.model_class == "custom": + tokenizer.model_input_names.append("token_type_ids") + if model_args.model_class == "auto": + assert not model_args.merge_subwords + + if model_args.model_class == "custom": + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) + else: + raise ValueError('Invalid config loading') + + for k, v in config_kwargs.items(): + config.__setattr__(k, v) + + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + l2v = LLM2Vec.from_pretrained( + base_model_name_or_path=model_args.model_name_or_path, + enable_bidirectional=model_args.bidirectional, + peft_model_name_or_path=model_args.peft_addr, + merge_peft=False, + torch_dtype=torch_dtype, + attn_implementation=model_args.attn_implementation, + ) + + model = ModelForWordTask( + model=l2v.model, + merge_subwords=model_args.merge_subwords, + config=config, + torch_dtype=torch_dtype, + ) + + MyTrainer = WordTaskTrainer + + elif model_args.model_class == "auto": + model = AutoModelForTokenClassification.from_pretrained(model_args.model_name_or_path, + num_labels=config_kwargs["num_labels"], + id2label=config_kwargs["id2label"], + label2id=config_kwargs["label2id"]) + MyTrainer = Trainer + + else: + raise ValueError(f"{model_args.model_class} is not implemented. Only 'auto' and 'custom' model_class options are valid.") + + # only train classifier + for (n,p) in list(model.named_parameters()): + if "classifier" in n: + p.requires_grad = True + else: + p.requires_grad = False + + if data_args.max_seq_length is None: + max_seq_length = tokenizer.model_max_length + if max_seq_length > 1024: + logger.warning( + "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value" + " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can" + " override this default with `--block_size xxx`." + ) + max_seq_length = 1024 + else: + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the " + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + def tokenize_and_align_labels(examples): + task = custom_args.task + padding = "max_length" if data_args.pad_to_max_length else False + tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True, padding=padding, max_length=max_seq_length) + + labels = [] + words = [] + for i, label in enumerate(examples[task]): + if custom_args.retroactive_labels in ["same_token"]: + word_ids = tokenized_inputs.word_ids(batch_index=i) + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + if word_idx is None: + label_ids.append(-100) + elif word_idx != previous_word_idx: + label_ids.append(label[word_idx]) + else: + label_ids.append(-100) + previous_word_idx = word_idx + labels.append(label_ids) + word_ids = [-1 if w is None else w for w in word_ids] + words.append(word_ids) + elif custom_args.retroactive_labels == "next_token": + word_ids = tokenized_inputs.word_ids(batch_index=i) + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + if word_idx is None: + label_ids.append(-100) + elif word_idx != previous_word_idx: + label_ids.append(label[word_idx]) + else: + label_ids.append(-100) + previous_word_idx = word_idx + label_ids.append(-100) + labels.append(label_ids[1:]) + word_ids = word_ids[1:] + [None] + word_ids = [-1 if w is None else w for w in word_ids] + words.append(word_ids) + else: + raise ValueError(f"retroactive_labels {custom_args.retroactive_labels} is not implemented.") + + tokenized_inputs["labels"] = labels + if model_args.model_class == "custom": + tokenized_inputs["token_type_ids"] = words + return tokenized_inputs + + tokenized_dataset = raw_datasets.map(tokenize_and_align_labels, batched=True, remove_columns=list(LABELS[data_args.dataset_name].keys())+["tokens", "id"], load_from_cache_file=not data_args.overwrite_cache) + data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer) + seqeval = evaluate.load("seqeval") + def compute_metrics(p): + predictions, labels = p + predictions = predictions[0] + predictions = np.argmax(predictions, axis=2) + + true_predictions = [ + [config_kwargs["id2label"][p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [config_kwargs["id2label"][l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + + results = seqeval.compute(predictions=true_predictions, references=true_labels) + return { + "precision": results["overall_precision"], + "recall": results["overall_recall"], + "f1": results["overall_f1"], + "accuracy": results["overall_accuracy"], + } + + trainer = MyTrainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset["train"], + eval_dataset=tokenized_dataset["validation"], + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics, + ) + trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps)) + + trainer.train() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/experiments/test_word_task.py b/experiments/test_word_task.py new file mode 100644 index 0000000..b174720 --- /dev/null +++ b/experiments/test_word_task.py @@ -0,0 +1,233 @@ +import os +import sys +import logging +import argparse +from transformers import AutoTokenizer, AutoConfig, AutoModelForTokenClassification, set_seed, HfArgumentParser +import torch +from datasets import load_dataset +import evaluate +import json +from tqdm import tqdm +from run_word_task import ModelForWordTask +from llm2vec import LLM2Vec + + +LABELS = { + "conll2003": { + "pos_tags": {'"': 0, "''": 1, '#': 2, '$': 3, '(': 4, ')': 5, ',': 6, '.': 7, ':': 8, '``': 9, 'CC': 10, 'CD': 11, 'DT': 12, + 'EX': 13, 'FW': 14, 'IN': 15, 'JJ': 16, 'JJR': 17, 'JJS': 18, 'LS': 19, 'MD': 20, 'NN': 21, 'NNP': 22, 'NNPS': 23, + 'NNS': 24, 'NN|SYM': 25, 'PDT': 26, 'POS': 27, 'PRP': 28, 'PRP$': 29, 'RB': 30, 'RBR': 31, 'RBS': 32, 'RP': 33, + 'SYM': 34, 'TO': 35, 'UH': 36, 'VB': 37, 'VBD': 38, 'VBG': 39, 'VBN': 40, 'VBP': 41, 'VBZ': 42, 'WDT': 43, + 'WP': 44, 'WP$': 45, 'WRB': 46}, + "chunk_tags": {'O': 0, 'B-ADJP': 1, 'I-ADJP': 2, 'B-ADVP': 3, 'I-ADVP': 4, 'B-CONJP': 5, 'I-CONJP': 6, 'B-INTJ': 7, 'I-INTJ': 8, + 'B-LST': 9, 'I-LST': 10, 'B-NP': 11, 'I-NP': 12, 'B-PP': 13, 'I-PP': 14, 'B-PRT': 15, 'I-PRT': 16, 'B-SBAR': 17, + 'I-SBAR': 18, 'B-UCP': 19, 'I-UCP': 20, 'B-VP': 21, 'I-VP': 22}, + "ner_tags": {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8} + } +} + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + parser = argparse.ArgumentParser() + parser.add_argument("--model_class", default="custom", type=str) + parser.add_argument("--model_name_or_path", default=None, type=str) + parser.add_argument("--peft_addr", default=None, type=str, help="The dir address where adapter_model.bin is saved.") + parser.add_argument("--cls_addr", default=None, type=str, help="The dir address where classifier is saved.") + parser.add_argument("--bidirectional", default=True, type=str2bool) + parser.add_argument("--merge_subwords", default=True, type=str2bool) + parser.add_argument("--output_dir", default=None, type=str) + parser.add_argument("--classifier_dropout", default=0.1, type=float) + parser.add_argument("--attn_implementation", default="sdpa", type=str, choices=["sdpa", "eager", "flash_attention_2"]) + parser.add_argument("--torch_dtype", default=None, type=str, choices=["auto", "bfloat16", "float16", "float32"]) + + parser.add_argument("--retroactive_labels", default="next_token", type=str, choices=["next_token", "same_token"]) + parser.add_argument("--dataset_name", default=None, type=str) + parser.add_argument("--task", default=None, type=str, choices=["pos_tags", "chunk_tags", "ner_tags"]) + parser.add_argument("--max_seq_length", default=1024, type=int) + parser.add_argument("--batch_size", default=32, type=int) + parser.add_argument("--seed", default=32, type=int) + + parser.add_argument("--config_file", default=None, type=str) + + args = parser.parse_args() + + if args.config_file is not None: + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + from pathlib import Path + import json + + json_text = json.load(open(os.path.abspath(args.config_file))) + argparse_dict = vars(args) + argparse_dict.update(json_text) + # args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + args = parser.parse_args() + + path_to_check = args.peft_addr if args.peft_addr else args.model_name_or_path + assert args.output_dir is not None, "If you want to evaluate a model, you have to provide the output_dir" + os.makedirs(args.output_dir,exist_ok=True) + + set_seed(args.seed) + + tokenizer_kwargs = {} + if "gpt" in args.model_name_or_path: + tokenizer_kwargs["add_prefix_space"] = True + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, **tokenizer_kwargs) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + if args.model_class == "custom": + tokenizer.model_input_names.append("token_type_ids") + + if args.model_class == "auto": + assert not args.merge_subwords + + assert args.dataset_name in LABELS and args.task in LABELS[args.dataset_name], f"LABELS[{args.dataset_name}][{args.task}] is not defined." + + config_kwargs = { + "num_labels": len(LABELS[args.dataset_name][args.task]), + "id2label": {i: lab for (lab, i) in LABELS[args.dataset_name][args.task].items()}, + "label2id": LABELS[args.dataset_name][args.task], + "classifier_dropout": args.classifier_dropout + } + + if args.model_class == "custom": + if args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path, **config_kwargs) + else: + raise ValueError('Invalid config loading') + + for k, v in config_kwargs.items(): + config.__setattr__(k, v) + + torch_dtype = ( + args.torch_dtype + if args.torch_dtype in ["auto", None] + else getattr(torch, args.torch_dtype) + ) + l2v = LLM2Vec.from_pretrained( + base_model_name_or_path=args.model_name_or_path, + enable_bidirectional=args.bidirectional, + peft_model_name_or_path=args.peft_addr, + merge_peft=False, + torch_dtype=torch_dtype, + attn_implementation=args.attn_implementation, + ) + model = ModelForWordTask( + model=l2v.model, + merge_subwords=args.merge_subwords, + config=config, + torch_dtype=torch_dtype, + ) + + classifier_path = os.path.join(args.cls_addr, "classifier.pt") + if os.path.exists(classifier_path): + print(f"Loading classifier from {classifier_path}") + model.classifier = torch.load(classifier_path) + else: + raise ValueError("classifier does not exist in", classifier_path) + + elif args.model_class == "auto": + model = AutoModelForTokenClassification.from_pretrained(args.model_name_or_path, + num_labels=len(LABELS[args.dataset_name][args.task]), + id2label={i: lab for (lab, i) in LABELS[args.dataset_name][args.task].items()}, + label2id=LABELS[args.dataset_name][args.task]) + else: + raise ValueError(f"{args.model_class} is not implemented. Only auto and custom model_class options are valid.") + + model = model.cuda() + + raw_datasets = load_dataset( + args.dataset_name, split="test" + ) + + def tokenize_and_align_labels(examples): + task = args.task + tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True, padding="max_length", max_length=args.max_seq_length, return_tensors="pt") + + labels = [] + words = [] + for i, label in enumerate(examples[task]): + if args.retroactive_labels in ["same_token"]: + # if args.retroactive_labels == "next_word": + # label = label[1:] + [-100] + word_ids = tokenized_inputs.word_ids(batch_index=i) + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + if word_idx is None: + label_ids.append(-100) + elif word_idx != previous_word_idx: + label_ids.append(label[word_idx]) + else: + label_ids.append(-100) + previous_word_idx = word_idx + labels.append(label_ids) + word_ids = [-1 if w is None else w for w in word_ids] + words.append(word_ids) + elif args.retroactive_labels == "next_token": + word_ids = tokenized_inputs.word_ids(batch_index=i) + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + if word_idx is None: + label_ids.append(-100) + elif word_idx != previous_word_idx: + label_ids.append(label[word_idx]) + else: + label_ids.append(-100) + previous_word_idx = word_idx + label_ids.append(-100) + labels.append(label_ids[1:]) + word_ids = word_ids[1:] + [None] + word_ids = [-1 if w is None else w for w in word_ids] + words.append(word_ids) + else: + raise ValueError(f"retroactive_labels {args.retroactive_labels} is not implemented.") + + tokenized_inputs["labels"] = torch.tensor(labels) + if args.model_class == "custom": + tokenized_inputs["token_type_ids"] = words + return tokenized_inputs + + tokenized_dataset = raw_datasets.map(tokenize_and_align_labels, batched=True, remove_columns=list(LABELS[args.dataset_name].keys())+["tokens", "id"]) + with torch.no_grad(): + predictions = None + labels = None + for batch_begin in tqdm(torch.arange(0, len(tokenized_dataset), args.batch_size)): + features = { + "input_ids": torch.tensor(tokenized_dataset[batch_begin:batch_begin+args.batch_size]["input_ids"]).to(model.device), + "attention_mask": torch.tensor(tokenized_dataset[batch_begin:batch_begin+args.batch_size]["attention_mask"]).to(model.device) + } + if "token_type_ids" in tokenized_dataset[batch_begin:batch_begin+args.batch_size]: + features["token_type_ids"] = torch.tensor(tokenized_dataset[batch_begin:batch_begin+args.batch_size]["token_type_ids"]).to(model.device) + + labs = torch.tensor(tokenized_dataset[batch_begin:batch_begin+args.batch_size]["labels"]) + + logits = model(**features).logits + preds = torch.argmax(logits, dim=-1) + if predictions is None: + predictions = preds + labels = labs + else: + predictions = torch.concatenate((predictions, preds)) + labels = torch.concatenate((labels, labs)) + + precision_metric = evaluate.load("precision") + metrics = precision_metric.compute(references=labels[labels != -100], predictions=predictions[labels != -100], average="micro") + + with open(os.path.join(args.output_dir, "result_summary.json"), "w") as f: + json.dump(metrics, f) + print(metrics) \ No newline at end of file diff --git a/test_configs/word-task/Llama2-bi-mntp-simcse.json b/test_configs/word-task/Llama2-bi-mntp-simcse.json new file mode 100644 index 0000000..7341fbf --- /dev/null +++ b/test_configs/word-task/Llama2-bi-mntp-simcse.json @@ -0,0 +1,16 @@ +{ + "model_name_or_path": "McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp", + "peft_addr": "McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp-unsup-simcse", + "cls_addr": "output/word-task/pos_tags/Llama2/bi-mntp/checkpoint-3000", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "next_token", + "output_dir": "output/word-task/pos_tags/Llama2/bi-mntp-simcse/results/checkpoint-3000", + "dataset_name": "conll2003", + "task": "pos_tags", + "max_seq_length": 512 +} \ No newline at end of file diff --git a/test_configs/word-task/Llama2-bi-mntp.json b/test_configs/word-task/Llama2-bi-mntp.json new file mode 100644 index 0000000..f8435fa --- /dev/null +++ b/test_configs/word-task/Llama2-bi-mntp.json @@ -0,0 +1,16 @@ +{ + "model_name_or_path": "meta-llama/Llama-2-7b-chat-hf", + "peft_addr": "McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp", + "cls_addr": "output/word-task/pos_tags/Llama2/bi-mntp/checkpoint-3000", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "next_token", + "output_dir": "output/word-task/pos_tags/Llama2/bi-mntp/results/checkpoint-3000", + "dataset_name": "conll2003", + "task": "pos_tags", + "max_seq_length": 512 +} \ No newline at end of file diff --git a/test_configs/word-task/Llama2-bi.json b/test_configs/word-task/Llama2-bi.json new file mode 100644 index 0000000..60b4484 --- /dev/null +++ b/test_configs/word-task/Llama2-bi.json @@ -0,0 +1,15 @@ +{ + "model_name_or_path": "meta-llama/Llama-2-7b-chat-hf", + "cls_addr": "output/word-task/pos_tags/Llama2/bi/checkpoint-3000", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "same_token", + "output_dir": "output/word-task/pos_tags/Llama2/bi/results/checkpoint-3000", + "dataset_name": "conll2003", + "task": "pos_tags", + "max_seq_length": 512 +} \ No newline at end of file diff --git a/test_configs/word-task/Llama2-uni.json b/test_configs/word-task/Llama2-uni.json new file mode 100644 index 0000000..9cffe6c --- /dev/null +++ b/test_configs/word-task/Llama2-uni.json @@ -0,0 +1,15 @@ +{ + "model_name_or_path": "meta-llama/Llama-2-7b-chat-hf", + "cls_addr": "output/word-task/pos_tags/Llama2/uni/checkpoint-3000", + "model_class": "custom", + "bidirectional": false, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "same_token", + "output_dir": "output/word-task/pos_tags/Llama2/uni/results/checkpoint-3000", + "dataset_name": "conll2003", + "task": "pos_tags", + "max_seq_length": 512 +} \ No newline at end of file diff --git a/test_configs/word-task/Mistral-bi-mntp-simcse.json b/test_configs/word-task/Mistral-bi-mntp-simcse.json new file mode 100644 index 0000000..6f992f7 --- /dev/null +++ b/test_configs/word-task/Mistral-bi-mntp-simcse.json @@ -0,0 +1,16 @@ +{ + "model_name_or_path": "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", + "peft_addr": "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse", + "cls_addr": "output/word-task/pos_tags/Mistral/bi-mntp-simcse/checkpoint-3000", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "next_token", + "output_dir": "output/word-task/pos_tags/Mistral/bi-mntp-simcse/results/checkpoint-3000", + "dataset_name": "conll2003", + "task": "pos_tags", + "max_seq_length": 512 +} \ No newline at end of file diff --git a/test_configs/word-task/Mistral-bi-mntp.json b/test_configs/word-task/Mistral-bi-mntp.json new file mode 100644 index 0000000..c5d9a46 --- /dev/null +++ b/test_configs/word-task/Mistral-bi-mntp.json @@ -0,0 +1,16 @@ +{ + "model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2", + "peft_addr": "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", + "cls_addr": "output/word-task/pos_tags/Mistral/bi-mntp/checkpoint-3000", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "next_token", + "output_dir": "output/word-task/pos_tags/Mistral/bi-mntp/results/checkpoint-3000", + "dataset_name": "conll2003", + "task": "pos_tags", + "max_seq_length": 512 +} \ No newline at end of file diff --git a/test_configs/word-task/Mistral-bi.json b/test_configs/word-task/Mistral-bi.json new file mode 100644 index 0000000..6168ca7 --- /dev/null +++ b/test_configs/word-task/Mistral-bi.json @@ -0,0 +1,15 @@ +{ + "model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2", + "cls_addr": "output/word-task/pos_tags/Mistral/bi/checkpoint-3000", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "same_token", + "output_dir": "output/word-task/pos_tags/Mistral/bi/results/checkpoint-3000", + "dataset_name": "conll2003", + "task": "pos_tags", + "max_seq_length": 512 +} \ No newline at end of file diff --git a/test_configs/word-task/Mistral-uni.json b/test_configs/word-task/Mistral-uni.json new file mode 100644 index 0000000..27fe022 --- /dev/null +++ b/test_configs/word-task/Mistral-uni.json @@ -0,0 +1,15 @@ +{ + "model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2", + "cls_addr": "output/word-task/pos_tags/Mistral/uni/checkpoint-3000", + "model_class": "custom", + "bidirectional": false, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "same_token", + "output_dir": "output/word-task/pos_tags/Mistral/uni/results/checkpoint-3000", + "dataset_name": "conll2003", + "task": "pos_tags", + "max_seq_length": 512 +} \ No newline at end of file diff --git a/test_configs/word-task/ShearedLlama-bi-mntp-simcse.json b/test_configs/word-task/ShearedLlama-bi-mntp-simcse.json new file mode 100644 index 0000000..e7c5676 --- /dev/null +++ b/test_configs/word-task/ShearedLlama-bi-mntp-simcse.json @@ -0,0 +1,16 @@ +{ + "model_name_or_path": "McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp", + "peft_addr": "McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-unsup-simcse", + "cls_addr": "output/word-task/pos_tags/ShearedLlama/bi-mntp-simcse/checkpoint-3000", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "next_token", + "output_dir": "output/word-task/pos_tags/ShearedLlama/bi-mntp-simcse/results/checkpoint-3000", + "dataset_name": "conll2003", + "task": "pos_tags", + "max_seq_length": 512 +} \ No newline at end of file diff --git a/test_configs/word-task/ShearedLlama-bi-mntp.json b/test_configs/word-task/ShearedLlama-bi-mntp.json new file mode 100644 index 0000000..b346938 --- /dev/null +++ b/test_configs/word-task/ShearedLlama-bi-mntp.json @@ -0,0 +1,16 @@ +{ + "model_name_or_path": "princeton-nlp/Sheared-LLaMA-1.3B", + "peft_addr": "McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp", + "cls_addr": "output/word-task/pos_tags/ShearedLlama/bi-mntp/checkpoint-3000", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "next_token", + "output_dir": "output/word-task/pos_tags/ShearedLlama/bi-mntp/results/checkpoint-3000", + "dataset_name": "conll2003", + "task": "pos_tags", + "max_seq_length": 512 +} \ No newline at end of file diff --git a/test_configs/word-task/ShearedLlama-bi.json b/test_configs/word-task/ShearedLlama-bi.json new file mode 100644 index 0000000..e8a9950 --- /dev/null +++ b/test_configs/word-task/ShearedLlama-bi.json @@ -0,0 +1,15 @@ +{ + "model_name_or_path": "princeton-nlp/Sheared-LLaMA-1.3B", + "cls_addr": "output/word-task/pos_tags/ShearedLlama/bi/checkpoint-3000", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "same_token", + "output_dir": "output/word-task/pos_tags/ShearedLlama/bi/results/checkpoint-3000", + "dataset_name": "conll2003", + "task": "pos_tags", + "max_seq_length": 512 +} \ No newline at end of file diff --git a/test_configs/word-task/ShearedLlama-uni.json b/test_configs/word-task/ShearedLlama-uni.json new file mode 100644 index 0000000..10609b3 --- /dev/null +++ b/test_configs/word-task/ShearedLlama-uni.json @@ -0,0 +1,15 @@ +{ + "model_name_or_path": "princeton-nlp/Sheared-LLaMA-1.3B", + "cls_addr": "output/word-task/pos_tags/ShearedLlama/uni/checkpoint-3000", + "model_class": "custom", + "bidirectional": false, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "same_token", + "output_dir": "output/word-task/pos_tags/ShearedLlama/uni/results/checkpoint-3000", + "dataset_name": "conll2003", + "task": "pos_tags", + "max_seq_length": 512 +} \ No newline at end of file diff --git a/train_configs/word-task/Llama2-bi-mntp-simcse.json b/train_configs/word-task/Llama2-bi-mntp-simcse.json new file mode 100644 index 0000000..1172f68 --- /dev/null +++ b/train_configs/word-task/Llama2-bi-mntp-simcse.json @@ -0,0 +1,26 @@ +{ + "model_name_or_path": "McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp", + "peft_addr": "McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp-unsup-simcse", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "next_token", + "output_dir": "output/word-task/pos_tags/Llama2/bi-mntp-simcse", + "dataset_name": "conll2003", + "task": "pos_tags", + "num_train_epochs": 10, + "learning_rate": 5e-4, + "do_train": true, + "save_steps": 250, + "logging_steps": 50, + "save_strategy": "steps", + "do_eval": true, + "eval_steps": 500, + "evaluation_strategy": "steps", + "overwrite_output_dir": true, + "stop_after_n_steps": 4000, + "remove_unused_columns": false +} \ No newline at end of file diff --git a/train_configs/word-task/Llama2-bi-mntp.json b/train_configs/word-task/Llama2-bi-mntp.json new file mode 100644 index 0000000..e7e36ae --- /dev/null +++ b/train_configs/word-task/Llama2-bi-mntp.json @@ -0,0 +1,26 @@ +{ + "model_name_or_path": "meta-llama/Llama-2-7b-chat-hf", + "peft_addr": "McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "next_token", + "output_dir": "output/word-task/pos_tags/Llama2/bi-mntp", + "dataset_name": "conll2003", + "task": "pos_tags", + "num_train_epochs": 10, + "learning_rate": 5e-4, + "do_train": true, + "save_steps": 250, + "logging_steps": 50, + "save_strategy": "steps", + "do_eval": true, + "eval_steps": 500, + "evaluation_strategy": "steps", + "overwrite_output_dir": true, + "stop_after_n_steps": 4000, + "remove_unused_columns": false +} \ No newline at end of file diff --git a/train_configs/word-task/Llama2-bi.json b/train_configs/word-task/Llama2-bi.json new file mode 100644 index 0000000..273492f --- /dev/null +++ b/train_configs/word-task/Llama2-bi.json @@ -0,0 +1,25 @@ +{ + "model_name_or_path": "meta-llama/Llama-2-7b-chat-hf", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "same_token", + "output_dir": "output/word-task/pos_tags/Llama2/bi", + "dataset_name": "conll2003", + "task": "pos_tags", + "num_train_epochs": 10, + "learning_rate": 5e-4, + "do_train": true, + "save_steps": 250, + "logging_steps": 50, + "save_strategy": "steps", + "do_eval": true, + "eval_steps": 500, + "evaluation_strategy": "steps", + "overwrite_output_dir": true, + "stop_after_n_steps": 4000, + "remove_unused_columns": false +} \ No newline at end of file diff --git a/train_configs/word-task/Llama2-uni.json b/train_configs/word-task/Llama2-uni.json new file mode 100644 index 0000000..82ec58d --- /dev/null +++ b/train_configs/word-task/Llama2-uni.json @@ -0,0 +1,25 @@ +{ + "model_name_or_path": "meta-llama/Llama-2-7b-chat-hf", + "model_class": "custom", + "bidirectional": false, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "same_token", + "output_dir": "output/word-task/pos_tags/Llama2/uni", + "dataset_name": "conll2003", + "task": "pos_tags", + "num_train_epochs": 10, + "learning_rate": 5e-4, + "do_train": true, + "save_steps": 250, + "logging_steps": 50, + "save_strategy": "steps", + "do_eval": true, + "eval_steps": 500, + "evaluation_strategy": "steps", + "overwrite_output_dir": true, + "stop_after_n_steps": 4000, + "remove_unused_columns": false +} \ No newline at end of file diff --git a/train_configs/word-task/Mistral-bi-mntp-simcse.json b/train_configs/word-task/Mistral-bi-mntp-simcse.json new file mode 100644 index 0000000..dd37ca2 --- /dev/null +++ b/train_configs/word-task/Mistral-bi-mntp-simcse.json @@ -0,0 +1,26 @@ +{ + "model_name_or_path": "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", + "peft_addr": "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "next_token", + "output_dir": "output/word-task/pos_tags/Mistral/bi-mntp-simcse", + "dataset_name": "conll2003", + "task": "pos_tags", + "num_train_epochs": 10, + "learning_rate": 5e-4, + "do_train": true, + "save_steps": 250, + "logging_steps": 50, + "save_strategy": "steps", + "do_eval": true, + "eval_steps": 500, + "evaluation_strategy": "steps", + "overwrite_output_dir": true, + "stop_after_n_steps": 4000, + "remove_unused_columns": false +} \ No newline at end of file diff --git a/train_configs/word-task/Mistral-bi-mntp.json b/train_configs/word-task/Mistral-bi-mntp.json new file mode 100644 index 0000000..336b9ac --- /dev/null +++ b/train_configs/word-task/Mistral-bi-mntp.json @@ -0,0 +1,26 @@ +{ + "model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2", + "peft_addr": "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "next_token", + "output_dir": "output/word-task/pos_tags/Mistral/bi-mntp", + "dataset_name": "conll2003", + "task": "pos_tags", + "num_train_epochs": 10, + "learning_rate": 5e-4, + "do_train": true, + "save_steps": 250, + "logging_steps": 50, + "save_strategy": "steps", + "do_eval": true, + "eval_steps": 500, + "evaluation_strategy": "steps", + "overwrite_output_dir": true, + "stop_after_n_steps": 4000, + "remove_unused_columns": false +} \ No newline at end of file diff --git a/train_configs/word-task/Mistral-bi.json b/train_configs/word-task/Mistral-bi.json new file mode 100644 index 0000000..eedf3fd --- /dev/null +++ b/train_configs/word-task/Mistral-bi.json @@ -0,0 +1,25 @@ +{ + "model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "same_token", + "output_dir": "output/word-task/pos_tags/Mistral/bi", + "dataset_name": "conll2003", + "task": "pos_tags", + "num_train_epochs": 10, + "learning_rate": 5e-4, + "do_train": true, + "save_steps": 250, + "logging_steps": 50, + "save_strategy": "steps", + "do_eval": true, + "eval_steps": 500, + "evaluation_strategy": "steps", + "overwrite_output_dir": true, + "stop_after_n_steps": 4000, + "remove_unused_columns": false +} \ No newline at end of file diff --git a/train_configs/word-task/Mistral-uni.json b/train_configs/word-task/Mistral-uni.json new file mode 100644 index 0000000..9dcfcf4 --- /dev/null +++ b/train_configs/word-task/Mistral-uni.json @@ -0,0 +1,25 @@ +{ + "model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2", + "model_class": "custom", + "bidirectional": false, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "same_token", + "output_dir": "output/word-task/pos_tags/Mistral/uni", + "dataset_name": "conll2003", + "task": "pos_tags", + "num_train_epochs": 10, + "learning_rate": 5e-4, + "do_train": true, + "save_steps": 250, + "logging_steps": 50, + "save_strategy": "steps", + "do_eval": true, + "eval_steps": 500, + "evaluation_strategy": "steps", + "overwrite_output_dir": true, + "stop_after_n_steps": 4000, + "remove_unused_columns": false +} \ No newline at end of file diff --git a/train_configs/word-task/ShearedLlama-bi-mntp-simcse.json b/train_configs/word-task/ShearedLlama-bi-mntp-simcse.json new file mode 100644 index 0000000..f67fc14 --- /dev/null +++ b/train_configs/word-task/ShearedLlama-bi-mntp-simcse.json @@ -0,0 +1,26 @@ +{ + "model_name_or_path": "McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp", + "peft_addr": "McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-unsup-simcse", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "next_token", + "output_dir": "output/word-task/pos_tags/ShearedLlama/bi-mntp-simcse", + "dataset_name": "conll2003", + "task": "pos_tags", + "num_train_epochs": 10, + "learning_rate": 5e-4, + "do_train": true, + "save_steps": 250, + "logging_steps": 50, + "save_strategy": "steps", + "do_eval": true, + "eval_steps": 500, + "evaluation_strategy": "steps", + "overwrite_output_dir": true, + "stop_after_n_steps": 4000, + "remove_unused_columns": false +} \ No newline at end of file diff --git a/train_configs/word-task/ShearedLlama-bi-mntp.json b/train_configs/word-task/ShearedLlama-bi-mntp.json new file mode 100644 index 0000000..cb73267 --- /dev/null +++ b/train_configs/word-task/ShearedLlama-bi-mntp.json @@ -0,0 +1,26 @@ +{ + "model_name_or_path": "princeton-nlp/Sheared-LLaMA-1.3B", + "peft_addr": "McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "next_token", + "output_dir": "output/word-task/pos_tags/ShearedLlama/bi-mntp", + "dataset_name": "conll2003", + "task": "pos_tags", + "num_train_epochs": 10, + "learning_rate": 5e-4, + "do_train": true, + "save_steps": 250, + "logging_steps": 50, + "save_strategy": "steps", + "do_eval": true, + "eval_steps": 500, + "evaluation_strategy": "steps", + "overwrite_output_dir": true, + "stop_after_n_steps": 4000, + "remove_unused_columns": false +} \ No newline at end of file diff --git a/train_configs/word-task/ShearedLlama-bi.json b/train_configs/word-task/ShearedLlama-bi.json new file mode 100644 index 0000000..2ba828f --- /dev/null +++ b/train_configs/word-task/ShearedLlama-bi.json @@ -0,0 +1,25 @@ +{ + "model_name_or_path": "princeton-nlp/Sheared-LLaMA-1.3B", + "model_class": "custom", + "bidirectional": true, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "same_token", + "output_dir": "output/word-task/pos_tags/ShearedLlama/bi", + "dataset_name": "conll2003", + "task": "pos_tags", + "num_train_epochs": 10, + "learning_rate": 5e-4, + "do_train": true, + "save_steps": 250, + "logging_steps": 50, + "save_strategy": "steps", + "do_eval": true, + "eval_steps": 500, + "evaluation_strategy": "steps", + "overwrite_output_dir": true, + "stop_after_n_steps": 4000, + "remove_unused_columns": false +} \ No newline at end of file diff --git a/train_configs/word-task/ShearedLlama-uni.json b/train_configs/word-task/ShearedLlama-uni.json new file mode 100644 index 0000000..72dc8de --- /dev/null +++ b/train_configs/word-task/ShearedLlama-uni.json @@ -0,0 +1,25 @@ +{ + "model_name_or_path": "princeton-nlp/Sheared-LLaMA-1.3B", + "model_class": "custom", + "bidirectional": false, + "classifier_dropout": 0.1, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "merge_subwords": true, + "retroactive_labels": "same_token", + "output_dir": "output/word-task/pos_tags/ShearedLlama/uni", + "dataset_name": "conll2003", + "task": "pos_tags", + "num_train_epochs": 10, + "learning_rate": 5e-4, + "do_train": true, + "save_steps": 250, + "logging_steps": 50, + "save_strategy": "steps", + "do_eval": true, + "eval_steps": 500, + "evaluation_strategy": "steps", + "overwrite_output_dir": true, + "stop_after_n_steps": 4000, + "remove_unused_columns": false +} \ No newline at end of file