Skip to content

Commit

Permalink
Enable inference serving capabilities on sagemaker endpoint using tor…
Browse files Browse the repository at this point in the history
…nado
  • Loading branch information
gwang111 committed Jan 9, 2025
1 parent 33b6986 commit 4f0921d
Show file tree
Hide file tree
Showing 12 changed files with 400 additions and 1 deletion.
2 changes: 1 addition & 1 deletion template/v3/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ RUN mkdir -p $SAGEMAKER_LOGGING_DIR && \
&& ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh ${HOME_DIR} python \
&& rm -rf ${HOME_DIR}/oss_compliance*

ENV PATH="/opt/conda/bin:/opt/conda/condabin:$PATH"
ENV PATH="/etc/sagemaker-inference-server:/opt/conda/bin:/opt/conda/condabin:$PATH"
WORKDIR "/home/${NB_USER}"
ENV SHELL=/bin/bash
ENV OPENSSL_MODULES=/opt/conda/lib64/ossl-modules/
Expand Down
3 changes: 3 additions & 0 deletions template/v3/dirs/etc/sagemaker-inference-server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import absolute_import

import utils.logger
2 changes: 2 additions & 0 deletions template/v3/dirs/etc/sagemaker-inference-server/serve
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/bash
python /etc/sagemaker-inference-server/serve.py
6 changes: 6 additions & 0 deletions template/v3/dirs/etc/sagemaker-inference-server/serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from __future__ import absolute_import

from tornado_server.server import TornadoServer

inference_server = TornadoServer()
inference_server.serve()
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import absolute_import

import pathlib
import sys

# make the utils modules accessible to modules from within the tornado_server folder
utils_path = pathlib.Path(__file__).parent.parent / "utils"
sys.path.insert(0, str(utils_path.resolve()))

# make the tornado_server modules accessible to each other
tornado_module_path = pathlib.Path(__file__).parent
sys.path.insert(0, str(tornado_module_path.resolve()))
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import absolute_import

import asyncio
import logging
import tornado.web
from utils.environment import Environment
from utils.exception import AsyncInvocationsException
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER

logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)


class InvocationsHandler(tornado.web.RequestHandler):
"""Handler mapped to the /invocations POST route.
This handler wraps the async handler retrieved from the inference script
and encapsulates it behind the post() method. The post() method is done
asynchronously.
"""

def initialize(self, handler: callable, environment: Environment):
"""Initializes the handler function and the serving environment."""

self._handler = handler
self._environment = environment

async def post(self):
"""POST method used to encapsulate and invoke the async handle method asynchronously"""

try:
response = await self._handler(self.request)
self.write(response)
except Exception as e:
raise AsyncInvocationsException(e)


class PingHandler(tornado.web.RequestHandler):
"""Handler mapped to the /ping GET route.
Ping handler to monitor the health of the Tornados server.
"""

def get(self):
"""Simple GET method to assess the health of the server."""

self.write("")


async def handle(handler: callable, environment: Environment):
"""Serves the async handler function using Tornado.
Opens the /invocations and /ping routes used by a SageMaker Endpoint
for inference serving capabilities.
"""

logger.info("Starting inference server in asynchronous mode...")

app = tornado.web.Application(
[
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)),
(r"/ping", PingHandler),
]
)
app.listen(environment.port)
logger.debug(f"Asynchronous inference server listening on port: `{environment.port}`")
await asyncio.Event().wait()
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from __future__ import absolute_import

import asyncio
import importlib
import logging
import subprocess
import sys
from pathlib import Path
from utils.environment import Environment
from utils.exception import InferenceCodeLoadException, RequirementsInstallException, ServerStartException
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER

logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)


class TornadoServer:
"""Holds serving logic using the Tornado framework.
The serve.py script will invoke TornadoServer.serve() to start the serving process.
The TornadoServer will install the runtime requirements specified through a requirements file.
It will then load an handler function within an inference script and then front it will an /invocations
route using the Tornado framework.
"""

def __init__(self):
"""Initialize the serving behaviors.
Defines the serving behavior through Environment() and locate where
the inference code is contained.
"""

self._environment = Environment()
logger.setLevel(self._environment.logging_level)
logger.debug(f"Environment: {str(self._environment)}")

self._path_to_inference_code = (
Path(self._environment.base_directory).joinpath(self._environment.code_directory)
if self._environment.code_directory
else Path(self._environment.base_directory)
)
logger.debug(f"Path to inference code: `{str(self._path_to_inference_code)}`")

self._handler = None

def initialize(self):
"""Initialize the serving artifacts and dependencies.
Install the runtime requirements and then locate the handler function from
the inference script.
"""

logger.info("Initializing inference server...")
self._install_runtime_requirements()
self._handler = self._load_inference_handler()

def serve(self):
"""Orchestrate the initialization and server startup behavior.
Call the initalize() method, determine the right Tornado serving behavior (async or sync),
and then start the Tornado server through asyncio
"""

logger.info("Serving inference requests using Tornado...")
self.initialize()

if asyncio.iscoroutinefunction(self._handler):
import async_handler as inference_handler
else:
import sync_handler as inference_handler

try:
asyncio.run(inference_handler.handle(self._handler, self._environment))
except Exception as e:
raise ServerStartException(e)

def _install_runtime_requirements(self):
"""Install the runtime requirements."""

