Skip to content

Commit

Permalink
Unit Tests for llama-index-llms-bedrock-converse (run-llama#16379)
Browse files Browse the repository at this point in the history
* Removed unused `llama-index-llms-anthropic` dependency. Incremented to `0.3.0`.

* Expanded range for `pytest` and `pytest-mock` to support `pytest-asyncio`

* Unit tests for main functions

* remove lock

* make coverage checks more narrow

* rename test

* update makefile

* wrong arg names

* even better workflow

* improve check

* try again?

* ok, i think this works

* Streamlined unit tests.

* Consolidated mock exception

---------

Co-authored-by: Logan Markewich <[email protected]>
  • Loading branch information
brycecf and logan-markewich authored Oct 6, 2024
1 parent af6ea71 commit 535d0a4
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 14 deletions.
71 changes: 71 additions & 0 deletions .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
name: Check Coverage

on:
push:
branches:
- main
pull_request:

env:
POETRY_VERSION: "1.8.3"

jobs:
test:
runs-on: ubuntu-latest-unit-tester
strategy:
fail-fast: false
matrix:
python-version: ["3.9"]
steps:
- name: clear space
env:
CI: true
shell: bash
run: rm -rf /opt/hostedtoolcache/*
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: update rustc
run: rustup update stable
- name: Install Poetry
run: pipx install poetry==${{ env.POETRY_VERSION }}
- name: Set up python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "poetry"
cache-dependency-path: "**/poetry.lock"
- uses: pantsbuild/actions/init-pants@v5-scie-pants
with:
# v0 makes it easy to bust the cache if needed
# just increase the integer to start with a fresh cache
gha-cache-key: v1-py${{ matrix.python_version }}
named-caches-hash: v1-py${{ matrix.python_version }}
pants-python-version: ${{ matrix.python-version }}
pants-ci-config: pants.toml
- name: Check BUILD files
run: |
pants tailor --check :: -docs/::
- name: Run coverage checks on changed packages
run: |
# Get the changed files
CHANGED_FILES=$(pants list --changed-since=origin/main)
# Find which roots contain changed files
FILTER_PATTERNS="["
for file in $CHANGED_FILES; do
root=$(echo "$file" | cut -d'/' -f1,2,3)
if [[ ! "$FILTER_PATTERNS" =~ "$root" ]]; then
FILTER_PATTERNS="${FILTER_PATTERNS}'${root}',"
fi
done
# remove the last comma and close the bracket
FILTER_PATTERNS="${FILTER_PATTERNS%,}]"
echo "Coverage filter patterns: $FILTER_PATTERNS"
pants --level=error --no-local-cache test \
--test-use-coverage \
--changed-since=origin/main \
--coverage-py-filter="$FILTER_PATTERNS"
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files

test: ## Run tests via pants
pants --level=error --no-local-cache --changed-since=origin/main --changed-dependents=transitive test
pants --level=error --no-local-cache --changed-since=origin/main --changed-dependents=transitive --no-test-use-coverage test

test-core: ## Run tests via pants
pants --no-local-cache test llama-index-core/::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
pytest = ">=7.2.1"
pytest-asyncio = "^0.24.0"
pytest-mock = ">=3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,162 @@
from llama_index.core.base.llms.base import BaseLLM
import pytest
from llama_index.llms.bedrock_converse import BedrockConverse
from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
MessageRole,
CompletionResponse,
)
from llama_index.core.callbacks import CallbackManager

# Expected values
EXP_RESPONSE = "Test"
EXP_STREAM_RESPONSE = ["Test ", "value"]
EXP_MAX_TOKENS = 100
EXP_TEMPERATURE = 0.7
EXP_MODEL = "anthropic.claude-v2"

def test_text_inference_embedding_class():
names_of_base_classes = [b.__name__ for b in BedrockConverse.__mro__]
assert BaseLLM.__name__ in names_of_base_classes
# Reused chat message and prompt
messages = [ChatMessage(role=MessageRole.USER, content="Test")]
prompt = "Test"


class MockExceptions:
class ThrottlingException(Exception):
pass


class AsyncMockClient:
def __init__(self) -> "AsyncMockClient":
self.exceptions = MockExceptions()

async def __aenter__(self) -> "AsyncMockClient":
return self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
pass

