From ddc6498f8f0e2fa7ca79f367caa2166ec5f35b3b Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Wed, 13 Dec 2023 20:50:27 +0100 Subject: [PATCH] Remove TGI folder from Optimum Habana (#597) --- text-generation-inference/Dockerfile | 87 -- text-generation-inference/Makefile | 16 - text-generation-inference/README.md | 66 +- .../launcher/src/main.rs | 1263 ----------------- text-generation-inference/server/Makefile | 23 - .../server/pyproject.toml | 20 - .../server/requirements.txt | 78 - .../server/text_generation_server/cli.py | 221 --- .../text_generation_server/models/__init__.py | 35 - .../text_generation_server/models/bloom.py | 48 - .../models/causal_lm.py | 929 ------------ .../text_generation_server/models/model.py | 90 -- .../models/santacoder.py | 37 - .../server/text_generation_server/server.py | 193 --- .../text_generation_server/tgi_service.py | 29 - .../text_generation_server/utils/dist.py | 91 -- .../utils/logits_process.py | 381 ----- .../text_generation_server/utils/tokens.py | 361 ----- .../text_generation_server/utils/watermark.py | 86 -- 19 files changed, 1 insertion(+), 4053 deletions(-) delete mode 100644 text-generation-inference/Dockerfile delete mode 100644 text-generation-inference/Makefile delete mode 100644 text-generation-inference/launcher/src/main.rs delete mode 100644 text-generation-inference/server/Makefile delete mode 100644 text-generation-inference/server/pyproject.toml delete mode 100644 text-generation-inference/server/requirements.txt delete mode 100644 text-generation-inference/server/text_generation_server/cli.py delete mode 100644 text-generation-inference/server/text_generation_server/models/__init__.py delete mode 100644 text-generation-inference/server/text_generation_server/models/bloom.py delete mode 100644 text-generation-inference/server/text_generation_server/models/causal_lm.py delete mode 100644 text-generation-inference/server/text_generation_server/models/model.py delete mode 100644 text-generation-inference/server/text_generation_server/models/santacoder.py delete mode 100644 text-generation-inference/server/text_generation_server/server.py delete mode 100644 text-generation-inference/server/text_generation_server/tgi_service.py delete mode 100644 text-generation-inference/server/text_generation_server/utils/dist.py delete mode 100644 text-generation-inference/server/text_generation_server/utils/logits_process.py delete mode 100644 text-generation-inference/server/text_generation_server/utils/tokens.py delete mode 100644 text-generation-inference/server/text_generation_server/utils/watermark.py diff --git a/text-generation-inference/Dockerfile b/text-generation-inference/Dockerfile deleted file mode 100644 index 455a6e7b3d..0000000000 --- a/text-generation-inference/Dockerfile +++ /dev/null @@ -1,87 +0,0 @@ -# Clone the original TGI repo -FROM python:3.9 as tgi -WORKDIR /tmp -RUN git clone --depth 1 --branch v1.1.0 https://github.com/huggingface/text-generation-inference.git -COPY launcher/src/main.rs text-generation-inference/launcher/src/main.rs -# Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.70 AS chef -WORKDIR /usr/src - -FROM chef as planner -COPY --from=tgi /tmp/text-generation-inference/Cargo.toml Cargo.toml -COPY --from=tgi /tmp/text-generation-inference/rust-toolchain.toml rust-toolchain.toml -COPY --from=tgi /tmp/text-generation-inference/proto proto -COPY --from=tgi /tmp/text-generation-inference/benchmark benchmark -COPY --from=tgi /tmp/text-generation-inference/router router -COPY --from=tgi /tmp/text-generation-inference/launcher launcher -RUN cargo chef prepare --recipe-path recipe.json - -FROM chef AS builder - -RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ - curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ - unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ - unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ - rm -f $PROTOC_ZIP - -COPY --from=planner /usr/src/recipe.json recipe.json -RUN cargo chef cook --release --recipe-path recipe.json - -COPY --from=tgi /tmp/text-generation-inference/Cargo.toml Cargo.toml -COPY --from=tgi /tmp/text-generation-inference/rust-toolchain.toml rust-toolchain.toml -COPY --from=tgi /tmp/text-generation-inference/proto proto -COPY --from=tgi /tmp/text-generation-inference/benchmark benchmark -COPY --from=tgi /tmp/text-generation-inference/router router -COPY --from=tgi /tmp/text-generation-inference/launcher launcher -RUN cargo build --release - -# Text Generation Inference base image -FROM vault.habana.ai/gaudi-docker/1.13.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.0:latest as base - -# Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ - HF_HUB_ENABLE_HF_TRANSFER=1 \ - PORT=80 - -# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it -RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \ - dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb - -WORKDIR /usr/src - -COPY --from=tgi /tmp/text-generation-inference/proto proto -COPY --from=tgi /tmp/text-generation-inference/server server -COPY --from=tgi /tmp/text-generation-inference/server/Makefile server/Makefile -# Copy files modified for running on Gaudi -COPY server/text_generation_server/cli.py server/text_generation_server/cli.py -COPY server/text_generation_server/tgi_service.py server/text_generation_server/tgi_service.py -COPY server/text_generation_server/server.py server/text_generation_server/server.py -COPY server/text_generation_server/models/__init__.py server/text_generation_server/models/__init__.py -COPY server/text_generation_server/models/bloom.py server/text_generation_server/models/bloom.py -COPY server/text_generation_server/models/causal_lm.py server/text_generation_server/models/causal_lm.py -COPY server/text_generation_server/models/model.py server/text_generation_server/models/model.py -COPY server/text_generation_server/models/santacoder.py server/text_generation_server/models/santacoder.py -COPY server/text_generation_server/utils/watermark.py server/text_generation_server/utils/watermark.py -COPY server/text_generation_server/utils/tokens.py server/text_generation_server/utils/tokens.py -COPY server/text_generation_server/utils/logits_process.py server/text_generation_server/utils/logits_process.py -COPY server/text_generation_server/utils/dist.py server/text_generation_server/utils/dist.py -COPY server/requirements.txt server/requirements.txt -COPY server/pyproject.toml server/pyproject.toml -# Install server -RUN cd server && \ - make gen-server && \ - pip install -r requirements.txt && \ - pip install . --no-cache-dir - -# Install benchmarker -COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark -# Install router -COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router -# Install launcher -COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher - -# Final image -FROM base - -ENTRYPOINT ["text-generation-launcher"] -CMD ["--json-output"] diff --git a/text-generation-inference/Makefile b/text-generation-inference/Makefile deleted file mode 100644 index b7561d1fc8..0000000000 --- a/text-generation-inference/Makefile +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -debug_image_build: - docker build --no-cache --progress=plain -t debug_tgi . diff --git a/text-generation-inference/README.md b/text-generation-inference/README.md index 37139ee2a8..bb867f8ce9 100644 --- a/text-generation-inference/README.md +++ b/text-generation-inference/README.md @@ -16,68 +16,4 @@ limitations under the License. # Text Generation Inference on Habana Gaudi -To use [🤗 text-generation-inference](https://github.com/huggingface/text-generation-inference) on Habana Gaudi/Gaudi2, follow these steps: - -1. Build the Docker image located in this folder with: - ```bash - docker build -t tgi_gaudi . - ``` -2. Launch a local server instance on 1 Gaudi card: - ```bash - model=meta-llama/Llama-2-7b-hf - volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run - - docker run -p 8080:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host tgi_gaudi --model-id $model - ``` -3. Launch a local server instance on 8 Gaudi cards: - ```bash - model=meta-llama/Llama-2-70b-hf - volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run - - docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host tgi_gaudi --model-id $model --sharded true --num-shard 8 - ``` -4. You can then send a request: - ```bash - curl 127.0.0.1:8080/generate \ - -X POST \ - -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17, "do_sample": true}}' \ - -H 'Content-Type: application/json' - ``` - > The first call will be slower as the model is compiled. -5. To run benchmark test, please refer [TGI's benchmark tool](https://github.com/huggingface/text-generation-inference/tree/main/benchmark). - - To run it on the same machine, you can do the following: - * `docker exec -it bash` , pick the docker started from step 3 or 4 using docker ps - * `text-generation-benchmark -t ` , pass the model-id from docker run command - * after the completion of tests, hit ctrl+c to see the performance data summary. - -> For gated models such as [StarCoder](https://huggingface.co/bigcode/starcoder), you will have to pass `-e HUGGING_FACE_HUB_TOKEN=` to the `docker run` command above with a valid Hugging Face Hub read token. - -For more information and documentation about Text Generation Inference, checkout [the README](https://github.com/huggingface/text-generation-inference#text-generation-inference) of the original repo. - -Not all features of TGI are currently supported as this is still a work in progress. - -New changes are added for the current release: -- Sharded feature with support for DeepSpeed-inference auto tensor parallism. Also use HPU graph for performance improvement. -- Torch profile. - - -Enviroment Variables Added: - -
- -| Name | Value(s) | Default | Description | Usage | -|------------------ |:---------------|:------------|:-------------------- |:--------------------------------- -| MAX_TOTAL_TOKENS | integer | 0 | Control the padding of input | add -e in docker run, such | -| ENABLE_HPU_GRAPH | true/false | true | Enable hpu graph or not | add -e in docker run command | -| PROF_WARMUPSTEP | integer | 0 | Enable/disable profile, control profile warmup step, 0 means disable profile | add -e in docker run command | -| PROF_STEP | interger | 5 | Control profile step | add -e in docker run command | -| PROF_PATH | string | /root/text-generation-inference | Define profile folder | add -e in docker run command | -| LIMIT_HPU_GRAPH | True/False | False | Skip HPU graph usage for prefill to save memory | add -e in docker run command | - -
- - -> The license to use TGI on Habana Gaudi is the one of TGI: https://github.com/huggingface/text-generation-inference/blob/main/LICENSE -> -> Please reach out to api-enterprise@huggingface.co if you have any question. +Please refer to the following fork of TGI for deploying it on Habana Gaudi: https://github.com/huggingface/tgi-gaudi diff --git a/text-generation-inference/launcher/src/main.rs b/text-generation-inference/launcher/src/main.rs deleted file mode 100644 index eb47f65e13..0000000000 --- a/text-generation-inference/launcher/src/main.rs +++ /dev/null @@ -1,1263 +0,0 @@ -use clap::{Parser, ValueEnum}; -use nix::sys::signal::{self, Signal}; -use nix::unistd::Pid; -use serde::Deserialize; -use std::env; -use std::ffi::OsString; -use std::io::{BufRead, BufReader, Lines, Read}; -use std::os::unix::process::{CommandExt, ExitStatusExt}; -use std::path::Path; -use std::process::{Child, Command, ExitStatus, Stdio}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::mpsc::TryRecvError; -use std::sync::{mpsc, Arc}; -use std::thread; -use std::thread::sleep; -use std::time::{Duration, Instant}; -use std::{fs, io}; -use tracing_subscriber::EnvFilter; - -mod env_runtime; - -#[derive(Clone, Copy, Debug, ValueEnum)] -enum Quantization { - /// 4 bit quantization. Requires a specific GTPQ quantized model: - /// https://hf.co/models?search=awq. - /// Should replace GPTQ models whereever possible because of the better latency - Awq, - /// 8 bit quantization, doesn't require specific model. - /// Should be a drop-in replacement to bitsandbytes with much better performance. - /// Kernels are from https://github.com/NetEase-FuXi/EETQ.git - Eetq, - /// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. - /// text-generation-inference will use exllama (faster) kernels whereever possible, and use - /// triton kernel (wider support) when it's not. - /// AWQ has faster kernels. - Gptq, - /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, - /// but it is known that the model will be much slower to run than the native f16. - #[deprecated( - since = "1.1.0", - note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" - )] - Bitsandbytes, - /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, - /// but it is known that the model will be much slower to run than the native f16. - BitsandbytesNF4, - /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better - /// perplexity performance for you model - BitsandbytesFP4, -} - -impl std::fmt::Display for Quantization { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // To keep in track with `server`. - match self { - Quantization::Bitsandbytes => { - write!(f, "bitsandbytes") - } - Quantization::BitsandbytesNF4 => { - write!(f, "bitsandbytes-nf4") - } - Quantization::BitsandbytesFP4 => { - write!(f, "bitsandbytes-fp4") - } - Quantization::Gptq => { - write!(f, "gptq") - } - Quantization::Awq => { - write!(f, "awq") - } - Quantization::Eetq => { - write!(f, "eetq") - } - } - } -} - -#[derive(Clone, Copy, Debug, ValueEnum)] -enum Dtype { - Float16, - #[clap(name = "bfloat16")] - BFloat16, -} - -impl std::fmt::Display for Dtype { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // To keep in track with `server`. - match self { - Dtype::Float16 => { - write!(f, "float16") - } - Dtype::BFloat16 => { - write!(f, "bfloat16") - } - } - } -} - -#[derive(Clone, Copy, Debug, ValueEnum)] -enum RopeScaling { - Linear, - Dynamic, -} - -impl std::fmt::Display for RopeScaling { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // To keep in track with `server`. - match self { - RopeScaling::Linear => { - write!(f, "linear") - } - RopeScaling::Dynamic => { - write!(f, "dynamic") - } - } - } -} - -/// App Configuration -#[derive(Parser, Debug)] -#[clap(author, version, about, long_about = None)] -struct Args { - /// The name of the model to load. - /// Can be a MODEL_ID as listed on like - /// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`. - /// Or it can be a local directory containing the necessary files - /// as saved by `save_pretrained(...)` methods of transformers - #[clap(default_value = "bigscience/bloom-560m", long, env)] - model_id: String, - - /// The actual revision of the model if you're referring to a model - /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`. - #[clap(long, env)] - revision: Option, - - /// The number of tokenizer workers used for payload validation and truncation inside the - /// router. - #[clap(default_value = "2", long, env)] - validation_workers: usize, - - /// Whether to shard the model across multiple GPUs - /// By default text-generation-inference will use all available GPUs to run - /// the model. Setting it to `false` deactivates `num_shard`. - #[clap(long, env)] - sharded: Option, - - /// The number of shards to use if you don't want to use all GPUs on a given machine. - /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2` - /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to - /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance. - #[clap(long, env)] - num_shard: Option, - - /// Whether you want the model to be quantized. - #[clap(long, env, value_enum)] - quantize: Option, - - /// The dtype to be forced upon the model. This option cannot be used with `--quantize`. - #[clap(long, env, value_enum)] - dtype: Option, - - /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is - /// encouraged when loading a model with custom code to ensure no malicious code has been - /// contributed in a newer revision. - #[clap(long, env, value_enum)] - trust_remote_code: bool, - - /// The maximum amount of concurrent requests for this particular deployment. - /// Having a low limit will refuse clients requests instead of having them - /// wait for too long and is usually good to handle backpressure correctly. - #[clap(default_value = "128", long, env)] - max_concurrent_requests: usize, - - /// This is the maximum allowed value for clients to set `best_of`. - /// Best of makes `n` generations at the same time, and return the best - /// in terms of overall log probability over the entire generated sequence - #[clap(default_value = "2", long, env)] - max_best_of: usize, - - /// This is the maximum allowed value for clients to set `stop_sequences`. - /// Stop sequences are used to allow the model to stop on more than just - /// the EOS token, and enable more complex "prompting" where users can preprompt - /// the model in a specific way and define their "own" stop token aligned with - /// their prompt. - #[clap(default_value = "4", long, env)] - max_stop_sequences: usize, - - /// This is the maximum allowed value for clients to set `top_n_tokens`. - /// `top_n_tokens is used to return information about the the `n` most likely - /// tokens at each generation step, instead of just the sampled token. This - /// information can be used for downstream tasks like for classification or - /// ranking. - #[clap(default_value = "5", long, env)] - max_top_n_tokens: u32, - - /// This is the maximum allowed input length (expressed in number of tokens) - /// for users. The larger this value, the longer prompt users can send which - /// can impact the overall memory required to handle the load. - /// Please note that some models have a finite range of sequence they can handle. - #[clap(default_value = "1024", long, env)] - max_input_length: usize, - - /// This is the most important value to set as it defines the "memory budget" - /// of running clients requests. - /// Clients will send input sequences and ask to generate `max_new_tokens` - /// on top. with a value of `1512` users can send either a prompt of - /// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for - /// `1511` max_new_tokens. - /// The larger this value, the larger amount each request will be in your RAM - /// and the less effective batching can be. - #[clap(default_value = "2048", long, env)] - max_total_tokens: usize, - - /// This represents the ratio of waiting queries vs running queries where - /// you want to start considering pausing the running queries to include the waiting - /// ones into the same batch. - /// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's - /// only 10 queries left in the current batch we check if we can fit those 12 - /// waiting queries into the batching strategy, and if yes, then batching happens - /// delaying the 10 running queries by a `prefill` run. - /// - /// This setting is only applied if there is room in the batch - /// as defined by `max_batch_total_tokens`. - #[clap(default_value = "1.2", long, env)] - waiting_served_ratio: f32, - - /// Limits the number of tokens for the prefill operation. - /// Since this operation take the most memory and is compute bound, it is interesting - /// to limit the number of requests that can be sent. - #[clap(default_value = "4096", long, env)] - max_batch_prefill_tokens: u32, - - /// **IMPORTANT** This is one critical control to allow maximum usage - /// of the available hardware. - /// - /// This represents the total amount of potential tokens within a batch. - /// When using padding (not recommended) this would be equivalent of - /// `batch_size` * `max_total_tokens`. - /// - /// However in the non-padded (flash attention) version this can be much finer. - /// - /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100` - /// or a single query of `1000` tokens. - /// - /// Overall this number should be the largest possible amount that fits the - /// remaining memory (after the model is loaded). Since the actual memory overhead - /// depends on other parameters like if you're using quantization, flash attention - /// or the model implementation, text-generation-inference cannot infer this number - /// automatically. - #[clap(long, env)] - max_batch_total_tokens: Option, - - /// This setting defines how many tokens can be passed before forcing the waiting - /// queries to be put on the batch (if the size of the batch allows for it). - /// New queries require 1 `prefill` forward, which is different from `decode` - /// and therefore you need to pause the running batch in order to run `prefill` - /// to create the correct values for the waiting queries to be able to join the batch. - /// - /// With a value too small, queries will always "steal" the compute to run `prefill` - /// and running queries will be delayed by a lot. - /// - /// With a value too big, waiting queries could wait for a very long time - /// before being allowed a slot in the running batch. If your server is busy - /// that means that requests that could run in ~2s on an empty server could - /// end up running in ~20s because the query had to wait for 18s. - /// - /// This number is expressed in number of tokens to make it a bit more - /// "model" agnostic, but what should really matter is the overall latency - /// for end users. - #[clap(default_value = "20", long, env)] - max_waiting_tokens: usize, - - /// The IP address to listen on - #[clap(default_value = "0.0.0.0", long, env)] - hostname: String, - - /// The port to listen on. - #[clap(default_value = "3000", long, short, env)] - port: u16, - - /// The name of the socket for gRPC communication between the webserver - /// and the shards. - #[clap(default_value = "/tmp/text-generation-server", long, env)] - shard_uds_path: String, - - /// The address the master shard will listen on. (setting used by torch distributed) - #[clap(default_value = "localhost", long, env)] - master_addr: String, - - /// The address the master port will listen on. (setting used by torch distributed) - #[clap(default_value = "29500", long, env)] - master_port: usize, - - /// The location of the huggingface hub cache. - /// Used to override the location if you want to provide a mounted disk for instance - #[clap(long, env)] - huggingface_hub_cache: Option, - - /// The location of the huggingface hub cache. - /// Used to override the location if you want to provide a mounted disk for instance - #[clap(long, env)] - weights_cache_override: Option, - - /// For some models (like bloom), text-generation-inference implemented custom - /// cuda kernels to speed up inference. Those kernels were only tested on A100. - /// Use this flag to disable them if you're running on different hardware and - /// encounter issues. - #[clap(long, env)] - disable_custom_kernels: bool, - - /// Limit the CUDA available memory. - /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction. - #[clap(default_value = "1.0", long, env)] - cuda_memory_fraction: f32, - - /// Rope scaling will only be used for RoPE models - /// and allow rescaling the position rotary to accomodate for - /// larger prompts. - /// - /// Goes together with `rope_factor`. - /// - /// `--rope-factor 2.0` gives linear scaling with a factor of 2.0 - /// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0 - /// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed - /// basically) - /// - /// `--rope-scaling linear --rope-factor` fully describes the scaling you want - #[clap(long, env)] - rope_scaling: Option, - - /// Rope scaling will only be used for RoPE models - /// See `rope_scaling` - #[clap(long, env)] - rope_factor: Option, - - /// Outputs the logs in JSON format (useful for telemetry) - #[clap(long, env)] - json_output: bool, - - #[clap(long, env)] - otlp_endpoint: Option, - - #[clap(long, env)] - cors_allow_origin: Vec, - #[clap(long, env)] - watermark_gamma: Option, - #[clap(long, env)] - watermark_delta: Option, - - /// Enable ngrok tunneling - #[clap(long, env)] - ngrok: bool, - - /// ngrok authentication token - #[clap(long, env)] - ngrok_authtoken: Option, - - /// ngrok edge - #[clap(long, env)] - ngrok_edge: Option, - - /// Display a lot of information about your runtime environment - #[clap(long, short, action)] - env: bool, -} - -#[derive(Debug)] -enum ShardStatus { - Ready, - Failed(usize), -} - -#[allow(clippy::too_many_arguments)] -fn shard_manager( - model_id: String, - revision: Option, - quantize: Option, - dtype: Option, - trust_remote_code: bool, - uds_path: String, - rank: usize, - world_size: usize, - master_addr: String, - master_port: usize, - huggingface_hub_cache: Option, - weights_cache_override: Option, - disable_custom_kernels: bool, - watermark_gamma: Option, - watermark_delta: Option, - cuda_memory_fraction: f32, - rope_scaling: Option, - rope_factor: Option, - otlp_endpoint: Option, - status_sender: mpsc::Sender, - shutdown: Arc, - _shutdown_sender: mpsc::Sender<()>, -) { - // Enter shard-manager tracing span - let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); - - // Get UDS path - let uds_string = format!("{uds_path}-{rank}"); - let uds = Path::new(&uds_string); - // Clean previous runs - if uds.exists() { - fs::remove_file(uds).unwrap(); - } - - // Process args - let mut shard_args = vec![ - "serve".to_string(), - model_id, - "--uds-path".to_string(), - uds_path, - "--logger-level".to_string(), - "INFO".to_string(), - "--json-output".to_string(), - ]; - - // Activate trust remote code - if trust_remote_code { - shard_args.push("--trust-remote-code".to_string()); - } - - // Activate tensor parallelism - if world_size > 1 { - shard_args.push("--sharded".to_string()); - } - - if let Some(quantize) = quantize { - shard_args.push("--quantize".to_string()); - shard_args.push(quantize.to_string()) - } - - if let Some(dtype) = dtype { - shard_args.push("--dtype".to_string()); - shard_args.push(dtype.to_string()) - } - - // Model optional revision - if let Some(revision) = revision { - shard_args.push("--revision".to_string()); - shard_args.push(revision) - } - - let rope = match (rope_scaling, rope_factor) { - (None, None) => None, - (Some(scaling), None) => Some((scaling, 1.0)), - (Some(scaling), Some(factor)) => Some((scaling, factor)), - (None, Some(factor)) => Some((RopeScaling::Linear, factor)), - }; - // OpenTelemetry - if let Some(otlp_endpoint) = otlp_endpoint { - shard_args.push("--otlp-endpoint".to_string()); - shard_args.push(otlp_endpoint); - } - - // Copy current process env - let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); - - // Torch Distributed Env vars - if world_size == 1 { - envs.push(("RANK".into(), rank.to_string().into())); - } - envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); - envs.push(("MASTER_ADDR".into(), master_addr.into())); - envs.push(("MASTER_PORT".into(), master_port.to_string().into())); - envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into())); - - // CUDA memory fraction - envs.push(( - "CUDA_MEMORY_FRACTION".into(), - cuda_memory_fraction.to_string().into(), - )); - - // Safetensors load fast - envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); - - // Enable hf transfer for insane download speeds - let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); - envs.push(( - "HF_HUB_ENABLE_HF_TRANSFER".into(), - enable_hf_transfer.into(), - )); - - // Parse Inference API token - if let Ok(api_token) = env::var("HF_API_TOKEN") { - envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) - }; - - // Detect rope scaling - // Sending as env instead of CLI args to not bloat everything - // those only can be used by RoPE models, so passing information around - // for all models will complexify code unnecessarily - if let Some((scaling, factor)) = rope { - envs.push(("ROPE_SCALING".into(), scaling.to_string().into())); - envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); - } - - // If huggingface_hub_cache is some, pass it to the shard - // Useful when running inside a docker container - if let Some(huggingface_hub_cache) = huggingface_hub_cache { - envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); - }; - - // If weights_cache_override is some, pass it to the shard - // Useful when running inside a HuggingFace Inference Endpoint - if let Some(weights_cache_override) = weights_cache_override { - envs.push(( - "WEIGHTS_CACHE_OVERRIDE".into(), - weights_cache_override.into(), - )); - }; - - // If disable_custom_kernels is true, pass it to the shard as an env var - if disable_custom_kernels { - envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) - } - - // Watermark Gamma - if let Some(watermark_gamma) = watermark_gamma { - envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into())) - } - - // Watermark Delta - if let Some(watermark_delta) = watermark_delta { - envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into())) - } - - // Start process - tracing::info!("Starting shard"); - let mut p = match Command::new("text-generation-server") - .args(shard_args) - .envs(envs) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .process_group(0) - .spawn() - { - Ok(p) => p, - Err(err) => { - if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-server not found in PATH"); - tracing::error!("Please install it with `make install-server`") - } - { - tracing::error!("{}", err); - } - - status_sender.send(ShardStatus::Failed(rank)).unwrap(); - return; - } - }; - - // Redirect STDOUT to the console - let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); - let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); - - //stdout tracing thread - thread::spawn(move || { - log_lines(shard_stdout_reader.lines()); - }); - - let mut ready = false; - let start_time = Instant::now(); - let mut wait_time = Instant::now(); - loop { - // Process exited - if let Some(exit_status) = p.try_wait().unwrap() { - // We read stderr in another thread as it seems that lines() can block in some cases - let (err_sender, err_receiver) = mpsc::channel(); - thread::spawn(move || { - for line in shard_stderr_reader.lines().flatten() { - err_sender.send(line).unwrap_or(()); - } - }); - let mut err = String::new(); - while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { - err = err + "\n" + &line; - } - - tracing::error!("Shard complete standard error output:\n{err}"); - - if let Some(signal) = exit_status.signal() { - tracing::error!("Shard process was signaled to shutdown with signal {signal}"); - } - - status_sender.send(ShardStatus::Failed(rank)).unwrap(); - return; - } - - // We received a shutdown signal - if shutdown.load(Ordering::SeqCst) { - p.kill().unwrap(); - let _ = p.wait(); - tracing::info!("Shard terminated"); - return; - } - - // Shard is ready - if uds.exists() && !ready { - tracing::info!("Shard ready in {:?}", start_time.elapsed()); - status_sender.send(ShardStatus::Ready).unwrap(); - ready = true; - } else if !ready && wait_time.elapsed() > Duration::from_secs(10) { - tracing::info!("Waiting for shard to be ready..."); - wait_time = Instant::now(); - } - sleep(Duration::from_millis(100)); - } -} - -fn shutdown_shards(shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>) { - tracing::info!("Shutting down shards"); - // Update shutdown value to true - // This will be picked up by the shard manager - shutdown.store(true, Ordering::SeqCst); - - // Wait for shards to shutdown - // This will block till all shutdown_sender are dropped - let _ = shutdown_receiver.recv(); -} - -fn num_cuda_devices() -> Option { - let devices = match env::var("CUDA_VISIBLE_DEVICES") { - Ok(devices) => devices, - Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?, - }; - let n_devices = devices.split(',').count(); - Some(n_devices) -} - -#[derive(Deserialize)] -#[serde(rename_all = "UPPERCASE")] -enum PythonLogLevelEnum { - Trace, - Debug, - Info, - Success, - Warning, - Error, - Critical, -} - -#[derive(Deserialize)] -struct PythonLogLevel { - name: PythonLogLevelEnum, -} - -#[derive(Deserialize)] -struct PythonLogRecord { - level: PythonLogLevel, -} - -#[derive(Deserialize)] -struct PythonLogMessage { - text: String, - record: PythonLogRecord, -} - -impl PythonLogMessage { - fn trace(&self) { - match self.record.level.name { - PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text), - PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text), - PythonLogLevelEnum::Info => tracing::info!("{}", self.text), - PythonLogLevelEnum::Success => tracing::info!("{}", self.text), - PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text), - PythonLogLevelEnum::Error => tracing::error!("{}", self.text), - PythonLogLevelEnum::Critical => tracing::error!("{}", self.text), - } - } -} - -impl TryFrom<&String> for PythonLogMessage { - type Error = serde_json::Error; - - fn try_from(value: &String) -> Result { - serde_json::from_str::(value) - } -} - -fn log_lines(lines: Lines) { - for line in lines.flatten() { - match PythonLogMessage::try_from(&line) { - Ok(log) => log.trace(), - Err(_) => tracing::debug!("{line}"), - } - } -} - -fn find_num_shards( - sharded: Option, - num_shard: Option, -) -> Result { - // get the number of shards given `sharded` and `num_shard` - let num_shard = match (sharded, num_shard) { - (Some(true), None) => { - // try to default to the number of available GPUs - tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES"); - let n_devices = num_cuda_devices() - .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set"); - if n_devices <= 1 { - return Err(LauncherError::NotEnoughCUDADevices(format!( - "`sharded` is true but only found {n_devices} CUDA devices" - ))); - } - n_devices - } - (Some(true), Some(num_shard)) => { - // we can't have only one shard while sharded - if num_shard <= 1 { - return Err(LauncherError::ArgumentValidation( - "`sharded` is true but `num_shard` <= 1".to_string(), - )); - } - num_shard - } - (Some(false), Some(num_shard)) => num_shard, - (Some(false), None) => 1, - (None, None) => num_cuda_devices().unwrap_or(1), - (None, Some(num_shard)) => num_shard, - }; - if num_shard < 1 { - return Err(LauncherError::ArgumentValidation( - "`num_shard` cannot be < 1".to_string(), - )); - } - Ok(num_shard) -} - -#[derive(Debug)] -enum LauncherError { - ArgumentValidation(String), - NotEnoughCUDADevices(String), - DownloadError, - ShardCannotStart, - ShardDisconnected, - ShardFailed, - WebserverFailed, - WebserverCannotStart, -} - -fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { - // Enter download tracing span - let _span = tracing::span!(tracing::Level::INFO, "download").entered(); - - let mut download_args = vec![ - "download-weights".to_string(), - args.model_id.to_string(), - "--extension".to_string(), - ".safetensors".to_string(), - "--logger-level".to_string(), - "INFO".to_string(), - "--json-output".to_string(), - ]; - - // Model optional revision - if let Some(revision) = &args.revision { - download_args.push("--revision".to_string()); - download_args.push(revision.to_string()) - } - - // Trust remote code for automatic peft fusion - if args.trust_remote_code { - download_args.push("--trust-remote-code".to_string()); - } - - // Copy current process env - let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); - - // If huggingface_hub_cache is set, pass it to the download process - // Useful when running inside a docker container - if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { - envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); - }; - - // Enable hf transfer for insane download speeds - let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); - envs.push(( - "HF_HUB_ENABLE_HF_TRANSFER".into(), - enable_hf_transfer.into(), - )); - - // Parse Inference API token - if let Ok(api_token) = env::var("HF_API_TOKEN") { - envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) - }; - - // If args.weights_cache_override is some, pass it to the download process - // Useful when running inside a HuggingFace Inference Endpoint - if let Some(weights_cache_override) = &args.weights_cache_override { - envs.push(( - "WEIGHTS_CACHE_OVERRIDE".into(), - weights_cache_override.into(), - )); - }; - - // Start process - tracing::info!("Starting download process."); - let mut download_process = match Command::new("text-generation-server") - .args(download_args) - .envs(envs) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .process_group(0) - .spawn() - { - Ok(p) => p, - Err(err) => { - if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-server not found in PATH"); - tracing::error!("Please install it with `make install-server`") - } else { - tracing::error!("{}", err); - } - - return Err(LauncherError::DownloadError); - } - }; - - // Redirect STDOUT to the console - let download_stdout = download_process.stdout.take().unwrap(); - let stdout = BufReader::new(download_stdout); - - thread::spawn(move || { - log_lines(stdout.lines()); - }); - - loop { - if let Some(status) = download_process.try_wait().unwrap() { - if status.success() { - tracing::info!("Successfully downloaded weights."); - break; - } - - let mut err = String::new(); - download_process - .stderr - .take() - .unwrap() - .read_to_string(&mut err) - .unwrap(); - if let Some(signal) = status.signal() { - tracing::error!( - "Download process was signaled to shutdown with signal {signal}: {err}" - ); - } else { - tracing::error!("Download encountered an error: {err}"); - } - - return Err(LauncherError::DownloadError); - } - if !running.load(Ordering::SeqCst) { - terminate("download", download_process, Duration::from_secs(10)).unwrap(); - return Ok(()); - } - sleep(Duration::from_millis(100)); - } - Ok(()) -} - -#[allow(clippy::too_many_arguments)] -fn spawn_shards( - num_shard: usize, - args: &Args, - shutdown: Arc, - shutdown_receiver: &mpsc::Receiver<()>, - shutdown_sender: mpsc::Sender<()>, - status_receiver: &mpsc::Receiver, - status_sender: mpsc::Sender, - running: Arc, -) -> Result<(), LauncherError> { - // Start shard processes - for rank in 0..1 { - let model_id = args.model_id.clone(); - let revision = args.revision.clone(); - let uds_path = args.shard_uds_path.clone(); - let master_addr = args.master_addr.clone(); - let huggingface_hub_cache = args.huggingface_hub_cache.clone(); - let weights_cache_override = args.weights_cache_override.clone(); - let status_sender = status_sender.clone(); - let shutdown = shutdown.clone(); - let shutdown_sender = shutdown_sender.clone(); - let otlp_endpoint = args.otlp_endpoint.clone(); - let quantize = args.quantize; - let dtype = args.dtype; - let trust_remote_code = args.trust_remote_code; - let master_port = args.master_port; - let disable_custom_kernels = args.disable_custom_kernels; - let watermark_gamma = args.watermark_gamma; - let watermark_delta = args.watermark_delta; - let cuda_memory_fraction = args.cuda_memory_fraction; - let rope_scaling = args.rope_scaling; - let rope_factor = args.rope_factor; - thread::spawn(move || { - shard_manager( - model_id, - revision, - quantize, - dtype, - trust_remote_code, - uds_path, - rank, - num_shard, - master_addr, - master_port, - huggingface_hub_cache, - weights_cache_override, - disable_custom_kernels, - watermark_gamma, - watermark_delta, - cuda_memory_fraction, - rope_scaling, - rope_factor, - otlp_endpoint, - status_sender, - shutdown, - shutdown_sender, - ) - }); - } - drop(shutdown_sender); - - // Wait for shard to start - let mut shard_ready = 0; - while running.load(Ordering::SeqCst) { - match status_receiver.try_recv() { - Ok(ShardStatus::Ready) => { - shard_ready += 1; - if shard_ready == 1 { - break; - } - } - Err(TryRecvError::Empty) => { - sleep(Duration::from_millis(100)); - } - Ok(ShardStatus::Failed(rank)) => { - tracing::error!("Shard {rank} failed to start"); - shutdown_shards(shutdown, shutdown_receiver); - return Err(LauncherError::ShardCannotStart); - } - Err(TryRecvError::Disconnected) => { - tracing::error!("Shard status channel disconnected"); - shutdown_shards(shutdown, shutdown_receiver); - return Err(LauncherError::ShardDisconnected); - } - } - } - Ok(()) -} - -fn spawn_webserver( - args: Args, - shutdown: Arc, - shutdown_receiver: &mpsc::Receiver<()>, -) -> Result { - // All shard started - // Start webserver - tracing::info!("Starting Webserver"); - let mut router_args = vec![ - "--max-concurrent-requests".to_string(), - args.max_concurrent_requests.to_string(), - "--max-best-of".to_string(), - args.max_best_of.to_string(), - "--max-stop-sequences".to_string(), - args.max_stop_sequences.to_string(), - "--max-top-n-tokens".to_string(), - args.max_top_n_tokens.to_string(), - "--max-input-length".to_string(), - args.max_input_length.to_string(), - "--max-total-tokens".to_string(), - args.max_total_tokens.to_string(), - "--max-batch-prefill-tokens".to_string(), - args.max_batch_prefill_tokens.to_string(), - "--waiting-served-ratio".to_string(), - args.waiting_served_ratio.to_string(), - "--max-waiting-tokens".to_string(), - args.max_waiting_tokens.to_string(), - "--validation-workers".to_string(), - args.validation_workers.to_string(), - "--hostname".to_string(), - args.hostname.to_string(), - "--port".to_string(), - args.port.to_string(), - "--master-shard-uds-path".to_string(), - format!("{}-0", args.shard_uds_path), - "--tokenizer-name".to_string(), - args.model_id, - ]; - - // Model optional max batch total tokens - if let Some(max_batch_total_tokens) = args.max_batch_total_tokens { - router_args.push("--max-batch-total-tokens".to_string()); - router_args.push(max_batch_total_tokens.to_string()); - } - - // Model optional revision - if let Some(ref revision) = args.revision { - router_args.push("--revision".to_string()); - router_args.push(revision.to_string()) - } - - if args.json_output { - router_args.push("--json-output".to_string()); - } - - // OpenTelemetry - if let Some(otlp_endpoint) = args.otlp_endpoint { - router_args.push("--otlp-endpoint".to_string()); - router_args.push(otlp_endpoint); - } - - // CORS origins - for origin in args.cors_allow_origin.into_iter() { - router_args.push("--cors-allow-origin".to_string()); - router_args.push(origin); - } - - // Ngrok - if args.ngrok { - router_args.push("--ngrok".to_string()); - router_args.push("--ngrok-authtoken".to_string()); - router_args.push(args.ngrok_authtoken.unwrap()); - router_args.push("--ngrok-edge".to_string()); - router_args.push(args.ngrok_edge.unwrap()); - } - - // Copy current process env - let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); - - // Parse Inference API token - if let Ok(api_token) = env::var("HF_API_TOKEN") { - envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) - }; - - let mut webserver = match Command::new("text-generation-router") - .args(router_args) - .envs(envs) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .process_group(0) - .spawn() - { - Ok(p) => p, - Err(err) => { - tracing::error!("Failed to start webserver: {}", err); - if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-router not found in PATH"); - tracing::error!("Please install it with `make install-router`") - } else { - tracing::error!("{}", err); - } - - shutdown_shards(shutdown, shutdown_receiver); - return Err(LauncherError::WebserverCannotStart); - } - }; - - // Redirect STDOUT and STDERR to the console - let webserver_stdout = webserver.stdout.take().unwrap(); - let webserver_stderr = webserver.stderr.take().unwrap(); - - thread::spawn(move || { - let stdout = BufReader::new(webserver_stdout); - let stderr = BufReader::new(webserver_stderr); - for line in stdout.lines() { - println!("{}", line.unwrap()); - } - for line in stderr.lines() { - println!("{}", line.unwrap()); - } - }); - Ok(webserver) -} - -fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result { - tracing::info!("Terminating {process_name}"); - - let terminate_time = Instant::now(); - signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap(); - - tracing::info!("Waiting for {process_name} to gracefully shutdown"); - - while terminate_time.elapsed() < timeout { - if let Some(status) = process.try_wait()? { - tracing::info!("{process_name} terminated"); - return Ok(status); - } - sleep(Duration::from_millis(100)); - } - - tracing::info!("Killing {process_name}"); - - process.kill()?; - let exit_status = process.wait()?; - - tracing::info!("{process_name} killed"); - Ok(exit_status) -} - -fn main() -> Result<(), LauncherError> { - // Pattern match configuration - let args: Args = Args::parse(); - - // Filter events with LOG_LEVEL - let env_filter = - EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); - - if args.json_output { - tracing_subscriber::fmt() - .with_env_filter(env_filter) - .json() - .init(); - } else { - tracing_subscriber::fmt() - .with_env_filter(env_filter) - .compact() - .init(); - } - - if args.env { - let env_runtime = env_runtime::Env::new(); - tracing::info!("{}", env_runtime); - } - - tracing::info!("{:?}", args); - - // Validate args - if args.max_input_length >= args.max_total_tokens { - return Err(LauncherError::ArgumentValidation( - "`max_input_length` must be < `max_total_tokens`".to_string(), - )); - } - if args.max_input_length as u32 > args.max_batch_prefill_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}", - args.max_batch_prefill_tokens, args.max_input_length - ))); - } - - if args.validation_workers == 0 { - return Err(LauncherError::ArgumentValidation( - "`validation_workers` must be > 0".to_string(), - )); - } - if args.trust_remote_code { - tracing::warn!( - "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.", - args.model_id - ); - } - - let num_shard = find_num_shards(args.sharded, args.num_shard)?; - if num_shard > 1 { - tracing::info!("Sharding model on {num_shard} processes"); - } - - if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { - if args.max_batch_prefill_tokens > *max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - args.max_batch_prefill_tokens, max_batch_total_tokens - ))); - } - if args.max_total_tokens as u32 > *max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - args.max_total_tokens, max_batch_total_tokens - ))); - } - } - - if args.ngrok { - if args.ngrok_authtoken.is_none() { - return Err(LauncherError::ArgumentValidation( - "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(), - )); - } - - if args.ngrok_edge.is_none() { - return Err(LauncherError::ArgumentValidation( - "`ngrok-edge` must be set when using ngrok tunneling".to_string(), - )); - } - } - - // Signal handler - let running = Arc::new(AtomicBool::new(true)); - let r = running.clone(); - ctrlc::set_handler(move || { - r.store(false, Ordering::SeqCst); - }) - .expect("Error setting Ctrl-C handler"); - - // Download and convert model weights - download_convert_model(&args, running.clone())?; - - if !running.load(Ordering::SeqCst) { - // Launcher was asked to stop - return Ok(()); - } - - // Shared shutdown bool - let shutdown = Arc::new(AtomicBool::new(false)); - // Shared shutdown channel - // When shutting down, the main thread will wait for all senders to be dropped - let (shutdown_sender, shutdown_receiver) = mpsc::channel(); - - // Shared channel to track shard status - let (status_sender, status_receiver) = mpsc::channel(); - - spawn_shards( - num_shard, - &args, - shutdown.clone(), - &shutdown_receiver, - shutdown_sender, - &status_receiver, - status_sender, - running.clone(), - )?; - - // We might have received a termination signal - if !running.load(Ordering::SeqCst) { - shutdown_shards(shutdown, &shutdown_receiver); - return Ok(()); - } - - let mut webserver = - spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| { - shutdown_shards(shutdown.clone(), &shutdown_receiver); - err - })?; - - // Default exit code - let mut exit_code = Ok(()); - - while running.load(Ordering::SeqCst) { - if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() { - tracing::error!("Shard {rank} crashed"); - exit_code = Err(LauncherError::ShardFailed); - break; - }; - - match webserver.try_wait().unwrap() { - Some(_) => { - tracing::error!("Webserver Crashed"); - shutdown_shards(shutdown, &shutdown_receiver); - return Err(LauncherError::WebserverFailed); - } - None => { - sleep(Duration::from_millis(100)); - } - }; - } - - // Graceful termination - terminate("webserver", webserver, Duration::from_secs(90)).unwrap(); - shutdown_shards(shutdown, &shutdown_receiver); - - exit_code -} diff --git a/text-generation-inference/server/Makefile b/text-generation-inference/server/Makefile deleted file mode 100644 index 9e9596ae3b..0000000000 --- a/text-generation-inference/server/Makefile +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -install-poetry: - curl -sSL https://install.python-poetry.org | python3 - - -update-lock: - rm poetry.lock - poetry lock --no-update - -export-requirements: - poetry export -f requirements.txt --without-hashes --output requirements.txt diff --git a/text-generation-inference/server/pyproject.toml b/text-generation-inference/server/pyproject.toml deleted file mode 100644 index ce67c7023d..0000000000 --- a/text-generation-inference/server/pyproject.toml +++ /dev/null @@ -1,20 +0,0 @@ -[tool.poetry] -name = "text-generation-server" -version = "1.1.0" -description = "Text Generation Inference Python gRPC Server" -authors = ["Olivier Dehaene "] - -[tool.poetry.scripts] -text-generation-server = 'text_generation_server.cli:app' - -[tool.poetry.dependencies] -huggingface-hub = "^0.16.4" -deepspeed = { git = "https://github.com/HabanaAI/DeepSpeed.git", branch = "1.13.0" } -optimum-habana = { git = "https://github.com/huggingface/optimum-habana.git", branch = "main" } - -[tool.pytest.ini_options] -markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] - -[build-system] -requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" diff --git a/text-generation-inference/server/requirements.txt b/text-generation-inference/server/requirements.txt deleted file mode 100644 index c4247cc204..0000000000 --- a/text-generation-inference/server/requirements.txt +++ /dev/null @@ -1,78 +0,0 @@ -accelerate>=0.22.0 ; python_version >= "3.9" and python_version < "3.13" -aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "3.13" -aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13" -async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13" -attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13" -backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" -certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" -charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" -click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" -colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") -coloredlogs==15.0.1 ; python_version >= "3.9" and python_version < "3.13" -datasets==2.14.4 ; python_version >= "3.9" and python_version < "3.13" -deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" -diffusers==0.20.1 ; python_version >= "3.9" and python_version < "3.13" -dill==0.3.7 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" -frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "3.13" -fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13" -fsspec[http]==2023.6.0 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13" -grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13" -grpcio-reflection==1.48.2 ; python_version >= "3.9" and python_version < "3.13" -grpcio-status==1.48.2 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.57.0 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" -humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13" -idna==3.4 ; python_version >= "3.9" and python_version < "3.13" -importlib-metadata==6.8.0 ; python_version >= "3.9" and python_version < "3.13" -jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13" -loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13" -mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" -multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13" -multiprocess==0.70.15 ; python_version >= "3.9" and python_version < "3.13" -networkx==3.1 ; python_version >= "3.9" and python_version < "3.13" -numpy==1.25.2 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -optimum==1.13.2 ; python_version >= "3.9" and python_version < "3.13" -packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" -pandas==2.0.3 ; python_version >= "3.9" and python_version < "3.13" -peft==0.4.0 ; python_version >= "3.9" and python_version < "3.13" -pillow==10.0.0 ; python_version >= "3.9" and python_version < "3.13" -protobuf==3.20.3 ; python_version >= "3.9" and python_version < "3.13" -psutil==5.9.5 ; python_version >= "3.9" and python_version < "3.13" -pyarrow==13.0.0 ; python_version >= "3.9" and python_version < "3.13" -pyreadline3==3.4.1 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "3.13" -python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "3.13" -pytz==2023.3 ; python_version >= "3.9" and python_version < "3.13" -pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" -regex==2023.8.8 ; python_version >= "3.9" and python_version < "3.13" -requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" -safetensors==0.3.2 ; python_version >= "3.9" and python_version < "3.13" -sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" -setuptools==68.1.2 ; python_version >= "3.9" and python_version < "3.13" -six==1.16.0 ; python_version >= "3.9" and python_version < "3.13" -sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" -tokenizers==0.14.1 ; python_version >= "3.9" and python_version < "3.13" -tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.34.1 ; python_version >= "3.9" and python_version < "3.13" -transformers[sentencepiece]==4.34.1 ; python_version >= "3.9" and python_version < "3.13" -typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" -tzdata==2023.3 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" -win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" -wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13" -yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13" -zipp==3.16.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/text-generation-inference/server/text_generation_server/cli.py b/text-generation-inference/server/text_generation_server/cli.py deleted file mode 100644 index 841c78828f..0000000000 --- a/text-generation-inference/server/text_generation_server/cli.py +++ /dev/null @@ -1,221 +0,0 @@ -import os -import sys -import typer - -from pathlib import Path -from loguru import logger -from typing import Optional -from enum import Enum - - -app = typer.Typer() - - -class Quantization(str, Enum): - bitsandbytes = "bitsandbytes" - gptq = "gptq" - - -class Dtype(str, Enum): - float16 = "float16" - bloat16 = "bfloat16" - - -@app.command() -def serve( - model_id: str, - revision: Optional[str] = None, - sharded: bool = False, - quantize: Optional[Quantization] = None, - dtype: Optional[Dtype] = None, - trust_remote_code: bool = False, - uds_path: Path = "/tmp/text-generation-server", - logger_level: str = "INFO", - json_output: bool = False, - otlp_endpoint: Optional[str] = None, -): - if sharded: - assert os.getenv("WORLD_SIZE", None) is not None, "WORLD_SIZE must be set when sharded is True" - assert os.getenv("MASTER_ADDR", None) is not None, "MASTER_ADDR must be set when sharded is True" - assert os.getenv("MASTER_PORT", None) is not None, "MASTER_PORT must be set when sharded is True" - - # Remove default handler - logger.remove() - logger.add( - sys.stdout, - format="{message}", - filter="text_generation_server", - level=logger_level, - serialize=json_output, - backtrace=True, - diagnose=False, - ) - - # Import here after the logger is added to log potential import exceptions - from text_generation_server import server - from text_generation_server.tracing import setup_tracing - - # Setup OpenTelemetry distributed tracing - if otlp_endpoint is not None: - setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) - - # Downgrade enum into str for easier management later on - quantize = None if quantize is None else quantize.value - dtype = "bfloat16" if dtype is None else dtype.value - - logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) - - if sharded: - tgi_file = Path(__file__).resolve().parent / "tgi_service.py" - num_shard = int(os.getenv("WORLD_SIZE", "1")) - logger.info("CLI SHARDED = {}".format(num_shard)) - import subprocess - - cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file} --model_id {model_id} --revision {revision} --sharded {sharded} --dtype {dtype} --uds_path {uds_path}" - logger.info("CLI server start deepspeed ={} ".format(cmd)) - sys.stdout.flush() - sys.stderr.flush() - with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc: - proc.wait() - sys.stdout.flush() - sys.stderr.flush() - if proc.returncode != 0: - logger.error(f"{cmd} exited with status = {proc.returncode}") - return proc.returncode - else: - server.serve(model_id, revision, dtype, uds_path, sharded) - - -@app.command() -def download_weights( - model_id: str, - revision: Optional[str] = None, - extension: str = ".safetensors", - auto_convert: bool = True, - logger_level: str = "INFO", - json_output: bool = False, -): - # Remove default handler - logger.remove() - logger.add( - sys.stdout, - format="{message}", - filter="text_generation_server", - level=logger_level, - serialize=json_output, - backtrace=True, - diagnose=False, - ) - - # Import here after the logger is added to log potential import exceptions - from text_generation_server import utils - - # Test if files were already download - try: - utils.weight_files(model_id, revision, extension) - logger.info("Files are already present on the host. " "Skipping download.") - return - # Local files not found - except (utils.LocalEntryNotFoundError, FileNotFoundError): - pass - - is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv( - "WEIGHTS_CACHE_OVERRIDE", None - ) is not None - - if not is_local_model: - # Try to download weights from the hub - try: - filenames = utils.weight_hub_files(model_id, revision, extension) - utils.download_weights(filenames, model_id, revision) - # Successfully downloaded weights - return - - # No weights found on the hub with this extension - except utils.EntryNotFoundError as e: - # Check if we want to automatically convert to safetensors or if we can use .bin weights instead - if not extension == ".safetensors" or not auto_convert: - raise e - - # Try to see if there are local pytorch weights - try: - # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE - local_pt_files = utils.weight_files(model_id, revision, ".bin") - - # No local pytorch weights - except utils.LocalEntryNotFoundError: - if extension == ".safetensors": - logger.warning( - f"No safetensors weights found for model {model_id} at revision {revision}. " - f"Downloading PyTorch weights." - ) - - # Try to see if there are pytorch weights on the hub - pt_filenames = utils.weight_hub_files(model_id, revision, ".bin") - # Download pytorch weights - local_pt_files = utils.download_weights(pt_filenames, model_id, revision) - - if auto_convert: - logger.warning( - f"No safetensors weights found for model {model_id} at revision {revision}. " - f"Converting PyTorch weights to safetensors." - ) - - # Safetensors final filenames - local_st_files = [p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files] - try: - import transformers - from transformers import AutoConfig - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - ) - architecture = config.architectures[0] - - class_ = getattr(transformers, architecture) - - # Name for this varible depends on transformers version. - discard_names = getattr(class_, "_tied_weights_keys", []) - discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) - - except Exception: - discard_names = [] - # Convert pytorch weights to safetensors - utils.convert_files(local_pt_files, local_st_files, discard_names) - - -@app.command() -def quantize( - model_id: str, - output_dir: str, - revision: Optional[str] = None, - logger_level: str = "INFO", - json_output: bool = False, - trust_remote_code: bool = False, - upload_to_model_id: Optional[str] = None, - percdamp: float = 0.01, - act_order: bool = False, -): - download_weights( - model_id=model_id, - revision=revision, - logger_level=logger_level, - json_output=json_output, - ) - from text_generation_server.utils.gptq.quantize import quantize - - quantize( - model_id=model_id, - bits=4, - groupsize=128, - output_dir=output_dir, - trust_remote_code=trust_remote_code, - upload_to_model_id=upload_to_model_id, - percdamp=percdamp, - act_order=act_order, - ) - - -if __name__ == "__main__": - app() diff --git a/text-generation-inference/server/text_generation_server/models/__init__.py b/text-generation-inference/server/text_generation_server/models/__init__.py deleted file mode 100644 index efe9b62a1d..0000000000 --- a/text-generation-inference/server/text_generation_server/models/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch - -from loguru import logger -from transformers.models.auto import modeling_auto -from transformers import AutoConfig -from typing import Optional - -from text_generation_server.models.model import Model -from text_generation_server.models.causal_lm import CausalLM -from text_generation_server.models.bloom import BLOOM -from text_generation_server.models.santacoder import SantaCoder - - -# Disable gradients -torch.set_grad_enabled(False) - - -def get_model( - model_id: str, - revision: Optional[str], - dtype: Optional[torch.dtype] = None, -) -> Model: - config = AutoConfig.from_pretrained(model_id, revision=revision) - model_type = config.model_type - - if model_type == "gpt_bigcode": - return SantaCoder(model_id, revision, dtype) - - if model_type == "bloom": - return BLOOM(model_id, revision, dtype) - - if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM(model_id, revision, dtype) - - raise ValueError(f"Unsupported model type {model_type}") diff --git a/text-generation-inference/server/text_generation_server/models/bloom.py b/text-generation-inference/server/text_generation_server/models/bloom.py deleted file mode 100644 index 09d8b69b4f..0000000000 --- a/text-generation-inference/server/text_generation_server/models/bloom.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch - -from typing import Optional, Type - -from transformers import PreTrainedTokenizerBase - -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.pb import generate_pb2 - - -class BloomCausalLMBatch(CausalLMBatch): - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - is_optimized_for_gaudi: bool = False, - ) -> "CausalLMBatch": - batch = super().from_pb( - pb=pb, - tokenizer=tokenizer, - dtype=dtype, - device=device, - is_optimized_for_gaudi=is_optimized_for_gaudi, - ) - batch.keys_head_dim_last = False - return batch - - -class BLOOM(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ): - super(BLOOM, self).__init__( - model_id=model_id, - revision=revision, - dtype=dtype, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return BloomCausalLMBatch diff --git a/text-generation-inference/server/text_generation_server/models/causal_lm.py b/text-generation-inference/server/text_generation_server/models/causal_lm.py deleted file mode 100644 index bf193b4fbc..0000000000 --- a/text-generation-inference/server/text_generation_server/models/causal_lm.py +++ /dev/null @@ -1,929 +0,0 @@ -import os -import tempfile - -from text_generation_server.utils.tokens import batch_top_tokens -import torch - -from dataclasses import dataclass -from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig -from typing import Optional, Tuple, List, Type, Dict -from habana_frameworks.torch.hpu import wrap_in_hpu_graph -import habana_frameworks.torch as htorch -from contextlib import nullcontext -from optimum.habana.utils import HabanaProfile - -from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES -from optimum.habana.checkpoint_utils import ( - get_repo_root, - model_on_meta, - write_checkpoints_json, -) - -from text_generation_server.models import Model -from text_generation_server.models.types import ( - Batch, - PrefillTokens, - Generation, - GeneratedText, - TopTokens, -) -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling -from loguru import logger - -tracer = trace.get_tracer(__name__) - -@dataclass -class CausalLMBatch(Batch): - batch_id: int - requests: List[generate_pb2.Request] - requests_idx_mapping: Dict[int, int] - - # Decoder values - input_ids: torch.Tensor - attention_mask: torch.Tensor - position_ids: torch.Tensor - past_key_values: Optional[List[Tuple]] - - # All tokens - all_input_ids: List[torch.Tensor] - - # Lengths of all generations present in the batch - input_lengths: List[int] - prefix_offsets: List[int] - read_offsets: List[int] - - # Generation helpers - next_token_chooser: HeterogeneousNextTokenChooser - stopping_criterias: List[StoppingCriteria] - top_n_tokens: List[int] - top_n_tokens_tensor: torch.Tensor - - # Metadata used for padding - max_input_length: int - padding_right_offset: int - - # Maximum number of tokens this batch will grow to - max_tokens: int - - # Past metadata - keys_head_dim_last: bool = True - - def to_pb(self) -> generate_pb2.CachedBatch: - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[r.id for r in self.requests], - size=len(self), - max_tokens=self.max_tokens, - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - is_optimized_for_gaudi: bool = False, - ) -> "CausalLMBatch": - inputs = [] - next_token_chooser_parameters = [] - stopping_criterias = [] - top_n_tokens = [] - prefix_offsets = [] - read_offsets = [] - requests_idx_mapping = {} - input_lengths = [] - - # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 - - # TODO: this should be set to rust side `max_total_tokens`, - # (see https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs#L177) - # but TGI does not offer an API to expose this variable to python, as this variable - # is handled by the client but it appears the model is initialized by the server. - # An alternative could be to initialize the buffers during warmup. - # Dummy - max_total_tokens = int(os.getenv("MAX_TOTAL_TOKENS", "0")) - logger.info("MAX_TOTAL_TOKENS = {}".format(max_total_tokens)) - - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - inputs.append(r.inputs) - next_token_chooser_parameters.append(r.parameters) - stopping_criteria = StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(r.top_n_tokens) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max(padding_right_offset, stopping_criteria.max_new_tokens) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device - ) - - tokenized_inputs = tokenizer( - inputs, - return_tensors="pt", - padding="max_length", - return_token_type_ids=False, - truncation=True, - max_length=max_truncation, - ) - - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - input_lengths.append(input_len) - prefix_offsets.append(input_len - 5) - read_offsets.append(input_len) - - max_input_length = max(input_lengths) - if max_total_tokens == 0: - max_total_tokens = max_input_length - max_tokens = len(inputs) * max_input_length + max_decode_tokens - if is_optimized_for_gaudi and max_total_tokens > max_input_length: - # pad to max_total_tokens in case max_new_token changes per request and triggers new hpu graph generation - padding_right_offset = max_total_tokens - max_input_length - - input_ids = tokenized_inputs["input_ids"] - attention_mask = tokenized_inputs["attention_mask"] - # only move model inputs to device - attention_mask = attention_mask.to(device) - - if is_optimized_for_gaudi: - input_ids_cpu = torch.nn.functional.pad( - input_ids, (0, padding_right_offset), value=tokenizer.pad_token_id - ) - input_ids = input_ids_cpu.to(device) - attention_mask = torch.nn.functional.pad(attention_mask, (0, padding_right_offset), value=0) - all_input_ids = input_ids_cpu.T.split(1, dim=1) - else: - all_input_ids = input_ids.clone().T.split(1, dim=1) - input_ids = input_ids.to(device) - - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - htorch.core.mark_step() - - top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_chooser=next_token_chooser, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, - max_tokens=max_tokens, - ) - - @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]: - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") - if len(request_ids) == len(self): - return self - - keep_indices = [] - - # New values after filtering - requests_idx_mapping = {} - requests = [] - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - max_input_length = 0 - - stopping_criterias = [] - top_n_tokens = [] - - total_remaining_decode_tokens = 0 - new_padding_right_offset = 0 - - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - requests_idx_mapping[request_id] = i - keep_indices.append(idx) - - requests.append(self.requests[idx]) - prefix_offsets.append(self.prefix_offsets[idx]) - read_offsets.append(self.read_offsets[idx]) - all_input_ids.append(self.all_input_ids[idx]) - - request_input_length = self.input_lengths[idx] - input_lengths.append(request_input_length) - max_input_length = max(max_input_length, request_input_length) - - stopping_criteria = self.stopping_criterias[idx] - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(self.top_n_tokens[idx]) - remaining_decode_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max(new_padding_right_offset, remaining_decode_tokens) - - # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - input_ids = self.input_ids[keep_indices] - position_ids = self.position_ids[keep_indices] - next_token_chooser = self.next_token_chooser.filter(keep_indices) - if is_optimized_for_gaudi: - self.attention_mask = self.attention_mask[keep_indices] - else: - self.attention_mask = self.attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - ] - - # Ensure that past_key_values tensors can be updated in-place - kv_tuple = False - if type(self.past_key_values[0]) == tuple: - self.past_key_values = [list(layer) for layer in self.past_key_values] - kv_tuple = True - - # Update tensors in-place to allow incremental garbage collection - past_kv_length = max_input_length - 1 - for layer in self.past_key_values: - past_keys, past_values = layer - past_keys_dims = len(past_keys.shape) - if past_keys_dims == 3: - # Force past to be of dim [self_size, num_heads, ...] for easy indexing - past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) - past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) - if is_optimized_for_gaudi: - layer[0] = past_keys[keep_indices] - del past_keys - layer[1] = past_values[keep_indices] - del past_values - else: - if self.keys_head_dim_last: - layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] - else: - layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] - del past_keys - layer[1] = past_values[keep_indices, :, -past_kv_length:, :] - del past_values - if past_keys_dims == 3: - layer[0] = layer[0].view(layer[0].shape[0] * layer[0].shape[1], *layer[0].shape[-2:]) - layer[1] = layer[1].view(layer[1].shape[0] * layer[1].shape[1], *layer[1].shape[-2:]) - - top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] - max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens - - if kv_tuple: - self.past_key_values = [tuple(layer) for layer in self.past_key_values] - - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids = input_ids - self.position_ids = position_ids - self.all_input_ids = all_input_ids - self.input_lengths = input_lengths - self.prefix_offsets = prefix_offsets - self.read_offsets = read_offsets - self.next_token_chooser = next_token_chooser - self.stopping_criterias = stopping_criterias - self.top_n_tokens = top_n_tokens - self.top_n_tokens_tensor = top_n_tokens_tensor - self.max_input_length = max_input_length - self.padding_right_offset = new_padding_right_offset - self.max_tokens = max_tokens - - return self - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch": - # Used for padding - total_batch_size = 0 - max_input_length = 0 - padding_right_offset = 0 - max_total_tokens = 0 - for batch in batches: - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.max_input_length) - padding_right_offset = max(padding_right_offset, batch.padding_right_offset) - max_total_tokens = max(max_total_tokens, batch.max_input_length + batch.padding_right_offset) - - if is_optimized_for_gaudi and max_total_tokens > max_input_length: - padding_right_offset = max_total_tokens - max_input_length - - # Batch attributes - requests = [] - requests_idx_mapping = {} - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - next_token_chooser_parameters = [] - stopping_criterias = [] - top_n_tokens = [] - max_tokens = 0 - - # Batch tensors - input_ids = None - attention_mask = None - position_ids = None - past_key_values = [] - top_n_tokens_tensor = None - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - prefix_offsets.extend(batch.prefix_offsets) - read_offsets.extend(batch.read_offsets) - all_input_ids.extend(batch.all_input_ids) - next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) - stopping_criterias.extend(batch.stopping_criterias) - top_n_tokens.extend(batch.top_n_tokens) - - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - # We need to offset the mapping for each batch by the cumulative batch size - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((total_batch_size, 1)) - # Copy to correct indices - input_ids[start_index:end_index] = batch.input_ids - - # Create padded tensor - if attention_mask is None: - attention_mask = batch.attention_mask.new_zeros( - (total_batch_size, max_input_length + padding_right_offset), - ) - - if top_n_tokens_tensor is None: - top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( - total_batch_size, - ) - top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - - # We need to slice the attention mask to remove padding from previous steps - # and to remove unused allocated space - left_offset = max_input_length - batch.max_input_length - batch_left_offset = batch.attention_mask.shape[1] - batch.max_input_length - batch.padding_right_offset - attention_mask[start_index:end_index, left_offset:-padding_right_offset] = batch.attention_mask[ - :, - batch_left_offset : -batch.padding_right_offset, - ] - - # Create empty tensor - # position_ids is always of shape [batch_size, 1] - if position_ids is None: - position_ids = batch.position_ids.new_empty((total_batch_size, 1)) - position_ids[start_index:end_index] = batch.position_ids - - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - # And ensure that we can update tensors in-place - kv_tuple = False - past_key_values_dims = len(batch.past_key_values[0][0].shape) - if type(batch.past_key_values[0]) == tuple: - batch.past_key_values = [ - [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values - ] - kv_tuple = True - elif past_key_values_dims == 3: - for layer in batch.past_key_values: - for k, t in enumerate(layer): - layer[k] = t.view(len(batch), -1, *t.shape[-2:]) - - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch) - - start_index = end_index - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, - dtype=batches[0].next_token_chooser.dtype, - device=batches[0].next_token_chooser.device, - ) - - first_past_kvs = batches[0].past_key_values - _, num_heads, _, head_dim = first_past_kvs[0][1].shape - padded_sequence_length = ( - max_input_length + padding_right_offset if is_optimized_for_gaudi else max_input_length - 1 - ) - padded_past_values_shape = ( - total_batch_size, - num_heads, - padded_sequence_length, - head_dim, - ) - - if batches[0].keys_head_dim_last: - padded_past_keys_shape = padded_past_values_shape - else: - # seq_length is last for BLOOM - padded_past_keys_shape = ( - total_batch_size, - num_heads, - head_dim, - padded_sequence_length, - ) - - # Iterate over attention layers - # Concatenate past key values layer by layer to allow incremental garbage collection - for j in range(len(first_past_kvs)): - padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) - start_index = 0 - for batch in batches: - past_keys = batch.past_key_values[j][0] - # Clear reference to the original tensor - batch.past_key_values[j][0] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - # recaculate the offset - left_offset = max_input_length - batch.max_input_length - batch_left_offset = batch.attention_mask.shape[1] - batch.max_input_length - batch.padding_right_offset - - if batch.keys_head_dim_last: - padded_past_keys[ - start_index:end_index, :, left_offset : left_offset + past_seq_len, : - ] = past_keys[:, :, batch_left_offset : batch_left_offset + past_seq_len, :] - else: - # BLOOM case - padded_past_keys[ - start_index:end_index, :, :, left_offset : left_offset + past_seq_len - ] = past_keys[:, :, :, batch_left_offset : batch_left_offset + past_seq_len] - del past_keys - - start_index = end_index - - padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape) - start_index = 0 - for batch in batches: - past_values = batch.past_key_values[j][1] - # Clear reference to the original tensor - batch.past_key_values[j][1] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the past values to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - # recaculate the offset - left_offset = max_input_length - batch.max_input_length - batch_left_offset = batch.attention_mask.shape[1] - batch.max_input_length - batch.padding_right_offset - - padded_past_values[ - start_index:end_index, :, left_offset : left_offset + past_seq_len, : - ] = past_values[:, :, batch_left_offset : batch_left_offset + past_seq_len, :] - del past_values - - # Update values - start_index = end_index - - if past_key_values_dims == 3: - padded_past_keys = padded_past_keys.view( - padded_past_keys.shape[0] * padded_past_keys.shape[1], *padded_past_keys.shape[-2:] - ) - padded_past_values = padded_past_values.view( - padded_past_values.shape[0] * padded_past_values.shape[1], *padded_past_values.shape[-2:] - ) - - if kv_tuple: - past_key_values.append((padded_past_keys, padded_past_values)) - else: - past_key_values.append([padded_past_keys, padded_past_values]) - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_chooser=next_token_chooser, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, - keys_head_dim_last=batches[0].keys_head_dim_last, - max_tokens=max_tokens, - ) - - def __len__(self): - return len(self.requests) - - -class CausalLM(Model): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ): - device = torch.device("hpu") - - dtype = torch.bfloat16 if dtype is None else dtype - - from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi - - adapt_transformers_to_gaudi() - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - ) - - model_kwargs = { - "revision": revision, - } - - world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK"), 0) - self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" - self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" - - if world_size > 1: - import habana_frameworks.torch.hpu as torch_hpu - - # Get world size, rank and local rank - from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu - - world_size, rank, local_rank = initialize_distributed_hpu() - import deepspeed - - # Initialize process(es) for DeepSpeed - deepspeed.init_distributed(dist_backend="hccl") - logger.info( - "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank) - ) - config = AutoConfig.from_pretrained(model_id, **model_kwargs) - load_to_meta = model_on_meta(config) - - if load_to_meta: - # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load - with deepspeed.OnDevice(dtype=dtype, device="meta"): - model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) - else: - get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) - # TODO: revisit placement on CPU when auto-injection is possible - with deepspeed.OnDevice(dtype=dtype, device="cpu"): - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs) - model = model.eval() - - # Initialize the model - ds_inference_kwargs = {"dtype": dtype} - ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} - ds_inference_kwargs["enable_cuda_graph"] = self.enable_hpu_graph - - if load_to_meta: - # model loaded to meta is managed differently - checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - write_checkpoints_json(model_id, local_rank, checkpoints_json) - ds_inference_kwargs["checkpoint"] = checkpoints_json.name - model = deepspeed.init_inference(model, **ds_inference_kwargs) - model = model.module - else: - get_repo_root(model_id) - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - ) - model = model.eval().to(device) - #wrap in hpu_graph only if self.enable_hpu_graph is set - if self.enable_hpu_graph: - model = wrap_in_hpu_graph(model) - - if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: - self.is_optimized_for_gaudi = True - else: - self.is_optimized_for_gaudi = False - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - tokenizer.pad_token_id = model.config.eos_token_id - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - kwargs = { - "use_cache": True, - "return_dict": True, - } - - if model.config.model_type == "llama": - kwargs["attn_softmax_bf16"] = True - kwargs["trim_logits"] = True - - super(CausalLM, self).__init__( - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - kwargs=kwargs, - ) - self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) - self.profiling_steps = int(os.getenv("PROF_STEP", "5")) - output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") - self.hb_profer = HabanaProfile( - warmup=self.profiling_warmup_steps, active=self.profiling_steps, output_dir=output_dir - ) - if self.profiling_warmup_steps > 0: - self.hb_profer_started = True - self.hb_profer.start() - else: - self.hb_profer = None - self.hb_profer_started = False - self.step = 0 - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return CausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - - def forward( - self, - input_ids, - attention_mask, - position_ids, - token_idx: Optional = None, - past_key_values: Optional = None, - bypass_hpu_graph: Optional = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # Model Forward - kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - } - - if self.is_optimized_for_gaudi: - kwargs["token_idx"] = token_idx - - if self.has_position_ids: - kwargs["position_ids"] = position_ids - - if bypass_hpu_graph != None: - kwargs["bypass_hpu_graphs"] = bypass_hpu_graph - - kwargs.update(self.kwargs) - outputs = self.model.forward(**kwargs) - return outputs.logits, outputs.past_key_values - - @tracer.start_as_current_span("generate_token") - def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]: - self.step = self.step + 1 - if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps: - self.hb_profer.stop() - self.hb_profer_started = False - - if self.is_optimized_for_gaudi: - token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.padding_right_offset).to(self.device) - attention_mask = batch.attention_mask - - else: - token_idx = None - # slice the attention mask to the correct shape - attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - prefill = batch.past_key_values is None - if batch.past_key_values: - if token_idx is not None: - input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) - else: - input_ids = batch.input_ids - - logits, past = self.forward( - input_ids, - attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, - bypass_hpu_graph = prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ) - - # Results - generations: List[Generation] = [] - stopped = True - - # Select next token - input_length = batch.input_lengths[0] - if self.is_optimized_for_gaudi and logits.shape[-2] > 1: - next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser( - batch.input_ids[:, :token_idx], logits[:, input_length - 1 : input_length, :].squeeze(-2) - ) - else: - next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser( - batch.input_ids[:, :token_idx], logits.squeeze(-2) - ) - - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - logprobs, - ) - - htorch.core.mark_step() - logits = logits.to("cpu") - - next_token_logprobs = next_token_logprobs.tolist() - next_token_ids = next_input_ids - - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - logits, - batch.next_token_chooser.do_sample, - batch.next_token_chooser.seeds, - batch.stopping_criterias, - batch.all_input_ids, - batch.top_n_tokens, - next_token_ids, - next_token_logprobs, - batch_top_token_ids, - batch_top_token_logprobs, - ) - # For each member of the batch - for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - do_sample, - seed, - stopping_criteria, - all_input_ids, - top_n_tokens, - next_token_id, - next_token_logprob, - top_token_ids, - top_token_logprobs, - ) in enumerate(iterator): - # Append next token to all tokens - if self.is_optimized_for_gaudi: - all_input_ids[input_length] = next_token_id - else: - all_input_ids = torch.cat([all_input_ids, next_token_id]) - new_input_length = input_length + 1 - - # Generated token - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[0:new_input_length, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text = self.decode( - all_input_ids[new_input_length - stopping_criteria.current_tokens : new_input_length, 0] - ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, - ) - else: - generated_text = None - - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + next_token_logprobs - prefill_token_ids = all_input_ids[0 : new_input_length - 1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = PrefillTokens(prefill_token_ids, prefill_logprobs, prefill_texts) - else: - prefill_tokens = None - - if top_n_tokens > 0: - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [token_id in self.all_special_ids for token_id in top_token_ids] - top_tokens = TopTokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - else: - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - next_token_id, - next_token_logprob, - next_token_text, - next_token_id in self.all_special_ids, - generated_text, - top_tokens, - ) - - generations.append(generation) - - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] = new_input_length - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.max_input_length = max(batch.max_input_length, new_input_length) - - next_tokens = torch.tensor(next_token_ids, dtype=torch.int64).to(self.device) - if token_idx is None: - batch.input_ids[:, 0] = next_tokens[:, 0] - else: - batch.input_ids[:, token_idx] = next_tokens - # We finished all generations in the batch; there is no next batch - if stopped: - if self.hb_profer_started == True: - self.hb_profer.step() - return generations, None - - # Slice unused values from prefill, use it to store next token - if token_idx is None: - batch.input_ids = batch.input_ids[:, :1] - - # Update attention_mask as we added a new token to input_ids - if self.is_optimized_for_gaudi: - batch.attention_mask.index_fill_(1, token_idx, 1) - else: - batch.attention_mask[:, -batch.padding_right_offset] = 1 - # Decrease right offset - batch.padding_right_offset -= 1 - - # Update position_ids - if prefill: - batch.position_ids = batch.position_ids[:, token_idx - 1 : token_idx] + 1 - else: - batch.position_ids += 1 - # Update past key values - batch.past_key_values = past - if self.hb_profer_started == True: - self.hb_profer.step() - - return generations, batch diff --git a/text-generation-inference/server/text_generation_server/models/model.py b/text-generation-inference/server/text_generation_server/models/model.py deleted file mode 100644 index 73e1f1af33..0000000000 --- a/text-generation-inference/server/text_generation_server/models/model.py +++ /dev/null @@ -1,90 +0,0 @@ -import inspect -import torch - -from abc import ABC, abstractmethod -from typing import List, Optional, Tuple, Type, TypeVar -from transformers import PreTrainedTokenizerBase - -from text_generation_server.models.types import Batch, GeneratedText -from text_generation_server.pb.generate_pb2 import InfoResponse - -B = TypeVar("B", bound=Batch) - - -class Model(ABC): - def __init__( - self, - model: torch.nn.Module, - tokenizer: PreTrainedTokenizerBase, - requires_padding: bool, - dtype: torch.dtype, - device: torch.device, - rank: int = 0, - world_size: int = 1, - kwargs: dict = {}, - ): - self.model = model - self.tokenizer = tokenizer - self.all_special_ids = set(tokenizer.all_special_ids) - self.requires_padding = requires_padding - self.dtype = dtype - self.device = device - self.rank = rank - self.world_size = world_size - self.kwargs = kwargs - self.has_position_ids = inspect.signature(model.forward).parameters.get("position_ids", None) is not None - - self.check_initialized() - - @property - def info(self) -> InfoResponse: - return InfoResponse( - requires_padding=self.requires_padding, - dtype=str(self.dtype), - device_type=self.device.type, - ) - - @property - @abstractmethod - def batch_type(self) -> Type[B]: - raise NotImplementedError - - @abstractmethod - def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: - raise NotImplementedError - - def warmup(self, batch: B, max_total_tokens: int): - self.generate_token(batch) - - def decode_token( - self, - all_input_ids: List[int], - prefix_offset: int = 0, - read_offset: int = 0, - ) -> Tuple[str, int, int]: - """Hack to hopefully support generate_stream for the maximum number of tokenizers""" - - # The prefix text is necessary only to defeat cleanup algorithms in the decode - # which decide to add a space or not depending on the surrounding ids. - prefix_text = self.tokenizer.decode(all_input_ids[prefix_offset:read_offset], skip_special_tokens=False) - new_text = self.tokenizer.decode(all_input_ids[prefix_offset:], skip_special_tokens=False) - - if len(new_text) > len(prefix_text) and not new_text.endswith("�"): - # utf-8 char at the end means it's a potential unfinished byte sequence - # from byte fallback tokenization. - # If it's in the middle, it's probably a real invalid id generated - # by the model - new_text = new_text[len(prefix_text) :] - return new_text, read_offset, len(all_input_ids) - else: - return "", prefix_offset, read_offset - - def check_initialized(self): - uninitialized_parameters = [] - for n, p in self.model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" - ) diff --git a/text-generation-inference/server/text_generation_server/models/santacoder.py b/text-generation-inference/server/text_generation_server/models/santacoder.py deleted file mode 100644 index ee37a03a34..0000000000 --- a/text-generation-inference/server/text_generation_server/models/santacoder.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Optional, List -import torch - -from text_generation_server.models import CausalLM - -FIM_PREFIX = "" -FIM_MIDDLE = "" -FIM_SUFFIX = "" -FIM_PAD = "" -EOD = "<|endoftext|>" - - -class SantaCoder(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ): - super().__init__(model_id=model_id, revision=revision, dtype=dtype) - - self.tokenizer.add_special_tokens( - { - "additional_special_tokens": [ - EOD, - FIM_PREFIX, - FIM_MIDDLE, - FIM_SUFFIX, - FIM_PAD, - ], - "pad_token": EOD, - } - ) - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False) diff --git a/text-generation-inference/server/text_generation_server/server.py b/text-generation-inference/server/text_generation_server/server.py deleted file mode 100644 index b7ab751bdc..0000000000 --- a/text-generation-inference/server/text_generation_server/server.py +++ /dev/null @@ -1,193 +0,0 @@ -import asyncio -import os -import sys -import torch - -from grpc import aio -from loguru import logger - -from grpc_reflection.v1alpha import reflection -from pathlib import Path -from typing import List, Optional - -from text_generation_server.cache import Cache -from text_generation_server.interceptor import ExceptionInterceptor -from text_generation_server.models import Model, get_model -from text_generation_server.pb import generate_pb2_grpc, generate_pb2 -from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor - - -class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): - def __init__(self, model: Model, cache: Cache, server_urls: List[str]): - self.cache = cache - self.model = model - self.server_urls = server_urls - # For some reason, inference_mode does not work well with GLOO which we use on CPU - # TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul - # op not optimized issue. Will investigate further. - # if model.device.type == "hpu": - # Force inference mode for the lifetime of TextGenerationService - # self._inference_mode_raii_guard = torch._C._InferenceMode(True) - - async def Info(self, request, context): - return self.model.info - - async def Health(self, request, context): - if self.model.device.type == "hpu": - torch.zeros((2, 2)).to("hpu") - return generate_pb2.HealthResponse() - - async def ServiceDiscovery(self, request, context): - return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) - - async def ClearCache(self, request, context): - if request.HasField("id"): - self.cache.delete(request.id) - else: - self.cache.clear() - return generate_pb2.ClearCacheResponse() - - async def FilterBatch(self, request, context): - batch = self.cache.pop(request.batch_id) - if batch is None: - raise ValueError(f"Batch ID {request.batch_id} not found in cache.") - filtered_batch = batch.filter(request.request_ids, self.model.is_optimized_for_gaudi) - self.cache.set(filtered_batch) - - return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) - - async def Warmup(self, request, context): - # batch = self.model.batch_type.from_pb( - # request.batch, self.model.tokenizer, self.model.dtype, self.model.device - # ) - # max_supported_total_tokens = self.model.warmup(batch) - - # return generate_pb2.WarmupResponse( - # max_supported_total_tokens=max_supported_total_tokens - # ) - logger.warning("Warmup is not enabled on HPU.") - return generate_pb2.WarmupResponse() - - async def Prefill(self, request, context): - batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi - ) - - generations, next_batch = self.model.generate_token(batch) - self.cache.set(next_batch) - - return generate_pb2.PrefillResponse( - generations=[generation.to_pb() for generation in generations], - batch=next_batch.to_pb() if next_batch else None, - ) - - async def Decode(self, request, context): - if len(request.batches) == 0: - raise ValueError("Must provide at least one batch") - - batches = [] - for batch_pb in request.batches: - batch = self.cache.pop(batch_pb.id) - if batch is None: - raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") - batches.append(batch) - - if len(batches) == 0: - raise ValueError("All batches are empty") - - if len(batches) > 1: - batch = self.model.batch_type.concatenate(batches, self.model.is_optimized_for_gaudi) - else: - batch = batches[0] - - generations, next_batch = self.model.generate_token(batch) - self.cache.set(next_batch) - - return generate_pb2.DecodeResponse( - generations=[generation.to_pb() for generation in generations], - batch=next_batch.to_pb() if next_batch else None, - ) - - -def serve( - model_id: str, - revision: Optional[str], - dtype: Optional[str], - uds_path: Path, - sharded: bool, -): - # Remove default handler - logger.remove() - logger.add( - sys.stdout, - format="{message}", - filter="text_generation_server", - level="INFO", - serialize=False, - backtrace=True, - diagnose=False, - ) - - async def serve_inner( - model_id: str, - revision: Optional[str], - dtype: Optional[str] = None, - sharded: bool = False, - ): - unix_socket_template = "unix://{}-{}" - logger.info("Server:server_inner: sharded ={}".format(sharded)) - - if sharded: - rank = int(os.environ["RANK"]) - logger.info("Server:server_inner: rank ={}".format(rank)) - server_urls = [ - unix_socket_template.format(uds_path, rank) for rank in range(int(os.environ["WORLD_SIZE"])) - ] - local_url = server_urls[int(os.environ["RANK"])] - else: - local_url = unix_socket_template.format(uds_path, 0) - server_urls = [local_url] - - logger.info("Server:server_inner: data type = {}, local_url = {}".format(dtype, local_url)) - if dtype == "bfloat16" or None: - data_type = torch.bfloat16 - else: - data_type = torch.float - try: - model = get_model(model_id, revision=revision, dtype=data_type) - except Exception: - logger.exception("Error when initializing model") - raise - - server = aio.server( - interceptors=[ - ExceptionInterceptor(), - UDSOpenTelemetryAioServerInterceptor(), - ] - ) - generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( - TextGenerationService(model, Cache(), server_urls), server - ) - SERVICE_NAMES = ( - generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name, - reflection.SERVICE_NAME, - ) - reflection.enable_server_reflection(SERVICE_NAMES, server) - server.add_insecure_port(local_url) - - await server.start() - - logger.info("Server started at {}".format(local_url)) - - try: - await server.wait_for_termination() - except KeyboardInterrupt: - logger.info("Signal received. Shutting down") - await server.stop(0) - - logger.info( - "Starting Server : model_id= {}, revision = {} dtype = {} sharded = {} ".format( - model_id, revision, dtype, sharded - ) - ) - asyncio.run(serve_inner(model_id, revision, dtype, sharded)) diff --git a/text-generation-inference/server/text_generation_server/tgi_service.py b/text-generation-inference/server/text_generation_server/tgi_service.py deleted file mode 100644 index bf1bab4096..0000000000 --- a/text-generation-inference/server/text_generation_server/tgi_service.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -from pathlib import Path -from loguru import logger -import sys -from text_generation_server import server -import argparse - - -def main(args): - logger.info("TGIService: starting tgi service .... ") - logger.info( - "TGIService: --model_id {}, --revision {}, --sharded {}, --dtype {}, --uds_path {} ".format( - args.model_id, args.revision, args.sharded, args.dtype, args.uds_path - ) - ) - server.serve( - model_id=args.model_id, revision=args.revision, dtype=args.dtype, uds_path=args.uds_path, sharded=args.sharded - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model_id", type=str) - parser.add_argument("--revision", type=str) - parser.add_argument("--sharded", type=bool) - parser.add_argument("--dtype", type=str) - parser.add_argument("--uds_path", type=Path) - args = parser.parse_args() - main(args) diff --git a/text-generation-inference/server/text_generation_server/utils/dist.py b/text-generation-inference/server/text_generation_server/utils/dist.py deleted file mode 100644 index ad170e4435..0000000000 --- a/text-generation-inference/server/text_generation_server/utils/dist.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import torch - -from datetime import timedelta -from loguru import logger - -# Tensor Parallelism settings -RANK = int(os.getenv("RANK", "0")) -WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) - -# CUDA memory fraction -MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) - - -class FakeBarrier: - def wait(self): - pass - - -class FakeGroup: - def __init__(self, rank, size): - self._rank = rank - self._size = size - - def allreduce(self, *args, **kwargs): - return FakeBarrier() - - def allgather(self, inputs, local_tensor, **kwargs): - assert ( - len(inputs[0]) == len(local_tensor) == 1 - ), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" - for input_ in inputs: - input_[0].data = local_tensor[0].data - return FakeBarrier() - - def barrier(self, *args, **kwargs): - return FakeBarrier() - - def size(self): - return self._size - - def rank(self): - return self._rank - - -def initialize_torch_distributed(): - import habana_frameworks.torch.core as htcore - - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - - options = None - if torch.cuda.is_available(): - from torch.distributed import ProcessGroupNCCL - - # Set the device id. - assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" - device = RANK % torch.cuda.device_count() - torch.cuda.set_device(device) - torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) - backend = "nccl" - options = ProcessGroupNCCL.Options() - options.is_high_priority_stream = True - options._timeout = timedelta(seconds=60) - elif torch.hpu.is_available(): - backend = "hccl" - n_hpus = torch.hpu.device_count() - if world_size > n_hpus: - raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).") - else: - backend = "gloo" - - if WORLD_SIZE == 1: - return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE - else: - if os.getenv("DEBUG", None) == "1": - return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE - - if not torch.distributed.is_initialized(): - # Call the init process. - torch.distributed.init_process_group( - backend=backend, - world_size=WORLD_SIZE, - rank=RANK, - timeout=timedelta(seconds=60), - pg_options=options, - ) - else: - logger.warning("torch.distributed is already initialized.") - - return torch.distributed.group.WORLD, RANK, WORLD_SIZE diff --git a/text-generation-inference/server/text_generation_server/utils/logits_process.py b/text-generation-inference/server/text_generation_server/utils/logits_process.py deleted file mode 100644 index c515e4d386..0000000000 --- a/text-generation-inference/server/text_generation_server/utils/logits_process.py +++ /dev/null @@ -1,381 +0,0 @@ -import math -import torch -import habana_frameworks.torch.core as htcore - -from functools import lru_cache -from typing import Optional, List, Dict, Union - -from transformers import ( - LogitsWarper, - LogitsProcessor, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, - TypicalLogitsWarper, -) - -mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None - - -class StaticWarper: - def __init__( - self, - temperature=1.0, - top_k=None, - top_p=None, - typical_p=None, - ): - self.warpers = [] - - if temperature is not None and temperature != 1.0: - temperature = float(temperature) - self.warpers.append(TemperatureLogitsWarper(temperature)) - if top_k is not None and top_k != 0: - self.warpers.append(TopKLogitsWarper(top_k=top_k)) - if top_p is not None and top_p < 1.0: - self.warpers.append(TopPLogitsWarper(top_p=top_p)) - if typical_p is not None and typical_p < 1.0: - self.warpers.append(TypicalLogitsWarper(mass=typical_p)) - - self.hpu_graph = None - self.static_scores = None - self.static_warped_scores = None - self.static_next_logprob = None - - def __call__(self, scores): - if self.hpu_graph is None: - self.static_scores = scores.clone().contiguous() - self.static_warped_scores = scores.clone().contiguous() - self.static_next_logprob = scores.clone().contiguous() - self.hpu_graph = htcore.hpu.HPUGraph() - - with htcore.hpu.graph(self.hpu_graph): - local_scores = self.static_scores - for warper in self.warpers: - local_scores = warper(None, local_scores) - - self.static_warped_scores.copy_(local_scores) - # Compute logprobs - self.static_next_logprob.copy_(torch.log_softmax(self.static_warped_scores, -1)) - - self.static_scores.copy_(scores) - self.hpu_graph.replay() - - return self.static_warped_scores, self.static_next_logprob - - -@lru_cache(10) -def static_warper( - temperature: Optional[float], - top_k: Optional[int], - top_p: Optional[float], - typical_p: Optional[float], -) -> StaticWarper: - return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) - - -class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): - r""" - [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences. - This version allows for a separate value for each sample and runs inplace when possible. - It doesn't validate inputs. - - Args: - repetition_penalty (`List[float]`): - The parameter for repetition penalty. 1.0 means no penalty. See [this - paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. - """ - - def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): - self.penalty = penalty - self.penalty_tensor = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1) - - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - score = torch.gather(scores, 1, input_ids) - - # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability - score = torch.where(score < 0, score * self.penalty_tensor, score / self.penalty_tensor) - - scores.scatter_(1, input_ids, score) - return scores - - def filter(self, indices): - self.penalty = [self.penalty[i] for i in indices] - if any([x != 1.0 for x in self.penalty]): - self.penalty_tensor = self.penalty_tensor[indices] - return self - return None - - -class HeterogeneousTemperatureLogitsWarper: - r""" - [`LogitsWarper`] for temperature (exponential scaling output probability distribution). - This version allows for a separate value for each sample and runs inplace when possible. - It doesn't validate inputs. - - Args: - temperature (`float`): - The value used to module the logits distribution. - """ - - def __init__(self, temperature: List[float], dtype: torch.dtype, device: torch.device): - self.temperature = temperature - self.temperature_tensor = torch.tensor(temperature, dtype=dtype, device=device).unsqueeze(1) - - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - scores.div_(self.temperature_tensor) - return scores - - def filter(self, indices): - self.temperature = [self.temperature[i] for i in indices] - if any([x != 1.0 for x in self.temperature]): - self.temperature_tensor = self.temperature_tensor[indices] - return self - return None - - -class HeterogeneousTopPLogitsWarper(LogitsWarper): - """ - [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. - This version allows for a separate value for each sample and runs inplace when possible. - It doesn't validate inputs. - - Args: - top_p (`float`): - If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or - higher are kept for generation. - filter_value (`float`, *optional*, defaults to `-float("Inf")`): - All filtered values will be set to this float value. - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimum number of tokens that cannot be filtered. - """ - - def __init__( - self, - top_p: List[float], - dtype: torch.dtype, - device: torch.device, - filter_value: float = -math.inf, - min_tokens_to_keep: int = 1, - ): - self.top_p = top_p - self.top_p_opposite = 1 - torch.tensor(top_p, dtype=dtype, device=device).unsqueeze(1) - self.filter_value = filter_value - self.min_tokens_to_keep = min_tokens_to_keep - - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - sorted_logits, sorted_indices = torch.sort(scores, descending=False) - probs = sorted_logits.softmax(dim=-1) - # This is way faster for some reason - for i in range(probs.shape[0]): - probs[i] = probs[i].cumsum(dim=-1) - - # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) - sorted_indices_to_remove = probs <= self.top_p_opposite - # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) - - return warped_scores - - def filter(self, indices): - self.top_p = [self.top_p[i] for i in indices] - if any([x < 1.0 for x in self.top_p]): - self.top_p_opposite = self.top_p_opposite[indices] - return self - return None - - -class HeterogeneousTopKLogitsWarper(LogitsWarper): - r""" - [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. - This version allows for a separate value for each sample and runs inplace when possible. - It doesn't validate inputs. - - Args: - top_k (`int`): - The number of highest probability vocabulary tokens to keep for top-k-filtering. - filter_value (`float`, *optional*, defaults to `-float("Inf")`): - All filtered values will be set to this float value. - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimum number of tokens that cannot be filtered. - """ - - def __init__( - self, - top_k: List[int], - device: torch.device, - filter_value: float = -math.inf, - min_tokens_to_keep: int = 1, - ): - self.top_k = top_k - self.max_top_k = max(top_k) - # value - 1 as we will use top_k to index and python uses 0 based numbering - self.top_k_tensor = torch.tensor( - [max(x - 1, min_tokens_to_keep - 1) for x in top_k], - dtype=torch.int64, - device=device, - ).unsqueeze(1) - - # 0 is a special value that disables top_k warping for this member of the batch - disabled = [x == 0 for x in top_k] - - if any(disabled): - self.top_k_disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device).view(-1, 1) - else: - self.top_k_disabled_mask = None - - self.filter_value = filter_value - - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - # If max_top_k is superior to the vocab, we need to clamp or the warper will fail - if scores.size(-1) < self.max_top_k: - max_top_k = scores.size(-1) - top_k = torch.clamp_max(self.top_k_tensor, max_top_k) - else: - max_top_k = self.max_top_k - top_k = self.top_k_tensor - - # Get the kth score for each member of the batch - kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k) - - # Mask member of kth_scores that do not want to use top_k warping - if self.top_k_disabled_mask is not None: - kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value) - - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = scores < kth_scores - scores.masked_fill_(indices_to_remove, self.filter_value) - return scores - - def filter(self, indices): - self.top_k = [self.top_k[i] for i in indices] - disabled = [x == 0 for x in self.top_k] - - if not all(disabled): - self.top_k_tensor = self.top_k_tensor[indices] - self.max_top_k = max(self.top_k) - - if self.top_k_disabled_mask is not None: - self.top_k_disabled_mask = self.top_k_disabled_mask[indices] if any(disabled) else None - - return self - return None - - -class HeterogeneousTypicalLogitsWarper(LogitsWarper): - r""" - [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language - Generation](https://arxiv.org/abs/2202.00666) for more information. - This version allows for a separate value for each sample and runs inplace when possible. - It doesn't validate inputs. - - Args: - mass (`float`): - Value of typical_p between 0 and 1 inclusive, defaults to 0.9. - filter_value (`float`, *optional*, defaults to `-float("Inf")`): - All filtered values will be set to this float value. - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimum number of tokens that cannot be filtered. - """ - - def __init__( - self, - mass: List[float], - dtype: torch.dtype, - device: torch.device, - filter_value: float = -math.inf, - min_tokens_to_keep: int = 1, - ): - self.mass = mass - self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1) - - # 1 is a special value that disables typical_p warping for this member of the batch - disabled = [x == 1.0 for x in mass] - - if any(disabled): - self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device) - else: - self.disabled_mask = None - - self.filter_value = filter_value - self.min_tokens_to_keep = min_tokens_to_keep - - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - # calculate entropy - normalized = torch.nn.functional.log_softmax(scores, dim=-1) - p = torch.exp(normalized) - ent = -(normalized * p).nansum(-1, keepdim=True) - - # shift and sort - shifted_scores = torch.abs((-normalized) - ent) - sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) - sorted_logits = scores.gather(-1, sorted_indices) - probs = sorted_logits.softmax(dim=-1) - # This is way faster for some reason - for i in range(probs.shape[0]): - probs[i] = probs[i].cumsum(dim=-1) - - # Remove tokens with cumulative mass above the threshold - last_ind = (probs < self.mass_tensor).sum(dim=1) - last_ind[last_ind < 0] = 0 - - if self.disabled_mask is not None: - last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1) - - sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) - if self.min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) - sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - - warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) - - return warped_scores - - def filter(self, indices): - self.mass = [self.mass[i] for i in indices] - disabled = [x == 1.0 for x in self.mass] - - if not all(disabled): - self.mass_tensor = self.mass_tensor[indices] - - if self.disabled_mask is not None: - self.disabled_mask = self.disabled_mask[indices] if any(disabled) else None - - return self - return None - - -class HeterogeneousProcessorWrapper(LogitsProcessor): - r""" - A wrapper for logit warpers or processors without heterogeneous parameter support. - Args: - processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): - A mapping of sample indices to logit warpers or processors, to be run sequentially. - """ - - def __init__( - self, - processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], - ): - self.processors = processors - - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - for i, processor in self.processors.items(): - scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1]) - return scores - - def filter(self, indices): - new_processors = {} - for i, idx in enumerate(indices): - if idx in self.processors: - new_processors[i] = self.processors[idx] - - if new_processors: - self.processors = new_processors - return self - return None diff --git a/text-generation-inference/server/text_generation_server/utils/tokens.py b/text-generation-inference/server/text_generation_server/utils/tokens.py deleted file mode 100644 index 55754002ee..0000000000 --- a/text-generation-inference/server/text_generation_server/utils/tokens.py +++ /dev/null @@ -1,361 +0,0 @@ -import re -from typing import Callable, List, Optional, Tuple - -import torch -from text_generation_server.pb import generate_pb2 -from text_generation_server.pb.generate_pb2 import FinishReason -from text_generation_server.utils.logits_process import ( - HeterogeneousProcessorWrapper, - HeterogeneousRepetitionPenaltyLogitsProcessor, - HeterogeneousTemperatureLogitsWarper, - HeterogeneousTopKLogitsWarper, - HeterogeneousTopPLogitsWarper, - HeterogeneousTypicalLogitsWarper, - static_warper, -) -from text_generation_server.utils.watermark import WatermarkLogitsProcessor -from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor - - -class NextTokenChooser: - def __init__( - self, - watermark=False, - temperature=1.0, - repetition_penalty=1.0, - top_k=None, - top_p=None, - typical_p=None, - do_sample=False, - seed=0, - device="cpu", - ): - self.watermark_processor = WatermarkLogitsProcessor(device=device) if watermark else None - self.repetition_processor = ( - RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) if repetition_penalty else None - ) - - has_warpers = ( - (temperature is not None and temperature != 1.0) - or (top_k is not None and top_k != 0) - or (top_p is not None and top_p < 1.0) - or (typical_p is not None and typical_p < 1.0) - ) - if has_warpers: - self.static_warper = static_warper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) - else: - self.static_warper = None - - sampling = do_sample or has_warpers - self.choice = Sampling(seed, device) if sampling else Greedy() - - def __call__(self, input_ids, scores): - if self.watermark_processor is not None: - scores = self.watermark_processor(input_ids, scores) - if self.repetition_processor is not None: - scores = self.repetition_processor(input_ids, scores) - - if self.static_warper is None: - next_logprob = torch.log_softmax(scores, -1) - else: - scores, next_logprob = self.static_warper(scores) - - next_id = self.choice(scores[-1]).view(1, 1) - - return next_id, next_logprob - - @classmethod - def from_pb( - cls, - pb: generate_pb2.NextTokenChooserParameters, - device: torch.device, - ) -> "NextTokenChooser": - return NextTokenChooser( - watermark=pb.watermark, - temperature=pb.temperature, - repetition_penalty=pb.repetition_penalty, - top_k=pb.top_k, - top_p=pb.top_p, - typical_p=pb.typical_p, - do_sample=pb.do_sample, - seed=pb.seed, - device=device, - ) - - -class StopSequenceCriteria: - def __init__(self, stop_sequence: str): - stop_sequence = re.escape(stop_sequence) - self.regex = re.compile(f".*{stop_sequence}$") - - def __call__(self, output: str) -> bool: - if self.regex.findall(output): - return True - return False - - -class StoppingCriteria: - def __init__( - self, - eos_token_id: int, - stop_sequence_criterias: List[StopSequenceCriteria], - max_new_tokens: int = 20, - ignore_eos_token: bool = False, - ): - self.eos_token_id = eos_token_id - self.stop_sequence_criterias = stop_sequence_criterias - self.max_new_tokens = max_new_tokens - self.current_tokens = 0 - self.current_output = "" - self.ignore_eos_token = ignore_eos_token - - def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: - self.current_tokens += 1 - if self.current_tokens >= self.max_new_tokens: - return True, FinishReason.FINISH_REASON_LENGTH - - if not self.ignore_eos_token and last_token == self.eos_token_id: - return True, FinishReason.FINISH_REASON_EOS_TOKEN - - self.current_output += last_output - for stop_sequence_criteria in self.stop_sequence_criterias: - if stop_sequence_criteria(self.current_output): - return True, FinishReason.FINISH_REASON_STOP_SEQUENCE - - return False, None - - @classmethod - def from_pb( - cls, - pb: generate_pb2.StoppingCriteriaParameters, - tokenizer: PreTrainedTokenizerBase, - ) -> "StoppingCriteria": - stop_sequence_criterias = [StopSequenceCriteria(sequence) for sequence in pb.stop_sequences] - return StoppingCriteria( - tokenizer.eos_token_id, - stop_sequence_criterias, - pb.max_new_tokens, - pb.ignore_eos_token, - ) - - -class HeterogeneousNextTokenChooser: - def __init__( - self, - dtype: torch.dtype, - device: torch.device, - watermark: List[bool], - temperature: List[float], - repetition_penalty: List[float], - top_k: List[int], - top_p: List[float], - typical_p: List[float], - do_sample: List[bool], - seeds: List[int], - ): - warpers = [] - - self.watermark_processor = ( - HeterogeneousProcessorWrapper( - { - i: WatermarkLogitsProcessor(device=device) - for i, do_watermark in enumerate(watermark) - if do_watermark - } - ) - if any(watermark) - else None - ) - - self.repetition_processor = ( - HeterogeneousRepetitionPenaltyLogitsProcessor(repetition_penalty, dtype, device) - if any([x != 1.0 for x in repetition_penalty]) - else None - ) - - if any([x != 1.0 for x in temperature]): - do_sample = [sample or x != 1.0 for x, sample in zip(temperature, do_sample)] - warpers.append(HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)) - - if any([x != 0 for x in top_k]): - do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] - warpers.append(HeterogeneousTopKLogitsWarper(top_k, device)) - - if any([x < 1.0 for x in top_p]): - do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)] - warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device)) - - if any([x < 1.0 for x in typical_p]): - do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)] - warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device)) - - self.warpers = warpers - - if any(do_sample): - self.choice = HeterogeneousSampling(do_sample, seeds, device) - else: - self.choice = Greedy() - - self.seeds = seeds - self.do_sample = do_sample - self.dtype = dtype - self.device = device - - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): - if self.watermark_processor is not None: - scores = self.watermark_processor(input_ids, scores) - if self.repetition_processor is not None: - scores = self.repetition_processor(input_ids, scores) - - for warper in self.warpers: - scores = warper(input_ids, scores) - - next_ids = self.choice(scores) - logprobs = torch.log_softmax(scores, -1) - next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) - - return next_ids, next_logprobs, logprobs - - def filter(self, indices): - if self.watermark_processor is not None: - self.watermark_processor = self.watermark_processor.filter(indices) - - if self.repetition_processor is not None: - self.repetition_processor = self.repetition_processor.filter(indices) - - filtered_warpers = [] - for warper in self.warpers: - filtered_warper = warper.filter(indices) - if filtered_warper is not None: - filtered_warpers.append(filtered_warper) - self.warpers = filtered_warpers - - self.seeds = [self.seeds[i] for i in indices] - self.do_sample = [self.do_sample[i] for i in indices] - - if any(self.do_sample): - self.choice.filter(indices) - else: - self.choice = Greedy() - - return self - - @classmethod - def from_pb( - cls, - pb: List[generate_pb2.NextTokenChooserParameters], - dtype: torch.dtype, - device: torch.device, - ) -> "HeterogeneousNextTokenChooser": - return HeterogeneousNextTokenChooser( - watermark=[pb_.watermark for pb_ in pb], - temperature=[pb_.temperature for pb_ in pb], - repetition_penalty=[pb_.repetition_penalty for pb_ in pb], - top_k=[pb_.top_k for pb_ in pb], - top_p=[pb_.top_p for pb_ in pb], - typical_p=[pb_.typical_p for pb_ in pb], - do_sample=[pb_.do_sample for pb_ in pb], - seeds=[pb_.seed for pb_ in pb], - device=device, - dtype=dtype, - ) - - -class Sampling: - def __init__(self, seed: int, device: str = "cpu"): - self.generator = torch.Generator("cpu") - self.generator.manual_seed(seed) - self.seed = seed - - def __call__(self, logits): - probs = torch.nn.functional.softmax(logits, -1) - # Avoid GPU<->CPU sync done by torch multinomial - # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 - q = torch.empty_like(probs).exponential_(1, generator=self.generator) - return probs.div_(q).argmax() - - -class Greedy: - def __call__(self, logits): - return logits.argmax(dim=-1) - - -class HeterogeneousSampling: - r""" - Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample. - """ - - def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device): - self.seeds = seeds - - self.greedy_indices = [] - self.sampling_mapping = {} - for i, (sample, seed) in enumerate(zip(do_sample, seeds)): - if sample: - self.sampling_mapping[i] = Sampling(seed, device) - else: - self.greedy_indices.append(i) - - self.greedy = Greedy() - - def __call__(self, logits): - out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) - if self.greedy_indices: - # Computing for all indices is faster than slicing - torch.argmax(logits, -1, out=out) - - for i, sampling in self.sampling_mapping.items(): - out[i] = sampling(logits[i]) - return out - - def filter(self, indices): - new_greedy_indices = [] - new_sampling_mapping = {} - for i, idx in enumerate(indices): - if idx in self.sampling_mapping: - new_sampling_mapping[i] = self.sampling_mapping[idx] - else: - new_greedy_indices.append(i) - - self.greedy_indices = new_greedy_indices - self.sampling_mapping = new_sampling_mapping - return self - - -def batch_top_tokens( - top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor -) -> Tuple[List[List[int]], List[List[float]]]: - """Find the top n most likely tokens for a batch of generations. - - When multiple tokens have equal probabilities and they don't all fit, the - remaining tokens are also returned. - """ - max_top_n = max(top_n_tokens) - # Early exit when top_n_tokens is not used - if max_top_n == 0: - return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens) - - # Ensure top_n doesn't exceed vocab size - top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens] - - # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 - # Sorted topk is faster than torch.sort() since we only need a small subset - sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values - nth_highest = torch.gather(sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)) - nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min - - # Find the new "fuzzy" top n values - top_n_indices = (logprobs >= nth_highest).nonzero() - _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True) - - # Take a new topk for these new max n values - top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True) - - top_n_ishes = top_n_ishes.tolist() - top_indices = top_k.indices.tolist() - top_values = top_k.values.tolist() - - return ( - [idxs[:n] if req_n > 0 else [] for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)], - [vals[:n] if req_n > 0 else [] for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)], - ) diff --git a/text-generation-inference/server/text_generation_server/utils/watermark.py b/text-generation-inference/server/text_generation_server/utils/watermark.py deleted file mode 100644 index 7f4bf3676f..0000000000 --- a/text-generation-inference/server/text_generation_server/utils/watermark.py +++ /dev/null @@ -1,86 +0,0 @@ -# coding=utf-8 -# Copyright 2023 Authors of "A Watermark for Large Language Models" -# available at https://arxiv.org/abs/2301.10226 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -import torch -from transformers import LogitsProcessor -from typing import List, Union - -GAMMA = float(os.getenv("WATERMARK_GAMMA", 0.5)) -DELTA = float(os.getenv("WATERMARK_DELTA", 2.0)) - - -class WatermarkLogitsProcessor(LogitsProcessor): - def __init__( - self, - gamma: float = GAMMA, - delta: float = DELTA, - hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width - device: str = "cpu", - ): - # watermarking parameters - self.gamma = gamma - self.delta = delta - self.rng = torch.Generator(device="cpu") - self.hash_key = hash_key - - def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]): - if isinstance(input_ids, list): - assert len(input_ids) >= 1, "requires at least a 1 token prefix sequence to seed rng" - prev_token = input_ids[-1] - else: - assert len(input_ids) == 1 - input_ids = input_ids[0] - assert input_ids.shape[-1] >= 1, "requires at least a 1 token prefix sequence to seed rng" - prev_token = input_ids[-1].item() - self.rng.manual_seed(self.hash_key * prev_token) - - def _get_greenlist_ids( - self, - input_ids: Union[List[int], torch.LongTensor], - max_value: int, - device: torch.device, - ) -> List[int]: - # seed the rng using the previous tokens/prefix - self._seed_rng(input_ids) - - greenlist_size = int(max_value * self.gamma) - vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng) - greenlist_ids = vocab_permutation[:greenlist_size] - return greenlist_ids - - @staticmethod - def _calc_greenlist_mask(scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor: - green_tokens_mask = torch.zeros_like(scores) - green_tokens_mask[-1, greenlist_token_ids] = 1 - final_mask = green_tokens_mask.bool() - return final_mask - - @staticmethod - def _bias_greenlist_logits( - scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float - ) -> torch.Tensor: - scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias - return scores - - def __call__(self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor) -> torch.FloatTensor: - greenlist_ids = self._get_greenlist_ids(input_ids, scores.shape[-1], scores.device) - green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=greenlist_ids) - - scores = self._bias_greenlist_logits( - scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta - ) - return scores