Skip to content

Commit

Permalink
chore(api/tests): apply ruff reformat #7590 (#7591)
Browse files Browse the repository at this point in the history
Co-authored-by: -LAN- <[email protected]>
  • Loading branch information
bowenliang123 and laipz8200 authored Aug 23, 2024
1 parent 2da6365 commit b035c02
Show file tree
Hide file tree
Showing 155 changed files with 4,272 additions and 5,918 deletions.
1 change: 0 additions & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ exclude = [
"migrations/**/*",
"services/**/*.py",
"tasks/**/*.py",
"tests/**/*.py",
]

[tool.pytest_env]
Expand Down
68 changes: 28 additions & 40 deletions api/tests/integration_tests/model_runtime/__mock/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,76 +22,64 @@
)
from anthropic.types.message_delta_event import Delta

MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"


class MockAnthropicClass:
@staticmethod
def mocked_anthropic_chat_create_sync(model: str) -> Message:
return Message(
id='msg-123',
type='message',
role='assistant',
content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')],
id="msg-123",
type="message",
role="assistant",
content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")],
model=model,
stop_reason='stop_sequence',
usage=Usage(
input_tokens=1,
output_tokens=1
)
stop_reason="stop_sequence",
usage=Usage(input_tokens=1, output_tokens=1),
)

@staticmethod
def mocked_anthropic_chat_create_stream(model: str) -> Stream[MessageStreamEvent]:
full_response_text = "hello, I'm a chatbot from anthropic"

yield MessageStartEvent(
type='message_start',
type="message_start",
message=Message(
id='msg-123',
id="msg-123",
content=[],
role='assistant',
role="assistant",
model=model,
stop_reason=None,
type='message',
usage=Usage(
input_tokens=1,
output_tokens=1
)
)
type="message",
usage=Usage(input_tokens=1, output_tokens=1),
),
)

index = 0
for i in range(0, len(full_response_text)):
yield ContentBlockDeltaEvent(
type='content_block_delta',
delta=TextDelta(text=full_response_text[i], type='text_delta'),
index=index
type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index
)

index += 1

yield MessageDeltaEvent(
type='message_delta',
delta=Delta(
stop_reason='stop_sequence'
),
usage=MessageDeltaUsage(
output_tokens=1
)
type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1)
)

yield MessageStopEvent(type='message_stop')

def mocked_anthropic(self: Messages, *,
max_tokens: int,
messages: Iterable[MessageParam],
model: str,
stream: Literal[True],
**kwargs: Any
) -> Union[Message, Stream[MessageStreamEvent]]:
yield MessageStopEvent(type="message_stop")

def mocked_anthropic(
self: Messages,
*,
max_tokens: int,
messages: Iterable[MessageParam],
model: str,
stream: Literal[True],
**kwargs: Any,
) -> Union[Message, Stream[MessageStreamEvent]]:
if len(self._client.api_key) < 18:
raise anthropic.AuthenticationError('Invalid API key')
raise anthropic.AuthenticationError("Invalid API key")

if stream:
return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model)
Expand All @@ -102,7 +90,7 @@ def mocked_anthropic(self: Messages, *,
@pytest.fixture
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic)
monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic)

yield

Expand Down
58 changes: 21 additions & 37 deletions api/tests/integration_tests/model_runtime/__mock/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,63 +12,46 @@
from google.generativeai.types import GenerateContentResponse
from google.generativeai.types.generation_types import BaseGenerateContentResponse

current_api_key = ''
current_api_key = ""


class MockGoogleResponseClass:
_done = False

def __iter__(self):
full_response_text = 'it\'s google!'
full_response_text = "it's google!"

for i in range(0, len(full_response_text) + 1, 1):
if i == len(full_response_text):
self._done = True
yield GenerateContentResponse(
done=True,
iterator=None,
result=glm.GenerateContentResponse({

}),
chunks=[]
done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
)
else:
yield GenerateContentResponse(
done=False,
iterator=None,
result=glm.GenerateContentResponse({

}),
chunks=[]
done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
)


class MockGoogleResponseCandidateClass:
finish_reason = 'stop'
finish_reason = "stop"

@property
def content(self) -> gag_content.Content:
return gag_content.Content(
parts=[
gag_content.Part(text='it\'s google!')
]
)
return gag_content.Content(parts=[gag_content.Part(text="it's google!")])


class MockGoogleClass:
@staticmethod
def generate_content_sync() -> GenerateContentResponse:
return GenerateContentResponse(
done=True,
iterator=None,
result=glm.GenerateContentResponse({

}),
chunks=[]
)
return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])