async def converse(self, *args, **kwargs):
return {"output": {"message": {"content": [{"text": EXP_RESPONSE}]}}}

async def converse_stream(self, *args, **kwargs):
async def stream_generator():
for element in EXP_STREAM_RESPONSE:
yield {"contentBlockDelta": {"delta": {"text": element}}}

return {"stream": stream_generator()}


class MockClient:
def __init__(self) -> "MockClient":
self.exceptions = MockExceptions()

def converse(self, *args, **kwargs):
return {"output": {"message": {"content": [{"text": EXP_RESPONSE}]}}}

def converse_stream(self, *args, **kwargs):
def stream_generator():
for element in EXP_STREAM_RESPONSE:
yield {"contentBlockDelta": {"delta": {"text": element}}}

return {"stream": stream_generator()}


class MockAsyncSession:
def __init__(self, *args, **kwargs) -> "MockAsyncSession":
pass

def client(self, *args, **kwargs):
return AsyncMockClient()


@pytest.fixture()
def mock_boto3_session(monkeypatch):
def mock_client(*args, **kwargs):
return MockClient()

monkeypatch.setattr("boto3.Session.client", mock_client)


@pytest.fixture()
def mock_aioboto3_session(monkeypatch):
monkeypatch.setattr("aioboto3.Session", MockAsyncSession)


@pytest.fixture()
def bedrock_converse(mock_boto3_session, mock_aioboto3_session):
return BedrockConverse(
model=EXP_MODEL,
max_tokens=EXP_MAX_TOKENS,
temperature=EXP_TEMPERATURE,
callback_manager=CallbackManager(),
)


def test_init(bedrock_converse):
assert bedrock_converse.model == EXP_MODEL
assert bedrock_converse.max_tokens == EXP_MAX_TOKENS
assert bedrock_converse.temperature == EXP_TEMPERATURE
assert bedrock_converse._client is not None


def test_chat(bedrock_converse):
response = bedrock_converse.chat(messages)

assert response.message.role == MessageRole.ASSISTANT
assert response.message.content == EXP_RESPONSE


def test_complete(bedrock_converse):
response = bedrock_converse.complete(prompt)

assert isinstance(response, CompletionResponse)
assert response.text == EXP_RESPONSE
assert response.additional_kwargs["status"] == []
assert response.additional_kwargs["tool_call_id"] == []
assert response.additional_kwargs["tool_calls"] == []


def test_stream_chat(bedrock_converse):
response_stream = bedrock_converse.stream_chat(messages)

for response in response_stream:
assert response.message.role == MessageRole.ASSISTANT
assert response.delta in EXP_STREAM_RESPONSE


@pytest.mark.asyncio()
async def test_achat(bedrock_converse):
response = await bedrock_converse.achat(messages)

assert isinstance(response, ChatResponse)
assert response.message.role == MessageRole.ASSISTANT
assert response.message.content == EXP_RESPONSE


@pytest.mark.asyncio()
async def test_astream_chat(bedrock_converse):
response_stream = await bedrock_converse.astream_chat(messages)

responses = []
async for response in response_stream:
assert response.message.role == MessageRole.ASSISTANT
assert response.delta in EXP_STREAM_RESPONSE


@pytest.mark.asyncio()
async def test_acomplete(bedrock_converse):
response = await bedrock_converse.acomplete(prompt)

assert isinstance(response, CompletionResponse)
assert response.text == EXP_RESPONSE
assert response.additional_kwargs["status"] == []
assert response.additional_kwargs["tool_call_id"] == []
assert response.additional_kwargs["tool_calls"] == []


@pytest.mark.asyncio()
async def test_astream_complete(bedrock_converse):
response_stream = await bedrock_converse.astream_complete(prompt)

async for response in response_stream:
assert response.delta in EXP_STREAM_RESPONSE
7 changes: 0 additions & 7 deletions pants.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@ config = "./pyproject.toml"

[coverage-py]
fail_under = 50
filter = [
'llama-index-core/',
'llama-index-experimental/',
'llama-index-finetuning/',
'llama-index-integrations/',
'llama-index-utils/',
]
global_report = false
report = ["console", "html", "xml"]

Expand Down

0 comments on commit 535d0a4

Please sign in to comment.