-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
More Jetstream Pytorch fixes, prepare for release (#116)
* fix(tgi): correct truncation in Jetstream Pytorch generator * chore(ci): jetstream TGI tests also run on main on push * refactor(generator): inputs removed from slot This is not used anyway. * fix(generator): correct cached_batch and set slot numbers to batch_size The cached batch returned was wrong, because the generator expects only one cache batch returned per each prefill/decode call. Also, the slot size is now fixed: this will prevent creating and destroying elements in the slot list, so to better allow further optimizations and avoid JIT compilation. * feat(rng): improve randomness in sampling on Jetstream/Pt The randomness when sampling has been improved by splitting the key as suggested by the documentation of the JAX random submodule. * test(jetstream): added prefill and decode multiple tests A GPT2 test file exists to verify the generator behaviour when using the legacy Pytorch/XLA code, so now this test has been added to verify the same behaviour on the Jetstream/Pytorch counterpart. * test(jetstream): added failing test to check sampling can be changed * fix(jetstream): correct sampling for jetstream The Jetstream/Pt engine allows to pass a callback when using the prefill and generate methods. This callback is used to sample the generated token with custom function, but the caller function is JIT'ed, making a strong constraint on the callback signature. So far the callback was compiled on the first call, making it impossible to change the sampling algorithm on different requests. This commit fixes this issue by subclassing the PytorchEngine class and defining a new `prefill_ex` method that is not JIT'ed. The model calls are still compiled, so the performance should not be noticeably affected. * chore: bump version to 0.2.0 Minor version is increased mainly because of Jetstream Pytorch support on TGI. * fix(version): version number was not correctly updated, fix it * review: remove commented code leftover * review: add docstring to explain tests goals
- Loading branch information
1 parent
baae0c4
commit 1fc59ce
Showing
7 changed files
with
248 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,5 +15,5 @@ | |
from packaging.version import parse | ||
|
||
|
||
__version__ = "0.1.5" | ||
__version__ = "0.2.0" | ||
VERSION = parse(__version__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
text-generation-inference/server/text_generation_server/version.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from pkg_resources import parse_version | ||
|
||
|
||
__version__ = "0.1.5" | ||
__version__ = "0.2.0" | ||
VERSION = parse_version(__version__) |
Oops, something went wrong.