From 029ca1a06ff5766161d4a9bc6449761b4578c963 Mon Sep 17 00:00:00 2001 From: Jaden Fiotto-Kaufman Date: Tue, 23 Jul 2024 17:00:46 -0400 Subject: [PATCH] TEMPORARY llama fix --- src/nnsight/models/LanguageModel.py | 45 +++++++++++++++-------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/src/nnsight/models/LanguageModel.py b/src/nnsight/models/LanguageModel.py index e80f3935..cc50c6f9 100644 --- a/src/nnsight/models/LanguageModel.py +++ b/src/nnsight/models/LanguageModel.py @@ -4,9 +4,15 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch -from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, - AutoTokenizer, BatchEncoding, PreTrainedModel, - PreTrainedTokenizer) +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + BatchEncoding, + PreTrainedModel, + PreTrainedTokenizer, +) from transformers.models.auto import modeling_auto from transformers.models.llama.configuration_llama import LlamaConfig @@ -139,10 +145,10 @@ def __init__( if not isinstance(automodel, str) else getattr(modeling_auto, automodel) ) - + if isinstance(model_key, torch.nn.Module): - - setattr(model_key, 'generator', WrapperModule()) + + setattr(model_key, "generator", WrapperModule()) super().__init__(model_key, *args, **kwargs) @@ -150,10 +156,9 @@ def _load( self, repo_id: str, tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs ) -> PreTrainedModel: - config = kwargs.pop("config", None) or AutoConfig.from_pretrained(repo_id, **kwargs) - - - + config = kwargs.pop("config", None) or AutoConfig.from_pretrained( + repo_id, **kwargs + ) if self.tokenizer is None: if tokenizer_kwargs is None: @@ -163,31 +168,29 @@ def _load( warnings.warn( "NNsight LanguageModel requires padding_side='left' for tokenizers, setting it to 'left'" ) - + self.tokenizer = AutoTokenizer.from_pretrained( repo_id, config=config, padding_side="left", **tokenizer_kwargs ) self.tokenizer.pad_token = self.tokenizer.eos_token if self._model is None: - - - if isinstance(config, LlamaConfig) and "rope_type" in config.rope_scaling: - config.rope_scaling['rope_type'] = "default" + + if isinstance(config, LlamaConfig) and isinstance(config.rope_scaling, dict) and "rope_type" in config.rope_scaling: + config.rope_scaling["rope_type"] = "default" model = self.automodel.from_config(config, trust_remote_code=True) - - setattr(model, 'generator', WrapperModule()) + + setattr(model, "generator", WrapperModule()) return model - - if isinstance(config, LlamaConfig) and "rope_type" in config.rope_scaling: - config.rope_scaling['rope_type'] = "llama3" + if isinstance(config, LlamaConfig) and isinstance(config.rope_scaling, dict) and "rope_type" in config.rope_scaling: + config.rope_scaling["rope_type"] = "llama3" model = self.automodel.from_pretrained(repo_id, config=config, **kwargs) - setattr(model, 'generator', WrapperModule()) + setattr(model, "generator", WrapperModule()) return model