Skip to content

Commit

Permalink
Add RoPE scaling for Llama3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Jul 23, 2024
1 parent 46a803f commit 05d1352
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 7 deletions.
29 changes: 22 additions & 7 deletions exllamav2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
T = TypeVar('T')
no_default = object()

def read(input_dict: dict[str, Any], expected_type: type, keys: str | list[str], default = no_default) -> T:
def read(input_dict: dict[str, Any], expected_type: type | list[type], keys: str | list[str], default = no_default) -> T:

expected_types = expected_type if isinstance(expected_type, list) else [expected_type]

if isinstance(keys, str): keys = [keys]

Expand All @@ -34,10 +36,10 @@ def read(input_dict: dict[str, Any], expected_type: type, keys: str | list[str],
if expected_type == int and isinstance(x, float) and x == int(x):
x = int(x)

if isinstance(x, expected_type):
return cast(T, x)
else:
raise TypeError(f"Value for {key} is not of expected type {expected_type}")
for t in expected_types:
if isinstance(x, t):
return cast(T, x)
raise TypeError(f"Value for {key} is not of expected type {expected_type}")

if default != no_default: return default
raise ValueError(f"Missing any of the following keys: {keys}")
Expand Down Expand Up @@ -105,7 +107,10 @@ class ExLlamaV2Config:
attn_logit_softcapping: float | None
sliding_window: int
norm_head: int | None

l3_rope_factor: float | None
l3_rope_low_freq_factor: float | None
l3_rope_high_freq_factor: float | None
l3_rope_original_max_position_embeddings: int | None
checkpoint_fused_mlp: bool
checkpoint_offset_qzeros: bool

Expand Down Expand Up @@ -191,10 +196,13 @@ def prepare(self, no_tensors: bool = False):
# Vocab params

self.bos_token_id = read(read_config, int, "bos_token_id", None) # 1
self.eos_token_id = read(read_config, int, "eos_token_id", None) # 2
self.eos_token_id = read(read_config, [int, list], "eos_token_id", None) # 2
self.pad_token_id = read(read_config, int, "pad_token_id", None) # 0
self.vocab_size = read(read_config, int, "vocab_size")

if isinstance(self.eos_token_id, list):
self.eos_token_id = self.eos_token_id[0] # TODO: Figure out a way to maybe use all the EOS tokens somehow

# Standard params

self.initializer_range = read(read_config, float, ["initializer_range"])
Expand Down Expand Up @@ -287,6 +295,13 @@ def prepare(self, no_tensors: bool = False):
self.alt_rope_method = "su"
# if scaling_type == "yarn":
# self.scale_alpha_value = factor
rope_type = rs.get("rope_type", None)
if rope_type == "llama3":
self.alt_rope_method = "llama3"
self.l3_rope_factor = rs["factor"]
self.l3_rope_low_freq_factor = rs["low_freq_factor"]
self.l3_rope_high_freq_factor = rs["high_freq_factor"]
self.l3_rope_original_max_position_embeddings = rs["original_max_position_embeddings"]

# Checkpoint format (for GPTQ models)

Expand Down
38 changes: 38 additions & 0 deletions exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,31 @@ def get_scratch_slice(self, size_bytes):
return scratch_slice


@staticmethod
def _apply_scaling(
freqs: torch.Tensor,
scale_factor: float = 8,
low_freq_factor: float = 1,
high_freq_factor: float = 4,
old_context_len: int = 8192, # original llama3 length
):
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []

for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype = freqs.dtype, device = freqs.device)


def prepare_sincos(self):

device = _torch_device(self.device_idx)
Expand Down Expand Up @@ -163,6 +188,19 @@ def prepare_sincos(self):

inv_freq = 1.0 / (ext_factors * base ** (torch.arange(0, head_dim, 2, device = device).float() / head_dim))

# Llama 3.1

elif cfg.alt_rope_method == "llama3":

inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device = device).float() / head_dim))
inv_freq = self._apply_scaling(
inv_freq,
cfg.l3_rope_factor,
cfg.l3_rope_low_freq_factor,
cfg.l3_rope_high_freq_factor,
cfg.l3_rope_original_max_position_embeddings,
)

# Regular

else:
Expand Down

0 comments on commit 05d1352

Please sign in to comment.