Skip to content

Commit

Permalink
improved logging with sensible defaults (#391)
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya authored Dec 10, 2024
1 parent 00eae56 commit 4a41f7f
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/litserve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from litserve.loggers import Logger
from litserve.server import LitServer, Request, Response
from litserve.specs import OpenAIEmbeddingSpec, OpenAISpec
from litserve.utils import configure_logging

configure_logging()

__all__ = [
"LitAPI",
Expand Down
44 changes: 44 additions & 0 deletions src/litserve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dataclasses
import logging
import pickle
import sys
from contextlib import contextmanager
from typing import TYPE_CHECKING, AsyncIterator

Expand Down Expand Up @@ -87,3 +88,46 @@ class WorkerSetupStatus:
READY: str = "ready"
ERROR: str = "error"
FINISHED: str = "finished"


def configure_logging(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", stream=sys.stdout
):
"""Configure logging for the entire library with sensible defaults.
Args:
level (int): Logging level (default: logging.INFO)
format (str): Log message format string
stream (file-like): Output stream for logs
"""
# Create a library-wide handler
handler = logging.StreamHandler(stream)

# Set formatter with user-configurable format
formatter = logging.Formatter(format)
handler.setFormatter(formatter)

# Configure root library logger
library_logger = logging.getLogger("litserve")
library_logger.setLevel(level)
library_logger.addHandler(handler)

# Prevent propagation to root logger to avoid duplicate logs
library_logger.propagate = False


def set_log_level(level):
"""Allow users to set the global logging level for the library."""
logging.getLogger("litserve").setLevel(level)


def add_log_handler(handler):
"""Allow users to add custom log handlers.
Example usage:
file_handler = logging.FileHandler('library_logs.log')
add_log_handler(file_handler)
"""
logging.getLogger("litserve").addHandler(handler)
60 changes: 60 additions & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import io
import logging

import pytest

from litserve.utils import add_log_handler, configure_logging, set_log_level


@pytest.fixture
def log_stream():
return io.StringIO()


def test_configure_logging(log_stream):
# Configure logging with test stream
configure_logging(level=logging.DEBUG, stream=log_stream)

# Get logger and log a test message
logger = logging.getLogger("litserve")
test_message = "Test debug message"
logger.debug(test_message)

# Verify log output
log_contents = log_stream.getvalue()
assert test_message in log_contents
assert "DEBUG" in log_contents
assert logger.propagate is False


def test_set_log_level():
# Set log level to WARNING
set_log_level(logging.WARNING)

# Verify logger level
logger = logging.getLogger("litserve")
assert logger.level == logging.WARNING


def test_add_log_handler():
# Create and add a custom handler
stream = io.StringIO()
custom_handler = logging.StreamHandler(stream)
add_log_handler(custom_handler)

# Verify handler is added
logger = logging.getLogger("litserve")
assert custom_handler in logger.handlers

# Test the handler works
test_message = "Test handler message"
logger.info(test_message)
assert test_message in stream.getvalue()


@pytest.fixture(autouse=True)
def cleanup_logger():
yield
logger = logging.getLogger("litserve")
logger.handlers.clear()
logger.setLevel(logging.INFO)
19 changes: 13 additions & 6 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import io
import json
import threading
import time
Expand Down Expand Up @@ -229,7 +230,9 @@ def test_run_single_loop():
assert response == ("UUID-001", ({"output": 16.0}, LitAPIStatus.OK))


def test_run_single_loop_timeout(caplog):
def test_run_single_loop_timeout():
stream = io.StringIO()
ls.configure_logging(stream=stream)
lit_api = ls.test_examples.SimpleLitAPI()
lit_api.setup(None)
lit_api.request_timeout = 0.0001
Expand All @@ -248,7 +251,7 @@ def test_run_single_loop_timeout(caplog):

request_queue.put((None, None, None, None))
loop_thread.join()
assert "Request UUID-001 was waiting in the queue for too long" in caplog.text
assert "Request UUID-001 was waiting in the queue for too long" in stream.getvalue()
assert isinstance(response_queues[0].get()[1][0], HTTPException), "Timeout should return an HTTPException"


Expand Down Expand Up @@ -284,7 +287,9 @@ def test_run_batched_loop():
assert response_2 == ("UUID-002", ({"output": 25.0}, LitAPIStatus.OK))


def test_run_batched_loop_timeout(caplog):
def test_run_batched_loop_timeout():
stream = io.StringIO()
ls.configure_logging(stream=stream)
lit_api = ls.test_examples.SimpleBatchedAPI()
lit_api.setup(None)
lit_api._sanitize(2, None)
Expand All @@ -309,7 +314,7 @@ def test_run_batched_loop_timeout(caplog):
# Allow some time for the loop to process
time.sleep(1)

assert "Request UUID-001 was waiting in the queue for too long" in caplog.text
assert "Request UUID-001 was waiting in the queue for too long" in stream.getvalue()
resp1 = response_queues[0].get(timeout=10)[1]
resp2 = response_queues[0].get(timeout=10)[1]
assert isinstance(resp1[0], HTTPException), "First request was timed out"
Expand Down Expand Up @@ -348,7 +353,9 @@ def test_run_streaming_loop():
assert response == {"output": f"{i}: Hello"}


def test_run_streaming_loop_timeout(caplog):
def test_run_streaming_loop_timeout():
stream = io.StringIO()
ls.configure_logging(stream=stream)
lit_api = ls.test_examples.SimpleStreamAPI()
lit_api.setup(None)
lit_api.request_timeout = 0.1
Expand All @@ -370,7 +377,7 @@ def test_run_streaming_loop_timeout(caplog):
request_queue.put((None, None, None, None))
loop_thread.join()

assert "Request UUID-001 was waiting in the queue for too long" in caplog.text
assert "Request UUID-001 was waiting in the queue for too long" in stream.getvalue()
response = response_queues[0].get(timeout=10)[1]
assert isinstance(response[0], HTTPException), "request was timed out"

Expand Down

0 comments on commit 4a41f7f

Please sign in to comment.