Skip to content

Commit

Permalink
⚙️ Fix Integration Test for TGI (#124)
Browse files Browse the repository at this point in the history
* fix(tests): fix broken GPT2 integration test

* feat(tests): add Gemma integration test

* refactor(logging): migrate to Loguru

* fix(tests): fix broken connection to docker container

* refractor(tests): make run arguments to model config

* refractor(tests): make run arguments to model config
  • Loading branch information
baptistecolle authored Dec 5, 2024
1 parent d180634 commit 9d44709
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 149 deletions.
7 changes: 4 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ clean:
rm -rf dist deps
make -C text-generation-inference/server/ clean

# ulimit nofile=100000:100000 is required for TPUs
# https://cloud.google.com/kubernetes-engine/docs/how-to/tpus#privileged-mode
tpu-tgi:
docker build --rm -f text-generation-inference/docker/Dockerfile \
--build-arg VERSION=$(VERSION) \
--build-arg TGI_VERSION=$(TGI_VERSION) \
--ulimit nofile=100000:100000 \
-t huggingface/optimum-tpu:$(VERSION)-tgi .
--ulimit nofile=100000:100000 \
-t huggingface/optimum-tpu:$(VERSION)-tgi .
docker tag huggingface/optimum-tpu:$(VERSION)-tgi huggingface/optimum-tpu:latest

tpu-tgi-ie:
Expand All @@ -64,7 +66,6 @@ style_check:
ruff check .

style:
ruff check . --fix

# Utilities to release to PyPi
build_dist_install_tools:
Expand Down
233 changes: 177 additions & 56 deletions text-generation-inference/integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import contextlib
import os
import shlex
import signal
import subprocess
import sys
import threading
import time
from tempfile import TemporaryDirectory
from typing import List
Expand All @@ -12,34 +14,85 @@
import pytest
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound
from loguru import logger
from test_model import MODEL_CONFIGS
from text_generation import AsyncClient
from text_generation.types import Response


DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tpu-tgi:latest")
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "huggingface/optimum-tpu:latest")
HF_TOKEN = os.getenv("HF_TOKEN", None)
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")

logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
level="INFO"
)


def cleanup_handler(signum, frame):
logger.info("\nCleaning up containers due to shutdown, please wait...")
try:
client = docker.from_env()
containers = client.containers.list(filters={"name": "tgi-tests-"})
for container in containers:
try:
container.stop()
container.remove()
logger.info(f"Successfully cleaned up container {container.name}")
except Exception as e:
logger.error(f"Error cleaning up container {container.name}: {str(e)}")
except Exception as e:
logger.error(f"Error during cleanup: {str(e)}")
sys.exit(1)

signal.signal(signal.SIGINT, cleanup_handler)
signal.signal(signal.SIGTERM, cleanup_handler)

def stream_container_logs(container):
"""Stream container logs in a separate thread."""
try:
for log in container.logs(stream=True, follow=True):
print("[TGI Server Logs] " + log.decode("utf-8"), end="", file=sys.stderr, flush=True)
except Exception as e:
logger.error(f"Error streaming container logs: {str(e)}")


class LauncherHandle:
def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}")
self.client = AsyncClient(f"http://localhost:{port}", timeout=3600)

def _inner_health(self):
raise NotImplementedError

async def health(self, timeout: int = 60):
assert timeout > 0
for _ in range(timeout):
start_time = time.time()
logger.info(f"Starting health check with timeout of {timeout}s")

for attempt in range(timeout):
if not self._inner_health():
logger.error("Launcher crashed during health check")
raise RuntimeError("Launcher crashed")

try:
await self.client.generate("test")
elapsed = time.time() - start_time
logger.info(f"Health check passed after {elapsed:.1f}s")
return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
if attempt == timeout - 1:
logger.error(f"Health check failed after {timeout}s: {str(e)}")
raise RuntimeError(f"Health check failed: {str(e)}")
logger.debug(f"Connection attempt {attempt+1}/{timeout} failed: {str(e)}")
time.sleep(1)
raise RuntimeError("Health check failed")
except Exception as e:
logger.error(f"Unexpected error during health check: {str(e)}")
# Get full traceback for debugging
import traceback
logger.error(f"Full traceback:\n{traceback.format_exc()}")
raise


class ContainerLauncherHandle(LauncherHandle):
Expand All @@ -49,8 +102,18 @@ def __init__(self, docker_client, container_name, port: int):
self.container_name = container_name

def _inner_health(self) -> bool:
container = self.docker_client.containers.get(self.container_name)
return container.status in ["running", "created"]
try:
container = self.docker_client.containers.get(self.container_name)
status = container.status
if status not in ["running", "created"]:
logger.warning(f"Container status is {status}")
# Get container logs for debugging
logs = container.logs().decode("utf-8")
logger.debug(f"Container logs:\n{logs}")
return status in ["running", "created"]
except Exception as e:
logger.error(f"Error checking container health: {str(e)}")
return False


class ProcessLauncherHandle(LauncherHandle):
Expand All @@ -62,95 +125,153 @@ def _inner_health(self) -> bool:
return self.process.poll() is None


@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()


