-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
234 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
name: Check Coverage | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
|
||
env: | ||
POETRY_VERSION: "1.8.3" | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest-unit-tester | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
python-version: ["3.9"] | ||
steps: | ||
- name: clear space | ||
env: | ||
CI: true | ||
shell: bash | ||
run: rm -rf /opt/hostedtoolcache/* | ||
- uses: actions/checkout@v4 | ||
with: | ||
fetch-depth: 0 | ||
- name: update rustc | ||
run: rustup update stable | ||
- name: Install Poetry | ||
run: pipx install poetry==${{ env.POETRY_VERSION }} | ||
- name: Set up python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
cache: "poetry" | ||
cache-dependency-path: "**/poetry.lock" | ||
- uses: pantsbuild/actions/init-pants@v5-scie-pants | ||
with: | ||
# v0 makes it easy to bust the cache if needed | ||
# just increase the integer to start with a fresh cache | ||
gha-cache-key: v1-py${{ matrix.python_version }} | ||
named-caches-hash: v1-py${{ matrix.python_version }} | ||
pants-python-version: ${{ matrix.python-version }} | ||
pants-ci-config: pants.toml | ||
- name: Check BUILD files | ||
run: | | ||
pants tailor --check :: -docs/:: | ||
- name: Run coverage checks on changed packages | ||
run: | | ||
# Get the changed files | ||
CHANGED_FILES=$(pants list --changed-since=origin/main) | ||
# Find which roots contain changed files | ||
FILTER_PATTERNS="[" | ||
for file in $CHANGED_FILES; do | ||
root=$(echo "$file" | cut -d'/' -f1,2,3) | ||
if [[ ! "$FILTER_PATTERNS" =~ "$root" ]]; then | ||
FILTER_PATTERNS="${FILTER_PATTERNS}'${root}'," | ||
fi | ||
done | ||
# remove the last comma and close the bracket | ||
FILTER_PATTERNS="${FILTER_PATTERNS%,}]" | ||
echo "Coverage filter patterns: $FILTER_PATTERNS" | ||
pants --level=error --no-local-cache test \ | ||
--test-use-coverage \ | ||
--changed-since=origin/main \ | ||
--coverage-py-filter="$FILTER_PATTERNS" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
163 changes: 159 additions & 4 deletions
163
...x-integrations/llms/llama-index-llms-bedrock-converse/tests/test_llms_bedrock_converse.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,162 @@ | ||
from llama_index.core.base.llms.base import BaseLLM | ||
import pytest | ||
from llama_index.llms.bedrock_converse import BedrockConverse | ||
from llama_index.core.base.llms.types import ( | ||
ChatMessage, | ||
ChatResponse, | ||
MessageRole, | ||
CompletionResponse, | ||
) | ||
from llama_index.core.callbacks import CallbackManager | ||
|
||
# Expected values | ||
EXP_RESPONSE = "Test" | ||
EXP_STREAM_RESPONSE = ["Test ", "value"] | ||
EXP_MAX_TOKENS = 100 | ||
EXP_TEMPERATURE = 0.7 | ||
EXP_MODEL = "anthropic.claude-v2" | ||
|
||
def test_text_inference_embedding_class(): | ||
names_of_base_classes = [b.__name__ for b in BedrockConverse.__mro__] | ||
assert BaseLLM.__name__ in names_of_base_classes | ||
# Reused chat message and prompt | ||
messages = [ChatMessage(role=MessageRole.USER, content="Test")] | ||
prompt = "Test" | ||
|
||
|
||
class MockExceptions: | ||
class ThrottlingException(Exception): | ||
pass | ||
|
||
|
||
class AsyncMockClient: | ||
def __init__(self) -> "AsyncMockClient": | ||
self.exceptions = MockExceptions() | ||
|
||
async def __aenter__(self) -> "AsyncMockClient": | ||
return self | ||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: | ||
pass | ||
|
||
async def converse(self, *args, **kwargs): | ||
return {"output": {"message": {"content": [{"text": EXP_RESPONSE}]}}} | ||
|
||
async def converse_stream(self, *args, **kwargs): | ||
async def stream_generator(): | ||
for element in EXP_STREAM_RESPONSE: | ||
yield {"contentBlockDelta": {"delta": {"text": element}}} | ||
|
||
return {"stream": stream_generator()} | ||
|
||
|
||
class MockClient: | ||
def __init__(self) -> "MockClient": | ||
self.exceptions = MockExceptions() | ||
|
||
def converse(self, *args, **kwargs): | ||
return {"output": {"message": {"content": [{"text": EXP_RESPONSE}]}}} | ||
|
||
def converse_stream(self, *args, **kwargs): | ||
def stream_generator(): | ||
for element in EXP_STREAM_RESPONSE: | ||
yield {"contentBlockDelta": {"delta": {"text": element}}} | ||
|
||
return {"stream": stream_generator()} | ||
|
||
|
||
class MockAsyncSession: | ||
def __init__(self, *args, **kwargs) -> "MockAsyncSession": | ||
pass | ||
|
||
def client(self, *args, **kwargs): | ||
return AsyncMockClient() | ||
|
||
|
||
@pytest.fixture() | ||
def mock_boto3_session(monkeypatch): | ||
def mock_client(*args, **kwargs): | ||
return MockClient() | ||
|
||
monkeypatch.setattr("boto3.Session.client", mock_client) | ||
|
||
|
||
@pytest.fixture() | ||
def mock_aioboto3_session(monkeypatch): | ||
monkeypatch.setattr("aioboto3.Session", MockAsyncSession) | ||
|
||
|
||
@pytest.fixture() | ||
def bedrock_converse(mock_boto3_session, mock_aioboto3_session): | ||
return BedrockConverse( | ||
model=EXP_MODEL, | ||
max_tokens=EXP_MAX_TOKENS, | ||
temperature=EXP_TEMPERATURE, | ||
callback_manager=CallbackManager(), | ||
) | ||
|
||
|
||
def test_init(bedrock_converse): | ||
assert bedrock_converse.model == EXP_MODEL | ||
assert bedrock_converse.max_tokens == EXP_MAX_TOKENS | ||
assert bedrock_converse.temperature == EXP_TEMPERATURE | ||
assert bedrock_converse._client is not None | ||
|
||
|
||
def test_chat(bedrock_converse): | ||
response = bedrock_converse.chat(messages) | ||
|
||
assert response.message.role == MessageRole.ASSISTANT | ||
assert response.message.content == EXP_RESPONSE | ||
|
||
|
||
def test_complete(bedrock_converse): | ||
response = bedrock_converse.complete(prompt) | ||
|
||
assert isinstance(response, CompletionResponse) | ||
assert response.text == EXP_RESPONSE | ||
assert response.additional_kwargs["status"] == [] | ||
assert response.additional_kwargs["tool_call_id"] == [] | ||
assert response.additional_kwargs["tool_calls"] == [] | ||
|
||
|
||
def test_stream_chat(bedrock_converse): | ||
response_stream = bedrock_converse.stream_chat(messages) | ||
|
||
for response in response_stream: | ||
assert response.message.role == MessageRole.ASSISTANT | ||
assert response.delta in EXP_STREAM_RESPONSE | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_achat(bedrock_converse): | ||
response = await bedrock_converse.achat(messages) | ||
|
||
assert isinstance(response, ChatResponse) | ||
assert response.message.role == MessageRole.ASSISTANT | ||
assert response.message.content == EXP_RESPONSE | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_astream_chat(bedrock_converse): | ||
response_stream = await bedrock_converse.astream_chat(messages) | ||
|
||
responses = [] | ||
async for response in response_stream: | ||
assert response.message.role == MessageRole.ASSISTANT | ||
assert response.delta in EXP_STREAM_RESPONSE | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_acomplete(bedrock_converse): | ||
response = await bedrock_converse.acomplete(prompt) | ||
|
||
assert isinstance(response, CompletionResponse) | ||
assert response.text == EXP_RESPONSE | ||
assert response.additional_kwargs["status"] == [] | ||
assert response.additional_kwargs["tool_call_id"] == [] | ||
assert response.additional_kwargs["tool_calls"] == [] | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_astream_complete(bedrock_converse): | ||
response_stream = await bedrock_converse.astream_complete(prompt) | ||
|
||
async for response in response_stream: | ||
assert response.delta in EXP_STREAM_RESPONSE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters