Skip to content

Commit

Permalink
TEMPORARY llama fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Jul 23, 2024
1 parent 4df63da commit 029ca1a
Showing 1 changed file with 24 additions and 21 deletions.
45 changes: 24 additions & 21 deletions src/nnsight/models/LanguageModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoTokenizer, BatchEncoding, PreTrainedModel,
PreTrainedTokenizer)
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
BatchEncoding,
PreTrainedModel,
PreTrainedTokenizer,
)
from transformers.models.auto import modeling_auto
from transformers.models.llama.configuration_llama import LlamaConfig

Expand Down Expand Up @@ -139,21 +145,20 @@ def __init__(
if not isinstance(automodel, str)
else getattr(modeling_auto, automodel)
)

if isinstance(model_key, torch.nn.Module):
setattr(model_key, 'generator', WrapperModule())

setattr(model_key, "generator", WrapperModule())

super().__init__(model_key, *args, **kwargs)

def _load(
self, repo_id: str, tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs
) -> PreTrainedModel:

config = kwargs.pop("config", None) or AutoConfig.from_pretrained(repo_id, **kwargs)



config = kwargs.pop("config", None) or AutoConfig.from_pretrained(
repo_id, **kwargs
)

if self.tokenizer is None:
if tokenizer_kwargs is None:
Expand All @@ -163,31 +168,29 @@ def _load(
warnings.warn(
"NNsight LanguageModel requires padding_side='left' for tokenizers, setting it to 'left'"
)

self.tokenizer = AutoTokenizer.from_pretrained(
repo_id, config=config, padding_side="left", **tokenizer_kwargs
)
self.tokenizer.pad_token = self.tokenizer.eos_token

if self._model is None:


if isinstance(config, LlamaConfig) and "rope_type" in config.rope_scaling:
config.rope_scaling['rope_type'] = "default"

if isinstance(config, LlamaConfig) and isinstance(config.rope_scaling, dict) and "rope_type" in config.rope_scaling:
config.rope_scaling["rope_type"] = "default"

model = self.automodel.from_config(config, trust_remote_code=True)
setattr(model, 'generator', WrapperModule())

setattr(model, "generator", WrapperModule())

return model

if isinstance(config, LlamaConfig) and "rope_type" in config.rope_scaling:
config.rope_scaling['rope_type'] = "llama3"

if isinstance(config, LlamaConfig) and isinstance(config.rope_scaling, dict) and "rope_type" in config.rope_scaling:
config.rope_scaling["rope_type"] = "llama3"

model = self.automodel.from_pretrained(repo_id, config=config, **kwargs)

setattr(model, 'generator', WrapperModule())
setattr(model, "generator", WrapperModule())

return model

Expand Down

0 comments on commit 029ca1a

Please sign in to comment.