@staticmethod
def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
return MockGoogleResponseClass()

def generate_content(self: GenerativeModel,
def generate_content(
self: GenerativeModel,
contents: content_types.ContentsType,
*,
generation_config: generation_config_types.GenerationConfigType | None = None,
Expand All @@ -79,21 +62,21 @@ def generate_content(self: GenerativeModel,
global current_api_key

if len(current_api_key) < 16:
raise Exception('Invalid API key')
raise Exception("Invalid API key")

if stream:
return MockGoogleClass.generate_content_stream()

return MockGoogleClass.generate_content_sync()

@property
def generative_response_text(self) -> str:
return 'it\'s google!'
return "it's google!"

@property
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
return [MockGoogleResponseCandidateClass()]

def make_client(self: _ClientManager, name: str):
global current_api_key

Expand Down Expand Up @@ -121,7 +104,8 @@ def nop(self, *args, **kwargs):

if not self.default_metadata:
return client



@pytest.fixture
def setup_google_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
Expand All @@ -131,4 +115,4 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):

yield

monkeypatch.undo()
monkeypatch.undo()
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass

MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"


@pytest.fixture
def setup_huggingface_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation)

yield

if MOCK:
monkeypatch.undo()
monkeypatch.undo()
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ def generate_create_sync(model: str) -> TextGenerationResponse:
details=Details(
finish_reason="length",
generated_tokens=6,
tokens=[
Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)
]
)
tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)],
),
)

return response
Expand All @@ -36,26 +34,23 @@ def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse

for i in range(0, len(full_text)):
response = TextGenerationStreamResponse(
token = Token(id=i, text=full_text[i], logprob=0.0, special=False),
token=Token(id=i, text=full_text[i], logprob=0.0, special=False),
)
response.generated_text = full_text[i]
response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1)
response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1)

yield response

def text_generation(self: InferenceClient, prompt: str, *,
stream: Literal[False] = ...,
model: Optional[str] = None,
**kwargs: Any
def text_generation(
self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any
) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
# check if key is valid
if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']):
raise BadRequestError('Invalid API key')
if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]):
raise BadRequestError("Invalid API key")

if model is None:
raise BadRequestError('Invalid model')
raise BadRequestError("Invalid model")

if stream:
return MockHuggingfaceChatClass.generate_create_stream(model)
return MockHuggingfaceChatClass.generate_create_sync(model)

42 changes: 21 additions & 21 deletions api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ class MockTEIClass:
@staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
# During mock, we don't have a real server to query, so we just return a dummy value
if 'rerank' in model_name:
model_type = 'reranker'
if "rerank" in model_name:
model_type = "reranker"
else:
model_type = 'embedding'
model_type = "embedding"

return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)

Expand All @@ -17,16 +17,16 @@ def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
# Use space as token separator, and split the text into tokens
tokenized_texts = []
for text in texts:
tokens = text.split(' ')
tokens = text.split(" ")
current_index = 0
tokenized_text = []
for idx, token in enumerate(tokens):
s_token = {
'id': idx,
'text': token,
'special': False,
'start': current_index,
'stop': current_index + len(token),
"id": idx,
"text": token,
"special": False,
"start": current_index,
"stop": current_index + len(token),
}
current_index += len(token) + 1
tokenized_text.append(s_token)
Expand Down Expand Up @@ -55,18 +55,18 @@ def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
embedding = [0.1] * 768
embeddings.append(
{
'object': 'embedding',
'embedding': embedding,
'index': idx,
"object": "embedding",
"embedding": embedding,
"index": idx,
}
)
return {
'object': 'list',
'data': embeddings,
'model': 'MODEL_NAME',
'usage': {
'prompt_tokens': sum(len(text.split(' ')) for text in texts),
'total_tokens': sum(len(text.split(' ')) for text in texts),
"object": "list",
"data": embeddings,
"model": "MODEL_NAME",
"usage": {
"prompt_tokens": sum(len(text.split(" ")) for text in texts),
"total_tokens": sum(len(text.split(" ")) for text in texts),
},
}

Expand All @@ -83,9 +83,9 @@ def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
for idx, text in enumerate(texts):
reranked_docs.append(
{
'index': idx,
'text': text,
'score': 0.9,
"index": idx,
"text": text,
"score": 0.9,
}
)
# For mock, only return the first document
Expand Down
Loading

0 comments on commit b035c02

Please sign in to comment.