From 8bda9750154d21f5d159c1506e71d498025656c7 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 5 Jun 2024 06:58:38 +0000 Subject: [PATCH] fix rope precision for long context --- megatron/model/rotary_pos_embedding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/megatron/model/rotary_pos_embedding.py b/megatron/model/rotary_pos_embedding.py index 4d4497e0cd..d36cc71490 100644 --- a/megatron/model/rotary_pos_embedding.py +++ b/megatron/model/rotary_pos_embedding.py @@ -20,8 +20,9 @@ def __init__(self, dim, theta=10000): raise RuntimeError("einops is required for Rotary Embedding") def forward(self, max_seq_len, offset=0): - seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset - freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq) + seq = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=torch.float) + offset + # Force float32 since bfloat16 loses precision on long contexts + freqs = einsum('i , j -> i j', seq, self.inv_freq.float()) # first part even vector components, second part odd vector components, # 2 * dim in dimension size emb = torch.cat((freqs, freqs), dim=-1)