From 3d35156dffed00b746367a319122a3559ab37814 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 27 Nov 2024 14:38:08 +0000 Subject: [PATCH 01/13] chore(fsdp): remove a warning not relevant anymore A warning appeared on on an old version of torch xla (2.3.0), but that is not supported anymore. --- optimum/tpu/fsdp_v2.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/optimum/tpu/fsdp_v2.py b/optimum/tpu/fsdp_v2.py index 9f4a5ad1..d303e44a 100644 --- a/optimum/tpu/fsdp_v2.py +++ b/optimum/tpu/fsdp_v2.py @@ -17,8 +17,6 @@ """ from typing import Any, Dict, List, Union -from transformers.utils import logging - PreTrainedModel = Any # NOTE: instead of the above, modeling_utils.PreTrainedModel should be used, but since the usage is only for type @@ -92,15 +90,6 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict: from .modeling_gemma import GemmaForCausalLM if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM): - logger = logging.get_logger(__name__) - from torch_xla import __version__ as xla_version - - if xla_version == "2.3.0": - logger.warning_once( - "Fine-tuning Gemma on Pytorch XLA 2.3.0 might raise some issues. In case of any " - "issues consider using the nightly version, and report the issue on the optimum-tpu " - "GitHub repository: https://github.com/huggingface/optimum-tpu/issues/new." - ) cls_to_wrap = "GemmaDecoderLayer" matched_model = True elif model_type == "llama": From a25a56ff6a1b5f0ea2e0e38c11f511c028e12eb7 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 22 Nov 2024 15:49:47 +0000 Subject: [PATCH 02/13] chore: update jetstream dependency to v0.2.4 --- optimum/tpu/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/tpu/cli.py b/optimum/tpu/cli.py index 303eed8c..c43065b3 100644 --- a/optimum/tpu/cli.py +++ b/optimum/tpu/cli.py @@ -10,7 +10,7 @@ TORCH_VER = "2.4.0" -JETSTREAM_PT_VER = "02927c9f563082421abe8eedceabe8aedd7ec2f9" +JETSTREAM_PT_VER = "jetstream-v0.2.4" DEFAULT_DEPS_PATH = os.path.join(Path.home(), ".jetstream-deps") app = typer.Typer() From 446a9238a130f16ace1405899a043d72fd8c57c7 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 22 Nov 2024 16:51:36 +0000 Subject: [PATCH 03/13] chore(docker): increase ulimit to avoid error When building the docker container, sometimes an error occurs due to a "too many files open" error. Increasing the ulimit makes the error disappear. --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index 7448a46e..5911f9f8 100644 --- a/Makefile +++ b/Makefile @@ -47,6 +47,7 @@ tpu-tgi: docker build --rm -f text-generation-inference/docker/Dockerfile \ --build-arg VERSION=$(VERSION) \ --build-arg TGI_VERSION=$(TGI_VERSION) \ + --ulimit nofile=100000:100000 \ -t huggingface/optimum-tpu:$(VERSION)-tgi . docker tag huggingface/optimum-tpu:$(VERSION)-tgi huggingface/optimum-tpu:latest @@ -55,6 +56,7 @@ tpu-tgi-ie: --target inference-endpoint \ --build-arg VERSION=$(VERSION) \ --build-arg TGI_VERSION=$(TGI_VERSION) \ + --ulimit nofile=100000:100000 \ -t huggingface/optimum-tpu:$(VERSION)-tgi . docker tag huggingface/optimum-tpu:$(VERSION)-tgi huggingface/optimum-tpu:latest-ie From 780110cc518fe03c4fbeddef64f30e2c17a73939 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 22 Nov 2024 16:57:54 +0000 Subject: [PATCH 04/13] fix(docker): "AS" statement should be uppercase to avoid warning --- text-generation-inference/docker/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/text-generation-inference/docker/Dockerfile b/text-generation-inference/docker/Dockerfile index 218561dc..08b7b9b4 100644 --- a/text-generation-inference/docker/Dockerfile +++ b/text-generation-inference/docker/Dockerfile @@ -13,7 +13,7 @@ WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse -FROM chef as planner +FROM chef AS planner COPY --from=tgi /tgi/Cargo.toml Cargo.toml COPY --from=tgi /tgi/Cargo.lock Cargo.lock COPY --from=tgi /tgi/rust-toolchain.toml rust-toolchain.toml @@ -134,7 +134,7 @@ RUN pip install dist/text_generation_server*.tar.gz # TPU compatible image for Inference Endpoints -FROM tpu_base as inference-endpoint +FROM tpu_base AS inference-endpoint COPY text-generation-inference/docker/entrypoint.sh entrypoint.sh RUN chmod +x entrypoint.sh From 9d0afd5b629467e531eb069fb87c45983489329c Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 22 Nov 2024 15:59:05 +0000 Subject: [PATCH 05/13] chore: update TGI dependency to v2.4.1 Also align Dockerfile to TGI's one. --- .github/workflows/tpu-tgi-release.yml | 4 ++-- Makefile | 3 +-- text-generation-inference/docker/Dockerfile | 7 ++++--- text-generation-inference/server/Makefile | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/tpu-tgi-release.yml b/.github/workflows/tpu-tgi-release.yml index 2d81a471..400a5238 100644 --- a/.github/workflows/tpu-tgi-release.yml +++ b/.github/workflows/tpu-tgi-release.yml @@ -74,7 +74,7 @@ jobs: labels: ${{ steps.meta.outputs.labels }} build-args: | VERSION=${{ steps.version.outputs.version }} - TGI_VERSION=0ff6ff60ada291840beed63d8bf458d6f9606f7f + TGI_VERSION="v2.4.1" - name: Generate artifact attestation for TGI @@ -95,7 +95,7 @@ jobs: labels: ${{ steps.meta-ie.outputs.labels }} build-args: | VERSION=${{ steps.version.outputs.version }} - TGI_VERSION=0ff6ff60ada291840beed63d8bf458d6f9606f7f + TGI_VERSION="v2.4.1" target: inference-endpoint diff --git a/Makefile b/Makefile index 5911f9f8..7091ac00 100644 --- a/Makefile +++ b/Makefile @@ -19,8 +19,7 @@ REAL_CLONE_URL = $(if $(CLONE_URL),$(CLONE_URL),$(DEFAULT_CLONE_URL)) .PHONY: build_dist style style_check clean -# Ths is essentially v2.3.0 plus a fix to support v2 proto interface -TGI_VERSION ?= 0ff6ff60ada291840beed63d8bf458d6f9606f7f +TGI_VERSION ?= 690702b1ce9a27ce5bdf2a9dd3a80277ecea12cd rwildcard=$(wildcard $1) $(foreach d,$1,$(call rwildcard,$(addsuffix /$(notdir $d),$(wildcard $(dir $d)*)))) diff --git a/text-generation-inference/docker/Dockerfile b/text-generation-inference/docker/Dockerfile index 08b7b9b4..7611aa23 100644 --- a/text-generation-inference/docker/Dockerfile +++ b/text-generation-inference/docker/Dockerfile @@ -8,7 +8,7 @@ RUN tar -C /tgi -xf /tgi/sources.tar.gz --strip-components=1 # Build cargo components (adapted from TGI original Dockerfile) # Note that the build image is aligned on the same Linux version as the base image (Debian bookworm/ Ubuntu 22.04) -FROM lukemathwalker/cargo-chef:latest-rust-1.79-bookworm AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80.1-bookworm AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -106,7 +106,7 @@ ARG ACCELERATE_VERSION='0.27.2' ARG SAFETENSORS_VERSION='0.4.2' # TGI base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 \ VERSION=${VERSION} @@ -145,4 +145,5 @@ ENTRYPOINT ["./entrypoint.sh"] FROM tpu_base ENTRYPOINT ["text-generation-launcher"] -CMD ["--json-output"] +# This is commented out in the original TGI Dockerfile +# CMD ["--json-output"] diff --git a/text-generation-inference/server/Makefile b/text-generation-inference/server/Makefile index d513e9b5..56e481b0 100644 --- a/text-generation-inference/server/Makefile +++ b/text-generation-inference/server/Makefile @@ -2,7 +2,7 @@ pkg_name := text_generation_server BUILDDIR ?= $(CURDIR)/build VERSION ?= 0.0.1 -TGI_VERSION ?= 0ff6ff60ada291840beed63d8bf458d6f9606f7f +TGI_VERSION ?= "v2.4.1" mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) mkfile_dir := $(dir $(mkfile_path)) pkg_dir := $(BUILDDIR)/$(pkg_name) From 4e5ba90685a7726af5533f6dd3cc7baa8ca06be5 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 27 Nov 2024 14:01:06 +0000 Subject: [PATCH 06/13] chore(docker): update accelerate to v1.1.1 --- text-generation-inference/docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/text-generation-inference/docker/Dockerfile b/text-generation-inference/docker/Dockerfile index 7611aa23..2775cf7d 100644 --- a/text-generation-inference/docker/Dockerfile +++ b/text-generation-inference/docker/Dockerfile @@ -102,7 +102,7 @@ RUN pip install --upgrade pip # Install HuggingFace packages ARG TRANSFORMERS_VERSION='4.41.1' -ARG ACCELERATE_VERSION='0.27.2' +ARG ACCELERATE_VERSION='1.1.1' ARG SAFETENSORS_VERSION='0.4.2' # TGI base env From c5f0aac1eb2028c0dec413a0b7194d5359adf080 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Thu, 28 Nov 2024 15:14:05 +0000 Subject: [PATCH 07/13] fix(jetstream): correct Gemma and Mixtral config handling The config object variable for these models was used by the Jetstream code, but it does not completely match with HF's config definitions. This creates a class that heritates from both classes, and makes the adjustments necessary to avoid errors. --- .../models/gemma_model_hf.py | 27 +++++++---------- .../models/mixtral_model_hf.py | 29 ++++++++++--------- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/gemma_model_hf.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/gemma_model_hf.py index e61788d8..f0c0476b 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/gemma_model_hf.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/gemma_model_hf.py @@ -4,6 +4,15 @@ from transformers import GemmaConfig, GenerationConfig, GenerationMixin +class GemmaConfigHf(GemmaConfig, gemma_config.GemmaConfig): + """This class is used to support both the HF GemmaConfig and the Jetstream Pytorch GemmaConfig at the same time. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tokenizer = None + + class GemmaModelHf(GemmaModel, GenerationMixin): """Transformer module that uses HF GemmaConfig instead of Jetstream Pytorch GemmaConfig + device. @@ -16,24 +25,8 @@ def __init__( device, env, ): - self.config = config self.generation_config = GenerationConfig.from_model_config(config) - - args = gemma_config.GemmaConfig( - vocab_size=config.vocab_size, - max_position_embeddings=config.max_position_embeddings, - num_hidden_layers=config.num_hidden_layers, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - head_dim=config.head_dim, - rms_norm_eps=config.rms_norm_eps, - dtype="bfloat16", - quant=False, # No quantization support for now - tokenizer=None, - ) - + args = GemmaConfigHf(**config.to_dict()) args.device = device super().__init__(args, env) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py index 0e476a9a..fde78b10 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py @@ -4,6 +4,20 @@ from transformers import GenerationConfig, GenerationMixin, MixtralConfig +class MixtralConfigHf(MixtralConfig, mixtral_config.ModelArgs): + """This class is used to support both the HF MixtralConfig and the Jetstream Pytorch ModelArgs at the same time. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.block_size = self.max_position_embeddings + self.n_layer = self.num_hidden_layers + self.n_head = self.num_attention_heads + self.dim = self.hidden_size + self.n_local_heads = self.num_local_experts or self.num_attention_heads + self.num_activated_experts = self.num_experts_per_tok + self.__post_init__() + class MixtralModelHf(Transformer, GenerationMixin): """Transformer module that uses HF MixtralConfig instead of Jetstream Pytorch MixtralConfig + device. """ @@ -14,20 +28,9 @@ def __init__( device, env, ): - self.config = config self.generation_config = GenerationConfig.from_model_config(config) - - args = mixtral_config.ModelArgs( - block_size=config.max_position_embeddings, - vocab_size=config.vocab_size, - n_layer=config.num_hidden_layers, - n_head=config.num_attention_heads, - dim=config.hidden_size, - intermediate_size=config.intermediate_size, - n_local_heads=config.num_local_experts or config.num_attention_heads, - num_activated_experts=config.num_experts_per_tok, - device=device, - ) + args = MixtralConfigHf(**config.to_dict()) + args.device = device super().__init__(args, env) From cf21e04a9af79a8867ba99f3c3ff03acbe1d8463 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Thu, 28 Nov 2024 15:04:37 +0000 Subject: [PATCH 08/13] chore(optimum): remove AutoModelForCausalLM from optimum.tpu It is still possible to import it importing modeling, but it will reduce the possibility of importing transformers and torch xla before xla2. --- optimum/tpu/__init__.py | 1 - .../text_generation_server/generator.py | 22 +++++++++---------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/optimum/tpu/__init__.py b/optimum/tpu/__init__.py index 4f14dfc8..848946e5 100644 --- a/optimum/tpu/__init__.py +++ b/optimum/tpu/__init__.py @@ -14,5 +14,4 @@ from .jetstream_pt_support import jetstream_pt_available # isort:skip from .fsdp_v2 import get_fsdp_config, use_fsdp_v2 -from .modeling import AutoModelForCausalLM from .version import VERSION, __version__ diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index ab7c174b..b8f9cec7 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -16,23 +16,23 @@ from transformers.generation import GenerationConfig import optimum.tpu.xla_logger as logger -from optimum.tpu import AutoModelForCausalLM from optimum.tpu.generation import TokenSelector +from optimum.tpu.modeling import AutoModelForCausalLM from optimum.tpu.static_cache_xla import StaticCacheXla from optimum.tpu.xla_mp_comm import AgentMailbox, RootMailbox from .generator_base import Generator from .pb.generate_pb2 import ( - Batch, - CachedBatch, - FinishReason, - GeneratedText, - Generation, - InfoResponse, - NextTokenChooserParameters, - Request, - StoppingCriteriaParameters, - Tokens, + Batch, + CachedBatch, + FinishReason, + GeneratedText, + Generation, + InfoResponse, + NextTokenChooserParameters, + Request, + StoppingCriteriaParameters, + Tokens, ) From 9fd21a96e5588d54e05e561154aafd6f38322b74 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 27 Nov 2024 14:29:34 +0000 Subject: [PATCH 09/13] chore: update torch and torch_xla to v2.5.1 --- .github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml | 2 +- .../test-pytorch-xla-tpu-tgi-nightly-jetstream.yml | 2 +- .github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml | 2 +- .github/workflows/test-pytorch-xla-tpu-tgi.yml | 2 +- .github/workflows/test-pytorch-xla-tpu.yml | 2 +- optimum/tpu/cli.py | 2 +- pyproject.toml | 6 +++--- requirements.txt | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml b/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml index fb18d865..eb8f8929 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml @@ -22,7 +22,7 @@ jobs: runs-on: group: gcp-ct5lp-hightpu-8t container: - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache env: PJRT_DEVICE: TPU diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml index 3e4859e1..d103811c 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml @@ -16,7 +16,7 @@ jobs: runs-on: group: gcp-ct5lp-hightpu-8t container: - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache env: PJRT_DEVICE: TPU diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml index 1b14f127..0ba23d5e 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml @@ -17,7 +17,7 @@ jobs: runs-on: group: gcp-ct5lp-hightpu-8t container: - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache env: PJRT_DEVICE: TPU diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi.yml b/.github/workflows/test-pytorch-xla-tpu-tgi.yml index 78492caf..ff57d648 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi.yml @@ -20,7 +20,7 @@ jobs: runs-on: group: gcp-ct5lp-hightpu-8t container: - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache env: PJRT_DEVICE: TPU diff --git a/.github/workflows/test-pytorch-xla-tpu.yml b/.github/workflows/test-pytorch-xla-tpu.yml index efa2e354..11e66c8d 100644 --- a/.github/workflows/test-pytorch-xla-tpu.yml +++ b/.github/workflows/test-pytorch-xla-tpu.yml @@ -20,7 +20,7 @@ jobs: runs-on: group: gcp-ct5lp-hightpu-8t container: - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache env: PJRT_DEVICE: TPU diff --git a/optimum/tpu/cli.py b/optimum/tpu/cli.py index c43065b3..069ff965 100644 --- a/optimum/tpu/cli.py +++ b/optimum/tpu/cli.py @@ -9,7 +9,7 @@ import typer -TORCH_VER = "2.4.0" +TORCH_VER = "2.5.1" JETSTREAM_PT_VER = "jetstream-v0.2.4" DEFAULT_DEPS_PATH = os.path.join(Path.home(), ".jetstream-deps") diff --git a/pyproject.toml b/pyproject.toml index b9d4c9d4..22e7e3f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,8 +43,8 @@ keywords = [ dependencies = [ "transformers == 4.41.1", - "torch == 2.4.0", - "torch-xla[tpu] == 2.4.0", + "torch == 2.5.1", + "torch-xla[tpu] == 2.5.1", 'typer == 0.6.1', "loguru == 0.6.0", "sentencepiece == 0.2.0", @@ -63,7 +63,7 @@ quality = ["black", "ruff", "isort"] # Pallas is pulled because it will install a compatible version of jax[tpu]. jetstream-pt = [ "jetstream-pt", - "torch-xla[pallas] == 2.4.0" + "torch-xla[pallas] == 2.5.1" ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 76a52096..64a4a636 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ # This is not a complete list of dependencies, but it allows to install torch without CUDA support --index-url https://download.pytorch.org/whl/cpu -torch==2.4.0 +torch==2.5.1 From 340d4fdc4d2d09876534d96be90498e67b220e75 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Thu, 28 Nov 2024 15:17:00 +0000 Subject: [PATCH 10/13] chore(jetstream): token selector operations are done in torch Conversions of scores tensors from jax to torch and back are done when calling logits processor. This will be required in newer versions of transformers. --- .../jetstream_pt_support/token_selector.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py index ce0820c4..472b601b 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp +import torch_xla2 from jetstream.engine import sampling_utils from transformers.generation import ( GenerationConfig, @@ -173,7 +174,12 @@ def select(self, input_ids: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: Return: `jnp.ndarray`: A `jnp.ndarray` containing the selected tokens. """ - scores = self.logits_processor(input_ids, logits) + # Logits processors is written in pytorch, so parameters are cast to float32 and converted to pytorch and back + # to jax with j2t/t2j (that is a bit expensive, it does copies), otherwise some operations are not supported. + logits_t = torch_xla2.tensor.j2t(logits.astype(jnp.float32)) + scores = self.logits_processor(input_ids, logits_t) + scores = torch_xla2.tensor.t2j(scores).to_device(logits.device) + if self.mode == GenerationMode.SAMPLE: # split the key to avoid reusing the same key for multiple samples subkey, self.key = jax.random.split(self.key) From 515f45ad91d5f88f333ad06e5a87893ec7c5099b Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Mon, 25 Nov 2024 16:48:14 +0000 Subject: [PATCH 11/13] chore(dependencies): update transformers to v4.46.3 --- pyproject.toml | 2 +- text-generation-inference/docker/Dockerfile | 2 +- text-generation-inference/server/pyproject.toml | 2 +- .../server/text_generation_server/generator.py | 3 +++ .../text_generation_server/jetstream_pt_support/generator.py | 4 ++++ 5 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 22e7e3f0..449c909b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ keywords = [ ] dependencies = [ - "transformers == 4.41.1", + "transformers == 4.46.3", "torch == 2.5.1", "torch-xla[tpu] == 2.5.1", 'typer == 0.6.1', diff --git a/text-generation-inference/docker/Dockerfile b/text-generation-inference/docker/Dockerfile index 2775cf7d..3caf72a5 100644 --- a/text-generation-inference/docker/Dockerfile +++ b/text-generation-inference/docker/Dockerfile @@ -101,7 +101,7 @@ RUN apt-get update -y \ RUN pip install --upgrade pip # Install HuggingFace packages -ARG TRANSFORMERS_VERSION='4.41.1' +ARG TRANSFORMERS_VERSION='4.46.3' ARG ACCELERATE_VERSION='1.1.1' ARG SAFETENSORS_VERSION='0.4.2' diff --git a/text-generation-inference/server/pyproject.toml b/text-generation-inference/server/pyproject.toml index a10727b8..db37423b 100644 --- a/text-generation-inference/server/pyproject.toml +++ b/text-generation-inference/server/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ 'grpc-interceptor == 0.15.2', 'typer == 0.6.1', 'safetensors == 0.4.2', - 'transformers == 4.41.1', + 'transformers == 4.46.3', 'loguru == 0.6.0', "sentencepiece == 0.2.0", "numpy<2.0", diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index b8f9cec7..cf7e1f3b 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -314,6 +314,9 @@ def __init__( tokenizer.truncation_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids + # The token selector will use the model's generation mixin internal variables to select the next token, and it + # expects special tokens to be initialized in the model. + model._prepare_special_tokens(generation_config=model.generation_config, device=model.device) # Slots are empty to begin with, they will be populated as new batches arrive self.slots = [] self.batch_id = 0 diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 97061421..45f0a549 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -262,6 +262,10 @@ def __init__( tokenizer.truncation_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids + # The token selector will use the model's generation mixin internal variables to select the next token, and it + # expects special tokens to be initialized in the model. + model = self.engine.pt_model + model._prepare_special_tokens(generation_config=model.generation_config, device='cpu') # Slots number is static, it cannot grow over the size of the batch self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)] self.batch_id = 0 From f063f128b158ae1f648a48004f19b874fad2fa10 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Thu, 28 Nov 2024 16:35:08 +0000 Subject: [PATCH 12/13] chore: update safetensors to v0.4.5 This is to be coherent with accelerate dependencies, and to update to a newer version. --- text-generation-inference/docker/Dockerfile | 2 +- text-generation-inference/server/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/text-generation-inference/docker/Dockerfile b/text-generation-inference/docker/Dockerfile index 3caf72a5..319ae9e8 100644 --- a/text-generation-inference/docker/Dockerfile +++ b/text-generation-inference/docker/Dockerfile @@ -103,7 +103,7 @@ RUN pip install --upgrade pip # Install HuggingFace packages ARG TRANSFORMERS_VERSION='4.46.3' ARG ACCELERATE_VERSION='1.1.1' -ARG SAFETENSORS_VERSION='0.4.2' +ARG SAFETENSORS_VERSION='0.4.5' # TGI base env ENV HF_HOME=/data \ diff --git a/text-generation-inference/server/pyproject.toml b/text-generation-inference/server/pyproject.toml index db37423b..ab00ebc9 100644 --- a/text-generation-inference/server/pyproject.toml +++ b/text-generation-inference/server/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ 'grpcio-reflection == 1.62.1', 'grpc-interceptor == 0.15.2', 'typer == 0.6.1', - 'safetensors == 0.4.2', + 'safetensors == 0.4.5', 'transformers == 4.46.3', 'loguru == 0.6.0', "sentencepiece == 0.2.0", From 54183d756c9e97dc33a2e9cde6425e612f5f9747 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 29 Nov 2024 10:41:14 +0000 Subject: [PATCH 13/13] review(mixtral): use properties in config to avoid aliasing ambiguity Instead of assigning separate variables for Jetstream's config class, properties are added, resulting in accessing the same data and avoiding ambiguity. --- .../models/mixtral_model_hf.py | 38 +++++++++++++------ 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py index fde78b10..25a93595 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py @@ -1,26 +1,42 @@ - from jetstream_pt.third_party.mixtral import config as mixtral_config from jetstream_pt.third_party.mixtral.model import Transformer from transformers import GenerationConfig, GenerationMixin, MixtralConfig class MixtralConfigHf(MixtralConfig, mixtral_config.ModelArgs): - """This class is used to support both the HF MixtralConfig and the Jetstream Pytorch ModelArgs at the same time. - """ + """This class is used to support both the HF MixtralConfig and the Jetstream Pytorch ModelArgs at the same time.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.block_size = self.max_position_embeddings - self.n_layer = self.num_hidden_layers - self.n_head = self.num_attention_heads - self.dim = self.hidden_size - self.n_local_heads = self.num_local_experts or self.num_attention_heads - self.num_activated_experts = self.num_experts_per_tok self.__post_init__() + @property + def block_size(self): + return self.max_position_embeddings + + @property + def n_layer(self): + return self.num_hidden_layers + + @property + def n_head(self): + return self.num_attention_heads + + @property + def dim(self): + return self.hidden_size + + @property + def n_local_heads(self): + return self.num_local_experts or self.num_attention_heads + + @property + def num_activated_experts(self): + return self.num_experts_per_tok + + class MixtralModelHf(Transformer, GenerationMixin): - """Transformer module that uses HF MixtralConfig instead of Jetstream Pytorch MixtralConfig + device. - """ + """Transformer module that uses HF MixtralConfig instead of Jetstream Pytorch MixtralConfig + device.""" def __init__( self,