Skip to content

Commit

Permalink
Tweaking readme and docstrings because the normal samplers now run on…
Browse files Browse the repository at this point in the history
… the GPU
  • Loading branch information
murrellb committed Nov 28, 2024
1 parent f71d5c9 commit 806f833
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 16 deletions.
16 changes: 0 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,22 +155,6 @@ generate(model, prompt,
device = gpu); #Note the device keyword
```

If you're using one of the trickier samplers, some CPU operations are needed for sampling. So you need to pass `device = cpu` to the sampler, while passing `device = gpu` to the `generate` function:

```julia
#Put the model on the GPU
model = gpu(model)

prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python.")
generate(model, prompt,
max_new_tokens=500,
tokenizer_for_printing=tkn,
end_token = encode(tkn, "<|im_end|>")[end],
sampler = top_nσ_sampler(; device = cpu), #cpu for the sampler
device = gpu, #gpu for generate
);
```

And if you're training, the data needs to be on the GPU:

```julia
Expand Down
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ eos = encode(tkn, "<|im_end|>")[end]
prompt = smollm2_instruct_prompt(tkn, "You are an expert in Statistics and Probability Theory who answers questions in as few words as possible.",question)
generate(model, prompt, max_new_tokens=100, tokenizer_for_printing=tkn, end_token = eos, sampler = structured_choice(choices, vocab, eos));
```
If you want to run the model on the GPU, then you need to pass `device = gpu` to the `generate` function, and `device = cpu` to the `structured_choice` function.
"""
function structured_choice(choices::Vector{String}, vocab::Vector{String}, end_token::Int; sampler = logits -> argmax_sampler(logits), device = identity)
remaining_choices = copy(choices)
Expand Down

0 comments on commit 806f833

Please sign in to comment.