@pytest.fixture(scope="module")
def data_volume():
tmpdir = TemporaryDirectory()
yield tmpdir.name
# Cleanup the temporary directory using sudo as it contains root files created by the container
subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"))
try:
# Cleanup the temporary directory using sudo as it contains root files created by the container
subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"), check=True)
except subprocess.CalledProcessError as e:
logger.error(f"Error cleaning up temporary directory: {str(e)}")


@pytest.fixture(scope="module")
def launcher(event_loop, data_volume):
def launcher(data_volume):
@contextlib.contextmanager
def docker_launcher(
model_id: str,
trust_remote_code: bool = False,
):
# TODO: consider finding out how to forward a port in the container instead of leaving it to 80.
# For now this is necessary because TPU dockers require to run with net=host and privileged mode.
port = 80

args = ["--model-id", model_id, "--env"]

if trust_remote_code:
args.append("--trust-remote-code")
logger.info(f"Starting docker launcher for model {model_id}")
port = 8080

client = docker.from_env()

container_name = f"tgi-tests-{model_id.split('/')[-1]}"

try:
container = client.containers.get(container_name)
logger.info(f"Stopping existing container {container_name}")
container.stop()
container.wait()
except NotFound:
pass
except Exception as e:
logger.error(f"Error handling existing container: {str(e)}")

env = {"LOG_LEVEL": "info,text_generation_router=debug"}
model_name = next(name for name, cfg in MODEL_CONFIGS.items() if cfg["model_id"] == model_id)

args = MODEL_CONFIGS[model_name]["args"].copy()
if trust_remote_code:
args.append("--trust-remote-code")

env = {
"LOG_LEVEL": "info,text_generation_router,text_generation_launcher=debug",
"HF_HUB_ENABLE_HF_TRANSFER": "0"
}
env.update(MODEL_CONFIGS[model_name]["env_config"].copy())

if HUGGING_FACE_HUB_TOKEN is not None:
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN

# Add model_id to env
env["MODEL_ID"] = model_id

if HF_TOKEN is not None:
env["HF_TOKEN"] = HF_TOKEN

for var in ["MAX_BATCH_SIZE", "HF_SEQUENCE_LENGTH"]:
if var in os.environ:
env[var] = os.environ[var]

volumes = [f"{data_volume}:/data"]

container = client.containers.run(
DOCKER_IMAGE,
command=args,
name=container_name,
environment=env,
auto_remove=False,
detach=True,
volumes=volumes,
shm_size="1G",
privileged=True,
network_mode="host",
)

yield ContainerLauncherHandle(client, container.name, port)

try:
container.stop()
container.wait()
except NotFound:
pass
# Add debug logging before container creation
logger.debug(f"Creating container with image {DOCKER_IMAGE}")
logger.debug(f"Container environment: {env}")
logger.debug(f"Container volumes: {volumes}")

# Log equivalent docker run command
env_str = ' '.join([f'-e {k}="{v}"' for k,v in env.items()])
volume_str = ' '.join([f'-v {v}' for v in volumes])
cmd_str = f'docker run -d --name {container_name} {env_str} {volume_str} --shm-size 16G --privileged --ipc host {DOCKER_IMAGE} {" ".join(args)}'
logger.debug(f"Equivalent docker run command:\n{cmd_str}")

container = client.containers.run(
DOCKER_IMAGE,
command=args,
name=container_name,
environment=env,
auto_remove=False,
detach=True,
volumes=volumes,
shm_size="16G",
privileged=True,
ipc_mode="host",
ports={"80/tcp": 8080}
)
logger.info(f"Container {container_name} started successfully")

# Start log streaming in a background thread
log_thread = threading.Thread(
target=stream_container_logs,
args=(container,),
daemon=True # This ensures the thread will be killed when the main program exits
)
log_thread.start()

# Add a small delay to allow container to initialize
time.sleep(2)

# Check container status after creation
status = container.status
logger.debug(f"Initial container status: {status}")
if status not in ["running", "created"]:
logs = container.logs().decode("utf-8")
logger.error(f"Container failed to start properly. Logs:\n{logs}")

yield ContainerLauncherHandle(client, container.name, port)

except Exception as e:
logger.error(f"Error starting container: {str(e)}")
# Get full traceback for debugging
import traceback
logger.error(f"Full traceback:\n{traceback.format_exc()}")
raise
finally:
try:
container = client.containers.get(container_name)
logger.info(f"Stopping container {container_name}")
container.stop()
container.wait()

container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)
container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)

container.remove()
container.remove()
logger.info(f"Container {container_name} removed successfully")
except NotFound:
pass
except Exception as e:
logger.error(f"Error cleaning up container: {str(e)}")

return docker_launcher


@pytest.fixture(scope="module")
def generate_load():
async def generate_load_inner(client: AsyncClient, prompt: str, max_new_tokens: int, n: int) -> List[Response]:
futures = [
client.generate(prompt, max_new_tokens=max_new_tokens, decoder_input_details=True) for _ in range(n)
]

return await asyncio.gather(*futures)
try:
futures = [
client.generate(
prompt,
max_new_tokens=max_new_tokens,
decoder_input_details=True,
) for _ in range(n)
]
return await asyncio.gather(*futures)
except Exception as e:
logger.error(f"Error generating load: {str(e)}")
raise

return generate_load_inner
Loading

0 comments on commit 9d44709

Please sign in to comment.