-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable inference serving capabilities on sagemaker endpoint using tor…
…nado
- Loading branch information
Showing
12 changed files
with
400 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from __future__ import absolute_import | ||
|
||
import utils.logger |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#!/bin/bash | ||
python /etc/sagemaker-inference-server/serve.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from __future__ import absolute_import | ||
|
||
from tornado_server.server import TornadoServer | ||
|
||
inference_server = TornadoServer() | ||
inference_server.serve() |
12 changes: 12 additions & 0 deletions
12
template/v3/dirs/etc/sagemaker-inference-server/tornado_server/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from __future__ import absolute_import | ||
|
||
import pathlib | ||
import sys | ||
|
||
# make the utils modules accessible to modules from within the tornado_server folder | ||
utils_path = pathlib.Path(__file__).parent.parent / "utils" | ||
sys.path.insert(0, str(utils_path.resolve())) | ||
|
||
# make the tornado_server modules accessible to each other | ||
tornado_module_path = pathlib.Path(__file__).parent | ||
sys.path.insert(0, str(tornado_module_path.resolve())) |
66 changes: 66 additions & 0 deletions
66
template/v3/dirs/etc/sagemaker-inference-server/tornado_server/async_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from __future__ import absolute_import | ||
|
||
import asyncio | ||
import logging | ||
import tornado.web | ||
from utils.environment import Environment | ||
from utils.exception import AsyncInvocationsException | ||
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER | ||
|
||
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) | ||
|
||
|
||
class InvocationsHandler(tornado.web.RequestHandler): | ||
"""Handler mapped to the /invocations POST route. | ||
This handler wraps the async handler retrieved from the inference script | ||
and encapsulates it behind the post() method. The post() method is done | ||
asynchronously. | ||
""" | ||
|
||
def initialize(self, handler: callable, environment: Environment): | ||
"""Initializes the handler function and the serving environment.""" | ||
|
||
self._handler = handler | ||
self._environment = environment | ||
|
||
async def post(self): | ||
"""POST method used to encapsulate and invoke the async handle method asynchronously""" | ||
|
||
try: | ||
response = await self._handler(self.request) | ||
self.write(response) | ||
except Exception as e: | ||
raise AsyncInvocationsException(e) | ||
|
||
|
||
class PingHandler(tornado.web.RequestHandler): | ||
"""Handler mapped to the /ping GET route. | ||
Ping handler to monitor the health of the Tornados server. | ||
""" | ||
|
||
def get(self): | ||
"""Simple GET method to assess the health of the server.""" | ||
|
||
self.write("") | ||
|
||
|
||
async def handle(handler: callable, environment: Environment): | ||
"""Serves the async handler function using Tornado. | ||
Opens the /invocations and /ping routes used by a SageMaker Endpoint | ||
for inference serving capabilities. | ||
""" | ||
|
||
logger.info("Starting inference server in asynchronous mode...") | ||
|
||
app = tornado.web.Application( | ||
[ | ||
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)), | ||
(r"/ping", PingHandler), | ||
] | ||
) | ||
app.listen(environment.port) | ||
logger.debug(f"Asynchronous inference server listening on port: `{environment.port}`") | ||
await asyncio.Event().wait() |
119 changes: 119 additions & 0 deletions
119
template/v3/dirs/etc/sagemaker-inference-server/tornado_server/server.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from __future__ import absolute_import | ||
|
||
import asyncio | ||
import importlib | ||
import logging | ||
import subprocess | ||
import sys | ||
from pathlib import Path | ||
from utils.environment import Environment | ||
from utils.exception import InferenceCodeLoadException, RequirementsInstallException, ServerStartException | ||
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER | ||
|
||
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) | ||
|
||
|
||
class TornadoServer: | ||
"""Holds serving logic using the Tornado framework. | ||
The serve.py script will invoke TornadoServer.serve() to start the serving process. | ||
The TornadoServer will install the runtime requirements specified through a requirements file. | ||
It will then load an handler function within an inference script and then front it will an /invocations | ||
route using the Tornado framework. | ||
""" | ||
|
||
def __init__(self): | ||
"""Initialize the serving behaviors. | ||
Defines the serving behavior through Environment() and locate where | ||
the inference code is contained. | ||
""" | ||
|
||
self._environment = Environment() | ||
logger.setLevel(self._environment.logging_level) | ||
logger.debug(f"Environment: {str(self._environment)}") | ||
|
||
self._path_to_inference_code = ( | ||
Path(self._environment.base_directory).joinpath(self._environment.code_directory) | ||
if self._environment.code_directory | ||
else Path(self._environment.base_directory) | ||
) | ||
logger.debug(f"Path to inference code: `{str(self._path_to_inference_code)}`") | ||
|
||
self._handler = None | ||
|
||
def initialize(self): | ||
"""Initialize the serving artifacts and dependencies. | ||
Install the runtime requirements and then locate the handler function from | ||
the inference script. | ||
""" | ||
|
||
logger.info("Initializing inference server...") | ||
self._install_runtime_requirements() | ||
self._handler = self._load_inference_handler() | ||
|
||
def serve(self): | ||
"""Orchestrate the initialization and server startup behavior. | ||
Call the initalize() method, determine the right Tornado serving behavior (async or sync), | ||
and then start the Tornado server through asyncio | ||
""" | ||
|
||
logger.info("Serving inference requests using Tornado...") | ||
self.initialize() | ||
|
||
if asyncio.iscoroutinefunction(self._handler): | ||
import async_handler as inference_handler | ||
else: | ||
import sync_handler as inference_handler | ||
|
||
try: | ||
asyncio.run(inference_handler.handle(self._handler, self._environment)) | ||
except Exception as e: | ||
raise ServerStartException(e) | ||
|
||
def _install_runtime_requirements(self): | ||
"""Install the runtime requirements.""" | ||
|
||
logger.info("Installing runtime requirements...") | ||
requirements_txt = self._path_to_inference_code.joinpath(self._environment.requirements) | ||
if requirements_txt.is_file(): | ||
try: | ||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", str(requirements_txt)]) | ||
except Exception as e: | ||
raise RequirementsInstallException(e) | ||
else: | ||
logger.debug(f"No requirements file was found at `{str(requirements_txt)}`") | ||
|
||
def _load_inference_handler(self) -> callable: | ||
"""Load the handler function from the inference script.""" | ||
|
||
logger.info("Loading inference handler...") | ||
inference_module_name, handle_name = self._environment.code.split(".") | ||
if inference_module_name and handle_name: | ||
inference_module_file = f"{inference_module_name}.py" | ||
module_spec = importlib.util.spec_from_file_location( | ||
inference_module_file, str(self._path_to_inference_code.joinpath(inference_module_file)) | ||
) | ||
if module_spec: | ||
sys.path.insert(0, str(self._path_to_inference_code.resolve())) | ||
module = importlib.util.module_from_spec(module_spec) | ||
module_spec.loader.exec_module(module) | ||
|
||
if hasattr(module, handle_name): | ||
handler = getattr(module, handle_name) | ||
else: | ||
logger.info(dir(inference_module)) | ||
raise InferenceCodeLoadException( | ||
f"Handler `{handle_name}` could not be found in module `{inference_module_file}`" | ||
) | ||
logger.debug(f"Loaded handler `{handle_name}` from module `{inference_module_name}`") | ||
return handler | ||
else: | ||
raise InferenceCodeLoadException( | ||
f"Inference code could not be found at `{str(self._path_to_inference_code.joinpath(inference_module_file))}`" | ||
) | ||
raise InferenceCodeLoadException( | ||
f"Inference code expected in the format of `<module>.<handler>` but was provided as {self._environment.code}" | ||
) |
67 changes: 67 additions & 0 deletions
67
template/v3/dirs/etc/sagemaker-inference-server/tornado_server/sync_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from __future__ import absolute_import | ||
|
||
import asyncio | ||
import logging | ||
import tornado.web | ||
from utils.environment import Environment | ||
from utils.exception import SyncInvocationsException | ||
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER | ||
from tornado.ioloop import IOLoop | ||
|
||
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) | ||
|
||
|
||
class InvocationsHandler(tornado.web.RequestHandler): | ||
"""Handler mapped to the /invocations POST route. | ||
This handler wraps the sync handler retrieved from the inference script | ||
and encapsulates it behind the post() method. The post() method is done | ||
asynchronously. | ||
""" | ||
|
||
def initialize(self, handler: callable, environment: Environment): | ||
"""Initializes the handler function and the serving environment.""" | ||
|
||
self._handler = handler | ||
self._environment = environment | ||
|
||
async def post(self): | ||
"""POST method used to encapsulate and invoke the sync handle method asynchronously""" | ||
|
||
try: | ||
response = await IOLoop.current().run_in_executor(None, self._handler, self.request) | ||
self.write(response) | ||
except Exception as e: | ||
raise SyncInvocationsException(e) | ||
|
||
|
||
class PingHandler(tornado.web.RequestHandler): | ||
"""Handler mapped to the /ping GET route. | ||
Ping handler to monitor the health of the Tornados server. | ||
""" | ||
|
||
def get(self): | ||
"""Simple GET method to assess the health of the server.""" | ||
|
||
self.write("") | ||
|
||
|
||
async def handle(handler: callable, environment: Environment): | ||
"""Serves the sync handler function using Tornado. | ||
Opens the /invocations and /ping routes used by a SageMaker Endpoint | ||
for inference serving capabilities. | ||
""" | ||
|
||
logger.info("Starting inference server in synchronous mode...") | ||
|
||
app = tornado.web.Application( | ||
[ | ||
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)), | ||
(r"/ping", PingHandler), | ||
] | ||
) | ||
app.listen(environment.port) | ||
logger.debug(f"Synchronous inference server listening on port: `{environment.port}`") | ||
await asyncio.Event().wait() |
1 change: 1 addition & 0 deletions
1
template/v3/dirs/etc/sagemaker-inference-server/utils/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from __future__ import absolute_import |
59 changes: 59 additions & 0 deletions
59
template/v3/dirs/etc/sagemaker-inference-server/utils/environment.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from __future__ import absolute_import | ||
|
||
import json | ||
import os | ||
from enum import Enum | ||
|
||
|
||
class SageMakerInference(str, Enum): | ||
"""Simple enum to define the mapping between dictionary key and environement variable.""" | ||
|
||
BASE_DIRECTORY = "SAGEMAKER_INFERENCE_BASE_DIRECTORY" | ||
REQUIREMENTS = "SAGEMAKER_INFERENCE_REQUIREMENTS" | ||
CODE_DIRECTORY = "SAGEMAKER_INFERENCE_CODE_DIRECTORY" | ||
CODE = "SAGEMAKER_INFERENCE_CODE" | ||
LOGGING_LEVEL = "SAGEMAKER_INFERENCE_LOGGING_LEVEL" | ||
PORT = "SAGEMAKER_INFERENCE_PORT" | ||
|
||
|
||
class Environment: | ||
"""Retrieves and encapsulates SAGEMAKER_INFERENCE prefixed environment variables.""" | ||
|
||
def __init__(self): | ||
"""Initialize the environment variable mapping""" | ||
|
||
self._environment_variables = { | ||
SageMakerInference.BASE_DIRECTORY: "/opt/ml/model", | ||
SageMakerInference.REQUIREMENTS: "requirements.txt", | ||
SageMakerInference.CODE_DIRECTORY: os.getenv(SageMakerInference.CODE_DIRECTORY, None), | ||
SageMakerInference.CODE: os.getenv(SageMakerInference.CODE, "inference.handler"), | ||
SageMakerInference.LOGGING_LEVEL: os.getenv(SageMakerInference.LOGGING_LEVEL, 10), | ||
SageMakerInference.PORT: os.getenv(SageMakerInference.PORT, 8080), | ||
} | ||
|
||
def __str__(self): | ||
return json.dumps(self._environment_variables) | ||
|
||
@property | ||
def base_directory(self): | ||
return self._environment_variables.get(SageMakerInference.BASE_DIRECTORY) | ||
|
||
@property | ||
def requirements(self): | ||
return self._environment_variables.get(SageMakerInference.REQUIREMENTS) | ||
|
||
@property | ||
def code_directory(self): | ||
return self._environment_variables.get(SageMakerInference.CODE_DIRECTORY) | ||
|
||
@property | ||
def code(self): | ||
return self._environment_variables.get(SageMakerInference.CODE) | ||
|
||
@property | ||
def logging_level(self): | ||
return self._environment_variables.get(SageMakerInference.LOGGING_LEVEL) | ||
|
||
@property | ||
def port(self): | ||
return self._environment_variables.get(SageMakerInference.PORT) |
21 changes: 21 additions & 0 deletions
21
template/v3/dirs/etc/sagemaker-inference-server/utils/exception.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from __future__ import absolute_import | ||
|
||
|
||
class RequirementsInstallException(Exception): | ||
pass | ||
|
||
|
||
class InferenceCodeLoadException(Exception): | ||
pass | ||
|
||
|
||
class ServerStartException(Exception): | ||
pass | ||
|
||
|
||
class SyncInvocationsException(Exception): | ||
pass | ||
|
||
|
||
class AsyncInvocationsException(Exception): | ||
pass |
Oops, something went wrong.