Skip to content

Commit

Permalink
[Bug] HybridCache not subscriptable (#1047)
Browse files Browse the repository at this point in the history
Use transformers `Cache` interface rather than legacy tuples (still included for backwards compatibility).

- When we overflow the size of StaticCache, HybridCache, reallocate a cache with double the size
      - Other fixed-size caches will just raise a warning and delete the cache until we adapt doubling logic to those cache types
- Use `Cache.crop` when available for backtracking the cache
- When `Cache.crop` is unavailable, try `Cache.reset` to avoid reallocation, finally falling back on deleting the cache
  • Loading branch information
hudson-ai authored Jan 15, 2025
1 parent 71f1a68 commit aaa5f00
Showing 1 changed file with 60 additions and 7 deletions.
67 changes: 60 additions & 7 deletions guidance/models/transformers/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def __init__(self,
self.model = model.__class__.__name__
self.device = self.model_obj.device # otherwise note the current device

self._past_key_values = None
self._past_key_values: Union[transformers_package.Cache, tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]], None] = None
self._cached_logits = None
self._cached_token_ids: list[int] = []

Expand Down Expand Up @@ -479,13 +479,66 @@ def get_logits(self, token_ids):

# reset the cache length according to that number of positions
past_key_values = self._past_key_values
past_length = past_key_values[0][0].size(-2) if past_key_values is not None else 0
if past_length > num_cached:
# note we recompute the last token because we don't bother to handle the special case of just computing logits
max_cache_shape = None
if past_key_values is None:
past_length = 0
elif isinstance(past_key_values, tuple):
past_length = past_key_values[0][0].size(-2)
elif isinstance(past_key_values, transformers_package.Cache):
# TODO: use model's `cache_position` as this may be deprecated in a future version
# https://github.com/huggingface/transformers/blob/70b07d97cf2c5f61fff55700b65528a1b6845cd2/src/transformers/cache_utils.py#L64
past_length = past_key_values.get_seq_length()
# TODO: use `get_max_cache_shape` as `get_max_length` will be deprecated in a future version
# (`get_max_cache_shape` is not yet available so we can't use it yet)
# https://github.com/huggingface/transformers/blob/70b07d97cf2c5f61fff55700b65528a1b6845cd2/src/transformers/cache_utils.py#L67
max_cache_shape = past_key_values.get_max_length()
else:
raise TypeError(f"Unknown type of past_key_values: {type(past_key_values)}")

if max_cache_shape is not None and len(token_ids) > max_cache_shape:
# TODO: this seems to get set to the length of the first sequence we pass for models using
# StaticCache or HybridCache. We need to initialize our own cache with a large enough size
# if we want to continue generation with the same cache.
if isinstance(past_key_values, (transformers_package.StaticCache, transformers_package.HybridCache)):
# The __init__ API isn't consistent between different cache types, but there seems to be consistency
# between these two types, so we can use the same logic for both.
warnings.warn("Cache is too small. Re-initializing cache with larger size.")
cache_type = type(past_key_values)
config = self.model_obj.config
device = self.model_obj.device
hf_device_map = getattr(self.model_obj, "hf_device_map", {})
# hf_device_map is not always a complete mapping of layers to devices...
layer_device_map = {k: hf_device_map.get(k, device) for k in range(config.num_hidden_layers)}
self._past_key_values = cache_type(
config=config,
batch_size=past_key_values.batch_size,
# Double the cache size to be safe
max_cache_len=len(token_ids)*2,
dtype=past_key_values.dtype,
layer_device_map=layer_device_map,
)
else:
warnings.warn(f"Cache is too small. Resetting cache (no method implemented to resize cache for type {type(past_key_values)}).")
self._past_key_values = None
past_length = 0
elif past_length > num_cached:
past_length = max(0, num_cached - 1)
self._past_key_values = tuple(
tuple(p[..., :past_length, :] for p in v) for v in past_key_values
)
if isinstance(past_key_values, tuple):
self._past_key_values = tuple(
tuple(p[..., :past_length, :] for p in v) for v in past_key_values
)
else:
if hasattr(past_key_values, "crop"):
self._past_key_values.crop(past_length)
else:
warnings.warn(f"Cropping unsupported for cache type: {type(self._past_key_values)}. Resetting cache.")
if hasattr(self._past_key_values, "reset"):
# Use built-in reset method if available to avoid constructing/allocating a new cache
self._past_key_values.reset()
else:
self._past_key_values = None
past_length = 0

cache_token_ids[past_length:] = []

# call the model
Expand Down

0 comments on commit aaa5f00

Please sign in to comment.