logger.info("Installing runtime requirements...")
requirements_txt = self._path_to_inference_code.joinpath(self._environment.requirements)
if requirements_txt.is_file():
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", str(requirements_txt)])
except Exception as e:
raise RequirementsInstallException(e)
else:
logger.debug(f"No requirements file was found at `{str(requirements_txt)}`")

def _load_inference_handler(self) -> callable:
"""Load the handler function from the inference script."""

logger.info("Loading inference handler...")
inference_module_name, handle_name = self._environment.code.split(".")
if inference_module_name and handle_name:
inference_module_file = f"{inference_module_name}.py"
module_spec = importlib.util.spec_from_file_location(
inference_module_file, str(self._path_to_inference_code.joinpath(inference_module_file))
)
if module_spec:
sys.path.insert(0, str(self._path_to_inference_code.resolve()))
module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)

if hasattr(module, handle_name):
handler = getattr(module, handle_name)
else:
logger.info(dir(inference_module))
raise InferenceCodeLoadException(
f"Handler `{handle_name}` could not be found in module `{inference_module_file}`"
)
logger.debug(f"Loaded handler `{handle_name}` from module `{inference_module_name}`")
return handler
else:
raise InferenceCodeLoadException(
f"Inference code could not be found at `{str(self._path_to_inference_code.joinpath(inference_module_file))}`"
)
raise InferenceCodeLoadException(
f"Inference code expected in the format of `<module>.<handler>` but was provided as {self._environment.code}"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import absolute_import

import asyncio
import logging
import tornado.web
from utils.environment import Environment
from utils.exception import SyncInvocationsException
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER
from tornado.ioloop import IOLoop

logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)


class InvocationsHandler(tornado.web.RequestHandler):
"""Handler mapped to the /invocations POST route.
This handler wraps the sync handler retrieved from the inference script
and encapsulates it behind the post() method. The post() method is done
asynchronously.
"""

def initialize(self, handler: callable, environment: Environment):
"""Initializes the handler function and the serving environment."""

self._handler = handler
self._environment = environment

async def post(self):
"""POST method used to encapsulate and invoke the sync handle method asynchronously"""

try:
response = await IOLoop.current().run_in_executor(None, self._handler, self.request)
self.write(response)
except Exception as e:
raise SyncInvocationsException(e)


class PingHandler(tornado.web.RequestHandler):
"""Handler mapped to the /ping GET route.
Ping handler to monitor the health of the Tornados server.
"""

def get(self):
"""Simple GET method to assess the health of the server."""

self.write("")


async def handle(handler: callable, environment: Environment):
"""Serves the sync handler function using Tornado.
Opens the /invocations and /ping routes used by a SageMaker Endpoint
for inference serving capabilities.
"""

logger.info("Starting inference server in synchronous mode...")

app = tornado.web.Application(
[
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)),
(r"/ping", PingHandler),
]
)
app.listen(environment.port)
logger.debug(f"Synchronous inference server listening on port: `{environment.port}`")
await asyncio.Event().wait()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __future__ import absolute_import
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import absolute_import

import json
import os
from enum import Enum


class SageMakerInference(str, Enum):
"""Simple enum to define the mapping between dictionary key and environement variable."""

BASE_DIRECTORY = "SAGEMAKER_INFERENCE_BASE_DIRECTORY"
REQUIREMENTS = "SAGEMAKER_INFERENCE_REQUIREMENTS"
CODE_DIRECTORY = "SAGEMAKER_INFERENCE_CODE_DIRECTORY"
CODE = "SAGEMAKER_INFERENCE_CODE"
LOGGING_LEVEL = "SAGEMAKER_INFERENCE_LOGGING_LEVEL"
PORT = "SAGEMAKER_INFERENCE_PORT"


class Environment:
"""Retrieves and encapsulates SAGEMAKER_INFERENCE prefixed environment variables."""

def __init__(self):
"""Initialize the environment variable mapping"""

self._environment_variables = {
SageMakerInference.BASE_DIRECTORY: "/opt/ml/model",
SageMakerInference.REQUIREMENTS: "requirements.txt",
SageMakerInference.CODE_DIRECTORY: os.getenv(SageMakerInference.CODE_DIRECTORY, None),
SageMakerInference.CODE: os.getenv(SageMakerInference.CODE, "inference.handler"),
SageMakerInference.LOGGING_LEVEL: os.getenv(SageMakerInference.LOGGING_LEVEL, 10),
SageMakerInference.PORT: os.getenv(SageMakerInference.PORT, 8080),
}

def __str__(self):
return json.dumps(self._environment_variables)

@property
def base_directory(self):
return self._environment_variables.get(SageMakerInference.BASE_DIRECTORY)

@property
def requirements(self):
return self._environment_variables.get(SageMakerInference.REQUIREMENTS)

@property
def code_directory(self):
return self._environment_variables.get(SageMakerInference.CODE_DIRECTORY)

@property
def code(self):
return self._environment_variables.get(SageMakerInference.CODE)

@property
def logging_level(self):
return self._environment_variables.get(SageMakerInference.LOGGING_LEVEL)

@property
def port(self):
return self._environment_variables.get(SageMakerInference.PORT)
21 changes: 21 additions & 0 deletions template/v3/dirs/etc/sagemaker-inference-server/utils/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import absolute_import


class RequirementsInstallException(Exception):
pass


class InferenceCodeLoadException(Exception):
pass


class ServerStartException(Exception):
pass


class SyncInvocationsException(Exception):
pass


class AsyncInvocationsException(Exception):
pass
Loading

0 comments on commit 4f0921d

Please sign in to comment.