Skip to content

Commit

Permalink
remove dtype from llama precompute_freqs_cis (tinygrad#7930)
Browse files Browse the repository at this point in the history
do the cast based on input in first forward call instead
  • Loading branch information
chenyuxyz authored Nov 28, 2024
1 parent 3e2430f commit 336a9b6
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions extra/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from tinygrad.helpers import getenv

# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
# TODO: move dtype outside this
return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).reshape(1, end, 1, dim//2, 2)

# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
def complex_mult(A, c, d):
Expand Down Expand Up @@ -163,9 +162,11 @@ def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:

def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)

self.freqs_cis = self.freqs_cis.cast(h.dtype).realize()
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))

h = self.tok_embeddings(tokens)
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
logits = self.output(self.norm(h)).float()[:, -1, :]
Expand Down

0 comments on commit 336a9b6

Please sign in to comment.