Skip to content

Commit

Permalink
Update torchtune generation to be more flexible
Browse files Browse the repository at this point in the history
Differential Revision: D65480353

Pull Request resolved: pytorch#1970
  • Loading branch information
RylanC24 authored Nov 8, 2024
1 parent 7bfb333 commit eb67cc5
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions torchtune/generation/_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def generate_next_token(
model: TransformerDecoder,
input_pos: torch.Tensor,
x: torch.Tensor,
q: torch.Tensor,
q: Optional[torch.Tensor] = None,
*,
mask: Optional[torch.Tensor] = None,
temperature: float = 1.0,
Expand All @@ -82,7 +82,7 @@ def generate_next_token(
with shape [bsz x seq_length].
x (torch.Tensor): tensor with the token IDs associated with the given prompt,
with shape [bsz x seq_length].
q (torch.Tensor): randomly sampled tensor for softmax sampling trick.
q (Optional[torch.Tensor]): randomly sampled tensor for softmax sampling trick.
See https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/generate.py#L40
mask (Optional[torch.Tensor]): attention mask with shape [bsz x seq_length x seq_length],
default None.
Expand Down Expand Up @@ -302,9 +302,11 @@ def generate(
# tensors are of identical shape to the prompt
curr_masks = masks[:, :prompt_length, :prompt_length]

q = torch.empty(
(bsz, model.tok_embeddings.num_embeddings), device=prompt.device
).exponential_(1, generator=rng)
q = None
if rng is not None:
q = torch.empty(
(bsz, model.tok_embeddings.num_embeddings), device=prompt.device
).exponential_(1, generator=rng)
tokens, generated_logits = generate_next_token(
model,
input_pos=input_pos[:, :prompt_length].squeeze(),
Expand Down Expand Up @@ -360,9 +362,11 @@ def generate(
curr_input_pos = input_pos[:, : curr_pos + 1]
curr_masks = masks[:, : curr_pos + 1, : curr_pos + 1]

q = torch.empty(
(bsz, model.tok_embeddings.num_embeddings), device=prompt.device
).exponential_(1, generator=rng)
q = None
if rng is not None:
q = torch.empty(
(bsz, model.tok_embeddings.num_embeddings), device=prompt.device
).exponential_(1, generator=rng)
tokens, logits = custom_generate_next_token(
model,
input_pos=curr_input_pos,
Expand Down

0 comments on commit eb67cc5

Please sign in to comment.