Skip to content

Commit

Permalink
More Jetstream Pytorch fixes, prepare for release (#116)
Browse files Browse the repository at this point in the history
* fix(tgi): correct truncation in Jetstream Pytorch generator

* chore(ci): jetstream TGI tests also run on main on push

* refactor(generator): inputs removed from slot

This is not used anyway.

* fix(generator): correct cached_batch and set slot numbers to batch_size

The cached batch returned was wrong, because the generator expects only
one cache batch returned per each prefill/decode call.
Also, the slot size is now fixed: this will prevent creating and
destroying elements in the slot list, so to better allow further
optimizations and avoid JIT compilation.

* feat(rng): improve randomness in sampling on Jetstream/Pt

The randomness when sampling has been improved by splitting the key as
suggested by the documentation of the JAX random submodule.

* test(jetstream): added prefill and decode multiple tests

A GPT2 test file exists to verify the generator behaviour when using the
legacy Pytorch/XLA code, so now this test has been added to verify the
same behaviour on the Jetstream/Pytorch counterpart.

* test(jetstream): added failing test to check sampling can be changed

* fix(jetstream): correct sampling for jetstream

The Jetstream/Pt engine allows to pass a callback when using the prefill
and generate methods. This callback is used to sample the generated
token with custom function, but the caller function is JIT'ed, making a
strong constraint on the callback signature. So far the callback was
compiled on the first call, making it impossible to change the sampling
algorithm on different requests.
This commit fixes this issue by subclassing the PytorchEngine class and
defining a new `prefill_ex` method that is not JIT'ed. The model calls
are still compiled, so the performance should not be noticeably
affected.

* chore: bump version to 0.2.0

Minor version is increased mainly because of Jetstream Pytorch support
on TGI.

* fix(version): version number was not correctly updated, fix it

* review: remove commented code leftover

* review: add docstring to explain tests goals
  • Loading branch information
tengomucho authored Nov 20, 2024
1 parent baae0c4 commit 1fc59ce
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 61 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
name: Optimum TPU / Test TGI on TPU / Jetstream Pytorch

on:
push:
branches: [ main ]
paths:
- "text-generation-inference/**"
pull_request:
branches: [ main ]
paths:
Expand Down
2 changes: 1 addition & 1 deletion optimum/tpu/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
from packaging.version import parse


__version__ = "0.1.5"
__version__ = "0.2.0"
VERSION = parse(__version__)
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,35 @@

if TYPE_CHECKING:
from transformers import PretrainedConfig

from transformers import AutoConfig

from .compatibility import model_can_use_jetstream_pt
from .models import GemmaModel, LlamaModel, MixtralModel


class OptimumJetstreamEngine(PyTorchEngine):
"""This is essentially the same as the PytorchEngine, but it also supports a callback for sampling in prefill and
generation that can change on each request while not needing to be JIT'ed.
"""
prefill_ex = PyTorchEngine.prefill

def __init__(
self,
pt_model: torch.nn.Module,
env: JetEngineEnvironment,
weights=None,
):
super().__init__(pt_model, env, weights)
# Call model prefill and generate needs to be JIT'ed, because it is called with sharded notations, and it would
# otherwise not work for some models.
self._call_model_prefill = jax.jit(
self._call_model_prefill,
)
self._call_model_generate = jax.jit(
self._call_model_generate,
)

def _get_head_dim(config: "PretrainedConfig") -> int:
if hasattr(config, "head_dim"):
return config.head_dim
Expand Down Expand Up @@ -174,8 +197,9 @@ def create_engine(
logger.info(f"Quantization took {end - start:.2f} seconds")
model_weights = model.state_dict()
sharded_weights = shard_weights(env, model_weights, weight_shardings)
return PyTorchEngine(
engine = OptimumJetstreamEngine(
pt_model=model,
env=env,
weights=torchjax.from_torch_with_copy(sharded_weights),
)
return engine
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def clear(self):
self._state = Slot.State.EMPTY
self._batch_id = None
self._request_id = None
self._inputs = ""
self._generation_config = None
self._tokens = []
self._selector = None
Expand Down Expand Up @@ -263,9 +262,8 @@ def __init__(
tokenizer.truncation_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids

# Slots are empty to begin with, they will be populated as new batches arrive
self.slots = []
# Slots number is static, it cannot grow over the size of the batch
self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)]
self.batch_id = 0
# Note: this index will _never_ be decremented, and that's fine.
self.slot_index = 0
Expand Down Expand Up @@ -363,13 +361,11 @@ def warmup(self, batch: Batch) -> int:
seq_len = self.engine.env.seq_len
return batch_size * seq_len

def _get_slot_id(self):
"""Get the next available slot id."""
batch_size = self.engine.env.batch_size
used_ids = [slot.id for slot in self.slots if slot.state != Slot.State.EMPTY]
for i in range(batch_size):
if i not in used_ids:
return i
def _get_slot(self):
"""Get the next available slot."""
for slot in self.slots:
if slot.state == Slot.State.EMPTY:
return slot
# if we reach this point, all slots were used - this should not happen
raise ValueError("All slots are used, but we should have stopped earlier")

Expand All @@ -388,6 +384,8 @@ def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]:
"""
if max_length == 0:
max_length = self.model.config.sequence_length
# Remove one to max_length because BOS is going to be added when padding
max_length -= 1
input_ids = self.tokenizer.encode(
text,
return_tensors="np",
Expand Down Expand Up @@ -417,14 +415,9 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""

slots = {state: [] for state in Slot.State}
for slot in self.slots:
slots[slot.state].append(slot)
len_active_slots = len(slots[Slot.State.READY])
# Delete all empty slots, no need to have them anymore
empty_slots = slots[Slot.State.EMPTY]
for slot in empty_slots:
self.slots.remove(slot)
active_slots = [slot for slot in self.slots if slot.state == Slot.State.READY]
len_active_slots = len(active_slots)

len_requests = len(batch.requests)
model_batch_size = self.model.config.batch_size
if model_batch_size is not None and model_batch_size < len_active_slots + len_requests:
Expand All @@ -439,10 +432,10 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# Assign each request to an empty slot
logger.debug(f"Prefilling {len_requests} new request(s) adding to {len_active_slots} active slot(s)")
generations = []

prefilled_active_slots = []
for request in batch.requests:
# Dynamically create a new slot for each request
slot = Slot(self._get_slot_id(), self.tokenizer)
slot = self._get_slot()
self.prefill_slot.set(slot)
self.slot_index += 1
slot.assign(self.batch_id, request, self.model.generation_config)
Expand All @@ -462,7 +455,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# To allow jit'ing the select function, we need to wrap it in a partial
slot_select = jax.tree_util.Partial(self.prefill_slot.select)
# Ask for prefill and insert
prefill_results, _result_tokens = self.engine.prefill(
prefill_results, _result_tokens = self.engine.prefill_ex(
params=self.params,
padded_tokens=input_ids,
true_length=true_lengths,
Expand All @@ -473,28 +466,20 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:

self._post_generate(slot, next_token, generations)
if not slot.empty:
# append current to list of active slots
self.slots.append(slot)
len_active_slots += 1

batch = None
if len_active_slots > 0:
# Whatever initial batch these requests came from, we always return all pending requests in a single batch
request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY]
batch = self._cached_batch(self.batch_id, request_ids)
else:
logger.debug("No more pending requests")
prefilled_active_slots.append(slot)

cached_batch = self._cached_batch(self.batch_id, prefilled_active_slots)
self.batch_id += 1
logger.debug("Model ready for decoding")
return generations, batch
return generations, cached_batch

def _select_from_slots(self, logits: jnp.ndarray, batch_size: int=0) -> jnp.ndarray:
pad_token_id = self.tokenizer.pad_token_id
batch_size = logits.shape[0]
tokens = jnp.full((batch_size, 1), pad_token_id)
for slot in filter(lambda slot: slot.state == slot.State.READY, self.slots):
# Every slot might have a different selection criteria, so we are obliged to call select in a loop
next_token = slot.select(logits)
next_token = slot.select(logits[slot.id : slot.id + 1, :])
tokens = tokens.at[slot.id].set(next_token)
return tokens

Expand Down Expand Up @@ -539,10 +524,8 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")

# Use a custom function to select the next token for each slot
select_fn = jax.tree_util.Partial(self._select_from_slots)
self.decode_state, result_tokens = self.engine.generate(self.params, self.decode_state, select_fn)
self.decode_state, result_tokens = self.engine.generate_impl(self.params, self.decode_state, self._select_from_slots)

newly_empty = []
generations = []
for slot in active_slots:
# Get the next token.
Expand All @@ -555,20 +538,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
raise ValueError("Unexpected Slot is not ready for decoding")

self._post_generate(slot, next_token, generations)
if slot.empty:
newly_empty.append(slot)

# Remove empty slots
for slot in newly_empty:
self.slots.remove(slot)
batch = None
if len(self.slots) > 0:
# Whatever initial batch these requests came from, we always return all pending requests in a single batch
request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY]
batch = self._cached_batch(next_batch_id, request_ids)
else:
logger.debug("No more pending requests")
return generations, batch

cached_batch = self._cached_batch(next_batch_id, active_slots)
return generations, cached_batch

def _post_generate(self, slot: Slot, next_token: int, generations: List[Generation]) -> None:
"""Post-generate a slot after the generation has been completed.
Expand Down Expand Up @@ -616,7 +588,13 @@ def _post_generate(self, slot: Slot, next_token: int, generations: List[Generati
)
)

def _cached_batch(self, batch_id: int, request_ids: List):
def _cached_batch(self, batch_id: int, active_slots: List):
"""Create a CachedBatch from the active slots.
"""
request_ids = [slot.request_id for slot in active_slots if slot.state == Slot.State.READY]
if len(request_ids) == 0:
logger.debug("No more pending requests")
return None
size = len(request_ids)
max_tokens = size * self.model.config.sequence_length
return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,13 @@ def select(self, input_ids: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
"""
scores = self.logits_processor(input_ids, logits)
if self.mode == GenerationMode.SAMPLE:
return self._sample(scores)
# split the key to avoid reusing the same key for multiple samples
subkey, self.key = jax.random.split(self.key)
return self._sample(scores, subkey)
else:
return jnp.argmax(scores, axis=-1)

def _sample(self, scores: jnp.ndarray) -> jnp.ndarray:
def _sample(self, scores: jnp.ndarray, key) -> jnp.ndarray:
do_top_k = self.logits_warper.top_k > 0 and self.logits_warper.top_k < scores.shape[-1]
do_top_p = self.logits_warper.top_p < 1.0 and self.logits_warper.top_p > 0.0

Expand All @@ -188,14 +190,14 @@ def _sample(self, scores: jnp.ndarray) -> jnp.ndarray:
scores,
self.logits_warper.top_k,
self.logits_warper.temperature,
self.key,
key,
)
elif do_top_p:
return sampling_utils.sample_nucleus_topp_logits(
scores,
self.logits_warper.top_p,
self.logits_warper.temperature,
self.key,
key,
)

return jax.random.categorical(self.key, scores / self.logits_warper.temperature)
return jax.random.categorical(key, scores / self.logits_warper.temperature)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pkg_resources import parse_version


__version__ = "0.1.5"
__version__ = "0.2.0"
VERSION = parse_version(__version__)
Loading

0 comments on commit 1fc59ce

Please sign in to comment.