diff --git a/.gitignore b/.gitignore index 4f156a52e..8f371947d 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,9 @@ ## ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore +# rest client tests for local development and testing +*.http + # generated frontend files code/dist/ diff --git a/code/backend/__init__.py b/code/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/code/backend/batch/get_conversation_response.py b/code/backend/batch/get_conversation_response.py index aa3f74358..53a34e90e 100644 --- a/code/backend/batch/get_conversation_response.py +++ b/code/backend/batch/get_conversation_response.py @@ -33,19 +33,20 @@ async def do_get_conversation_response(req: func.HttpRequest) -> func.HttpRespon lambda x: x["role"] in ("user", "assistant"), req_body["messages"][0:-1] ) ) - chat_history = [] - for i, k in enumerate(user_assistant_messages): - if i % 2 == 0: - chat_history.append( - ( - user_assistant_messages[i]["content"], - user_assistant_messages[i + 1]["content"], - ) - ) + # JM commented out + # chat_history = [] + # for i, k in enumerate(user_assistant_messages): + # if i % 2 == 0: + # chat_history.append( + # ( + # user_assistant_messages[i]["content"], + # user_assistant_messages[i + 1]["content"], + # ) + # ) messages = await message_orchestrator.handle_message( user_message=user_message, - chat_history=chat_history, + chat_history=user_assistant_messages, # was chat_history, #JM changed conversation_id=conversation_id, orchestrator=ConfigHelper.get_active_config_or_default().orchestrator, ) diff --git a/code/backend/batch/local.settings.json.sample b/code/backend/batch/local.settings.json.sample deleted file mode 100644 index 95a22ee00..000000000 --- a/code/backend/batch/local.settings.json.sample +++ /dev/null @@ -1,14 +0,0 @@ -{ - "IsEncrypted": false, - "Values": { - "FUNCTIONS_WORKER_RUNTIME": "python", - "AzureWebJobsStorage": "", - "MyBindingConnection": "", - "AzureWebJobs.HttpExample.Disabled": "true" - }, - "Host": { - "LocalHttpPort": 7071, - "CORS": "*", - "CORSCredentials": false - } - } \ No newline at end of file diff --git a/code/backend/batch/utilities/common/source_document.py b/code/backend/batch/utilities/common/source_document.py index 8c651c315..b4d06c9e0 100644 --- a/code/backend/batch/utilities/common/source_document.py +++ b/code/backend/batch/utilities/common/source_document.py @@ -52,15 +52,22 @@ def from_json(cls, json_string): @classmethod def from_dict(cls, dict_obj): + """ + Create a SourceDocument instance from a dictionary. + + :param dict_obj: Dictionary containing the SourceDocument attributes, at least the mandatory ones. + :return: SourceDocument instance. + """ return cls( - dict_obj["id"], + # using dict.get() for the optional attributes dict_obj["content"], dict_obj["source"], - dict_obj["title"], - dict_obj["chunk"], - dict_obj["offset"], - dict_obj["page_number"], - dict_obj["chunk_id"], + dict_obj.get("id"), + dict_obj.get("title"), + dict_obj.get("chunk"), + dict_obj.get("offset"), + dict_obj.get("page_number"), + dict_obj.get("chunk_id"), ) @classmethod diff --git a/code/backend/batch/utilities/orchestrator/byod_orchestrator.py b/code/backend/batch/utilities/orchestrator/byod_orchestrator.py new file mode 100644 index 000000000..78924bd7f --- /dev/null +++ b/code/backend/batch/utilities/orchestrator/byod_orchestrator.py @@ -0,0 +1,337 @@ +import logging +from typing import List +import json +from openai import Stream +from openai.types.chat import ChatCompletionChunk, ChatCompletion +from flask import Response + +from .orchestrator_base import OrchestratorBase +from ..helpers.llm_helper import LLMHelper +from ..helpers.env_helper import EnvHelper +from ..common.answer import Answer +from ..common.source_document import SourceDocument + +logger = logging.getLogger(__name__) + + +class ByodOrchestrator(OrchestratorBase): + def __init__(self) -> None: + super().__init__() + self.llm_helper = LLMHelper() + self.env_helper = EnvHelper() + # delete config if default message is not needed + #self.config = ConfigHelper.get_active_config_or_default() + + + async def orchestrate( + self, + user_message: str, + chat_history: List[dict], + **kwargs: dict + ) -> list[dict]: + + # Call Content Safety tool + if self.config.prompts.enable_content_safety: + if response := self.call_content_safety_input(user_message): + return response + + # should use data func - checks index config but I think it should be handled as an exception rather than generate an option for an API call with no index reference + # I don't think there should be a distinction between should use data and should not use data - let's just leave the without data func but default to the other one + # - in_scope: it's a parameter in the payload so it's implied and managed by the server if optional or mandatory + + openai_client = self.llm_helper.openai_client + messages = [] + + # Create conversation history + if self.config.prompts.use_on_your_data_format: + messages.append( + {"role": "system", "content": self.config.prompts.answering_system_prompt} + ) + else: + messages.append( + {"role": "system", "content": "You are a helpful AI agent."} + ) + + + # Create conversation history + for message in chat_history: + messages.append({"role": message["role"], "content": message["content"]}) + messages.append({"role": "user", "content": user_message}) + + is_in_scope = self.env_helper.AZURE_SEARCH_ENABLE_IN_DOMAIN + #request_messages: List[dict] = [{"role": "user", "content": user_message}] + + #messages= [ + # { + # "role": "user", + # "content": "Summarize the Life in Green case study." + # }, + # { + # "role": "assistant", + # "content": "The \"Life in Green\" case study revolves around a unique campaign designed to support Real Betis, a football club in Seville, Spain. The challenge was to create a way for fans to support their team during significant life moments, specifically targeting the rivalry with Sevilla FC, whose colors are red and white." + # }, + # { + # "role": "user", + # "content": "Please reformat it into 2 key bulletpoints." + # } + #] + # keeping the default prompts for now - change here if needed + + # build the message array for the payload + logger.info("Request messages: %s", messages) + #for message in request_messages: + # messages.append({"role": message['role'], "content": message["content"]}) + + # Azure OpenAI takes the deployment name as the model name, "AZURE_OPENAI_MODEL" means + # deployment name. + response = openai_client.chat.completions.create( + model=self.env_helper.AZURE_OPENAI_MODEL, + messages=messages, + temperature=float(self.env_helper.AZURE_OPENAI_TEMPERATURE), + max_tokens=int(self.env_helper.AZURE_OPENAI_MAX_TOKENS), + top_p=float(self.env_helper.AZURE_OPENAI_TOP_P), + stop=( + self.env_helper.AZURE_OPENAI_STOP_SEQUENCE.split("|") + if self.env_helper.AZURE_OPENAI_STOP_SEQUENCE + else None + ), + stream=self.env_helper.SHOULD_STREAM, # consider if Teams should have its own stream logic + extra_body={ + "data_sources": [ + { + "type": "azure_search", + "parameters": { + "authentication": ( + { + "type": "api_key", + "key": self.env_helper.AZURE_SEARCH_KEY, + } + if self.env_helper.is_auth_type_keys() + else { + "type": "system_assigned_managed_identity", + } + ), + "endpoint": self.env_helper.AZURE_SEARCH_SERVICE, + "index_name": self.env_helper.AZURE_SEARCH_INDEX, + "fields_mapping": { + "content_fields": ( + self.env_helper.AZURE_SEARCH_CONTENT_COLUMN.split("|") + if self.env_helper.AZURE_SEARCH_CONTENT_COLUMN + else [] + ), + "vector_fields": [ + self.env_helper.AZURE_SEARCH_CONTENT_VECTOR_COLUMN + ], + "title_field": self.env_helper.AZURE_SEARCH_TITLE_COLUMN or None, + "url_field": self.env_helper.AZURE_SEARCH_FIELDS_METADATA + or None, + "filepath_field": ( + self.env_helper.AZURE_SEARCH_FILENAME_COLUMN or None + ), + }, + "filter": self.env_helper.AZURE_SEARCH_FILTER, + # defaults to false - differences vs non OYD API calls? + "in_scope": self.env_helper.AZURE_SEARCH_ENABLE_IN_DOMAIN, + "top_n_documents": self.env_helper.AZURE_SEARCH_TOP_K, + "embedding_dependency": { + "type": "deployment_name", + "deployment_name": self.env_helper.AZURE_OPENAI_EMBEDDING_MODEL, + }, + "query_type": ( + "vector_semantic_hybrid" + if self.env_helper.AZURE_SEARCH_USE_SEMANTIC_SEARCH + else "vector_simple_hybrid" + ), + "semantic_configuration": ( + self.env_helper.AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG + if self.env_helper.AZURE_SEARCH_USE_SEMANTIC_SEARCH + and self.env_helper.AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG + else "" + ), + # is this overriding the system message?????? + # "role_information": self.env_helper.AZURE_OPENAI_SYSTEM_MESSAGE, # is this overriding the system message?????? + }, + } + ] + }, + ) + + # update chat history with response + #chat_history = self._update_chat_history_with_llm_response(chat_history, response.choices[0].message) + + + if not self.env_helper.SHOULD_STREAM: + citations = self.get_citations(citation_list=response.choices[0].message.model_extra["context"]) +# response_obj = { +# "id": response.id, +# "model": response.model, +# "created": response.created, +# "object": response.object, +# "choices": [ +# { +# "messages": [ +# { +# "content": json.dumps( +# citations, +# ensure_ascii=False, +# ), +# "end_turn": False, +# "role": "tool", +# }, +# { +# "end_turn": True, +# "content": response.choices[0].message.content, +# "role": "assistant", +# }, +# ] +# } +# ], +# } + + ##format answer + #answer = Answer( + # question=user_message, + # answer=response_obj.choices[0].messages[1].content + #) +# + #if answer.answer is None: + # answer.answer = "The requested information is not available in the retrieved data. Please try another query or topic." +# + ## Call Content Safety tool with answers + #if self.config.prompts.enable_content_safety: + # if response := self.call_content_safety_output(user_message, answer.answer): + # return response +# + #citations_array = response.choices[0].message.model_extra["context"].get("citations") +# + ## Format the output for the UI + #answer = Answer.from_json(json.dumps(response.choices[0]. ) + #answer = Answer.from_json( {"question": , answer, citations}) + + list_source_docs = [SourceDocument.from_dict(c) for c in citations['citations']] + + + + #answer = Answer( + # question=user_message, + # answer=response.choices[0].message.content, + # source_documents=[SourceDocument.from_json(c) for c in citations] + # #[SourceDocument.from_json(doc['url']) for doc in citations_array] + # #source_documents = response.choices[0].message.model_extra["context"].get("citations") + #) + + #q = Answer.from_json + + parsed_messages = self.output_parser.parse( + question=user_message, + answer=response.choices[0].message.content, + source_documents=list_source_docs + ) + return parsed_messages + + #return response_obj + + return Response(self.stream_with_data(response), mimetype="application/json-lines") + + +# def get_markdown_url(self, source, title, container_sas): +# """Get Markdown URL for a citation""" +# +# url = quote(source, safe=":/") +# if "_SAS_TOKEN_PLACEHOLDER_" in url: +# url = url.replace("_SAS_TOKEN_PLACEHOLDER_", container_sas) +# return f"[{title}]({url})" + + def _update_chat_history_with_llm_response(self, chat_history: List[dict], message) -> List[dict]: + """ + Add a message to the chat history dictionary list + :param self + :param chat_history: List of messages + :param message: Message to add from the response + :return: Updated chat history + """ + chat_history.append({"role": "assistant", "content": message.content}) + logger.debug("Chat history updated.") + return chat_history + + def get_citations(self, citation_list): + """Returns Formated Citations""" + #blob_client = AzureBlobStorageClient() + #container_sas = blob_client.get_container_sas() + citations_dict = {"citations": []} + for citation in citation_list.get("citations"): + metadata = ( + json.loads(citation["url"]) + if isinstance(citation["url"], str) + else citation["url"] + ) + title = citation["title"] + #url = self.get_markdown_url(metadata["source"], title, container_sas) + citations_dict["citations"].append( + { + "content": citation["content"], #url + "\n\n\n" + citation["content"], , + "id": metadata["id"], + "chunk_id": citation.get('chunk_id'),#( + # re.findall(r"\d+", metadata["chunk_id"])[-1] + # if metadata["chunk_id"] is not None + # else metadata["chunk"] + #), + "title": title, + #"filepath": title.split("/")[-1], + "source": metadata["source"], + #"chunk": 0 + } + ) + return citations_dict + + def stream_with_data(self, response: Stream[ChatCompletionChunk]): + '''This function streams the response from Azure OpenAI with data.''' + response_obj = { + "id": "", + "model": "", + "created": 0, + "object": "", + "choices": [ + { + "messages": [ + { + "content": "", + "end_turn": False, + "role": "tool", + }, + { + "content": "", + "end_turn": False, + "role": "assistant", + }, + ] + } + ], + } + + for line in response: + choice = line.choices[0] + + if choice.model_extra["end_turn"]: + response_obj["choices"][0]["messages"][1]["end_turn"] = True + yield json.dumps(response_obj, ensure_ascii=False) + "\n" + return + + response_obj["id"] = line.id + response_obj["model"] = line.model + response_obj["created"] = line.created + response_obj["object"] = line.object + + delta = choice.delta + role = delta.role + + if role == "assistant": + citations = self.get_citations(delta.model_extra["context"]) + response_obj["choices"][0]["messages"][0]["content"] = json.dumps( + citations, + ensure_ascii=False, + ) + else: + response_obj["choices"][0]["messages"][1]["content"] += delta.content + + yield json.dumps(response_obj, ensure_ascii=False) + "\n" diff --git a/code/backend/batch/utilities/orchestrator/orchestration_strategy.py b/code/backend/batch/utilities/orchestrator/orchestration_strategy.py index bc212e1c2..d2843e549 100644 --- a/code/backend/batch/utilities/orchestrator/orchestration_strategy.py +++ b/code/backend/batch/utilities/orchestrator/orchestration_strategy.py @@ -6,3 +6,4 @@ class OrchestrationStrategy(Enum): LANGCHAIN = "langchain" SEMANTIC_KERNEL = "semantic_kernel" PROMPT_FLOW = "prompt_flow" + BYOD = "byod" diff --git a/code/backend/batch/utilities/orchestrator/strategies.py b/code/backend/batch/utilities/orchestrator/strategies.py index 349cd0aa8..ed095aea9 100644 --- a/code/backend/batch/utilities/orchestrator/strategies.py +++ b/code/backend/batch/utilities/orchestrator/strategies.py @@ -3,6 +3,7 @@ from .lang_chain_agent import LangChainAgent from .semantic_kernel import SemanticKernelOrchestrator from .prompt_flow import PromptFlowOrchestrator +from .byod_orchestrator import ByodOrchestrator def get_orchestrator(orchestration_strategy: str): @@ -14,5 +15,7 @@ def get_orchestrator(orchestration_strategy: str): return SemanticKernelOrchestrator() elif orchestration_strategy == OrchestrationStrategy.PROMPT_FLOW.value: return PromptFlowOrchestrator() + elif orchestration_strategy == OrchestrationStrategy.BYOD.value: + return ByodOrchestrator() else: - raise Exception(f"Unknown orchestration strategy: {orchestration_strategy}") + raise ValueError(f"Unknown orchestration strategy: {orchestration_strategy}") diff --git a/code/backend/pages/04_Configuration.py b/code/backend/pages/04_Configuration.py index 1ac80215e..8974137f2 100644 --- a/code/backend/pages/04_Configuration.py +++ b/code/backend/pages/04_Configuration.py @@ -66,8 +66,9 @@ def load_css(file_path): st.session_state["orchestrator_strategy"] = config.orchestrator.strategy.value if "ai_assistant_type" not in st.session_state: st.session_state["ai_assistant_type"] = config.prompts.ai_assistant_type -if "conversational_flow" not in st.session_state: - st.session_state["conversational_flow"] = config.prompts.conversational_flow +################# Conversational flow to be deleted +#if "conversational_flow" not in st.session_state: +# st.session_state["conversational_flow"] = config.prompts.conversational_flow if "enable_chat_history" not in st.session_state: st.session_state["enable_chat_history"] = st.session_state[ "enable_chat_history" @@ -187,16 +188,18 @@ def validate_documents(): try: - conversational_flow_help = "Whether to use the custom conversational flow or byod conversational flow. Refer to the Conversational flow options README for more details." - with st.expander("Conversational flow configuration", expanded=True): - cols = st.columns([2, 4]) - with cols[0]: - conv_flow = st.selectbox( - "Conversational flow", - key="conversational_flow", - options=config.get_available_conversational_flows(), - help=conversational_flow_help, - ) + ################# conversationanl flow to be deleted + # this is the box on the admin config to choose custom or byod conversational flow + #conversational_flow_help = "Whether to use the custom conversational flow or byod conversational flow. Refer to the Conversational flow options README for more details." + #with st.expander("Conversational flow configuration", expanded=True): + # cols = st.columns([2, 4]) + # with cols[0]: + # conv_flow = st.selectbox( + # "Conversational flow", + # key="conversational_flow", + # options=config.get_available_conversational_flows(), + # help=conversational_flow_help, + # ) with st.expander("Orchestrator configuration", expanded=True): cols = st.columns([2, 4]) @@ -204,13 +207,16 @@ def validate_documents(): st.selectbox( "Orchestrator strategy", key="orchestrator_strategy", - options=config.get_available_orchestration_strategies(), - disabled=( - True - if st.session_state["conversational_flow"] - == ConversationFlow.BYOD.value - else False - ), + ######################## + #### fix this conversational flow deleted reference + ########################### + options=config.get_available_orchestration_strategies() #, + #disabled=( + # True + # if st.session_state["conversational_flow"] + # == ConversationFlow.BYOD.value + # else False + #), ) # # # condense_question_prompt_help = "This prompt is used to convert the user's input to a standalone question, using the context of the chat history." @@ -438,7 +444,8 @@ def validate_documents(): ], "enable_content_safety": st.session_state["enable_content_safety"], "ai_assistant_type": st.session_state["ai_assistant_type"], - "conversational_flow": st.session_state["conversational_flow"], + ################### conversational flow to be deleted + #"conversational_flow": st.session_state["conversational_flow"], }, "messages": { "post_answering_filter": st.session_state[ diff --git a/code/create_app.py b/code/create_app.py index c9d1368c5..b0d29ca9a 100644 --- a/code/create_app.py +++ b/code/create_app.py @@ -31,46 +31,48 @@ ERROR_GENERIC_MESSAGE = "An error occurred. Please try again. If the problem persists, please contact the site administrator." logger = logging.getLogger(__name__) - -def get_markdown_url(source, title, container_sas): - """Get Markdown URL for a citation""" - - url = quote(source, safe=":/") - if "_SAS_TOKEN_PLACEHOLDER_" in url: - url = url.replace("_SAS_TOKEN_PLACEHOLDER_", container_sas) - return f"[{title}]({url})" - - -def get_citations(citation_list): - """Returns Formated Citations""" - blob_client = AzureBlobStorageClient() - container_sas = blob_client.get_container_sas() - citations_dict = {"citations": []} - for citation in citation_list.get("citations"): - metadata = ( - json.loads(citation["url"]) - if isinstance(citation["url"], str) - else citation["url"] - ) - title = citation["title"] - url = get_markdown_url(metadata["source"], title, container_sas) - citations_dict["citations"].append( - { - "content": url + "\n\n\n" + citation["content"], - "id": metadata["id"], - "chunk_id": ( - re.findall(r"\d+", metadata["chunk_id"])[-1] - if metadata["chunk_id"] is not None - else metadata["chunk"] - ), - "title": title, - "filepath": title.split("/")[-1], - "url": url, - } - ) - return citations_dict - - +############################# TO DELETE ######################################## +#def get_markdown_url(source, title, container_sas): +# """Get Markdown URL for a citation""" +# +# url = quote(source, safe=":/") +# if "_SAS_TOKEN_PLACEHOLDER_" in url: +# url = url.replace("_SAS_TOKEN_PLACEHOLDER_", container_sas) +# return f"[{title}]({url})" +# +# +#def get_citations(citation_list): +# """Returns Formated Citations""" +# blob_client = AzureBlobStorageClient() +# container_sas = blob_client.get_container_sas() +# citations_dict = {"citations": []} +# for citation in citation_list.get("citations"): +# metadata = ( +# json.loads(citation["url"]) +# if isinstance(citation["url"], str) +# else citation["url"] +# ) +# title = citation["title"] +# url = get_markdown_url(metadata["source"], title, container_sas) +# citations_dict["citations"].append( +# { +# "content": url + "\n\n\n" + citation["content"], +# "id": metadata["id"], +# "chunk_id": ( +# re.findall(r"\d+", metadata["chunk_id"])[-1] +# if metadata["chunk_id"] is not None +# else metadata["chunk"] +# ), +# "title": title, +# "filepath": title.split("/")[-1], +# "url": url, +# } +# ) +# return citations_dict +############################################## till here + + +#################### POTENTIALLY TO DELETE ############################ that's just checking if Azure Search is enabled but rather get an exception than ignoring it def should_use_data( env_helper: EnvHelper, azure_search_helper: AzureSearchHelper ) -> bool: @@ -83,216 +85,216 @@ def should_use_data( return True return False - -def stream_with_data(response: Stream[ChatCompletionChunk]): - """This function streams the response from Azure OpenAI with data.""" - response_obj = { - "id": "", - "model": "", - "created": 0, - "object": "", - "choices": [ - { - "messages": [ - { - "content": "", - "end_turn": False, - "role": "tool", - }, - { - "content": "", - "end_turn": False, - "role": "assistant", - }, - ] - } - ], - } - - for line in response: - choice = line.choices[0] - - if choice.model_extra["end_turn"]: - response_obj["choices"][0]["messages"][1]["end_turn"] = True - yield json.dumps(response_obj, ensure_ascii=False) + "\n" - return - - response_obj["id"] = line.id - response_obj["model"] = line.model - response_obj["created"] = line.created - response_obj["object"] = line.object - - delta = choice.delta - role = delta.role - - if role == "assistant": - citations = get_citations(delta.model_extra["context"]) - response_obj["choices"][0]["messages"][0]["content"] = json.dumps( - citations, - ensure_ascii=False, - ) - else: - response_obj["choices"][0]["messages"][1]["content"] += delta.content - - yield json.dumps(response_obj, ensure_ascii=False) + "\n" - - -def conversation_with_data(conversation: Request, env_helper: EnvHelper): - """This function streams the response from Azure OpenAI with data.""" - if env_helper.is_auth_type_keys(): - openai_client = AzureOpenAI( - azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT, - api_version=env_helper.AZURE_OPENAI_API_VERSION, - api_key=env_helper.AZURE_OPENAI_API_KEY, - ) - else: - openai_client = AzureOpenAI( - azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT, - api_version=env_helper.AZURE_OPENAI_API_VERSION, - azure_ad_token_provider=env_helper.AZURE_TOKEN_PROVIDER, - ) - - request_messages = conversation.json["messages"] - messages = [] - config = ConfigHelper.get_active_config_or_default() - if config.prompts.use_on_your_data_format: - messages.append( - {"role": "system", "content": config.prompts.answering_system_prompt} - ) - - for message in request_messages: - messages.append({"role": message["role"], "content": message["content"]}) - - # Azure OpenAI takes the deployment name as the model name, "AZURE_OPENAI_MODEL" means - # deployment name. - response = openai_client.chat.completions.create( - model=env_helper.AZURE_OPENAI_MODEL, - messages=messages, - temperature=float(env_helper.AZURE_OPENAI_TEMPERATURE), - max_tokens=int(env_helper.AZURE_OPENAI_MAX_TOKENS), - top_p=float(env_helper.AZURE_OPENAI_TOP_P), - stop=( - env_helper.AZURE_OPENAI_STOP_SEQUENCE.split("|") - if env_helper.AZURE_OPENAI_STOP_SEQUENCE - else None - ), - stream=env_helper.SHOULD_STREAM, - extra_body={ - "data_sources": [ - { - "type": "azure_search", - "parameters": { - "authentication": ( - { - "type": "api_key", - "key": env_helper.AZURE_SEARCH_KEY, - } - if env_helper.is_auth_type_keys() - else { - "type": "system_assigned_managed_identity", - } - ), - "endpoint": env_helper.AZURE_SEARCH_SERVICE, - "index_name": env_helper.AZURE_SEARCH_INDEX, - "fields_mapping": { - "content_fields": ( - env_helper.AZURE_SEARCH_CONTENT_COLUMN.split("|") - if env_helper.AZURE_SEARCH_CONTENT_COLUMN - else [] - ), - "vector_fields": [ - env_helper.AZURE_SEARCH_CONTENT_VECTOR_COLUMN - ], - "title_field": env_helper.AZURE_SEARCH_TITLE_COLUMN or None, - "url_field": env_helper.AZURE_SEARCH_FIELDS_METADATA - or None, - "filepath_field": ( - env_helper.AZURE_SEARCH_FILENAME_COLUMN or None - ), - }, - "filter": env_helper.AZURE_SEARCH_FILTER, - "in_scope": env_helper.AZURE_SEARCH_ENABLE_IN_DOMAIN, - "top_n_documents": env_helper.AZURE_SEARCH_TOP_K, - "embedding_dependency": { - "type": "deployment_name", - "deployment_name": env_helper.AZURE_OPENAI_EMBEDDING_MODEL, - }, - "query_type": ( - "vector_semantic_hybrid" - if env_helper.AZURE_SEARCH_USE_SEMANTIC_SEARCH - else "vector_simple_hybrid" - ), - "semantic_configuration": ( - env_helper.AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG - if env_helper.AZURE_SEARCH_USE_SEMANTIC_SEARCH - and env_helper.AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG - else "" - ), - "role_information": env_helper.AZURE_OPENAI_SYSTEM_MESSAGE, - }, - } - ] - }, - ) - - if not env_helper.SHOULD_STREAM: - citations = get_citations(response.choices[0].message.model_extra["context"]) - response_obj = { - "id": response.id, - "model": response.model, - "created": response.created, - "object": response.object, - "choices": [ - { - "messages": [ - { - "content": json.dumps( - citations, - ensure_ascii=False, - ), - "end_turn": False, - "role": "tool", - }, - { - "end_turn": True, - "content": response.choices[0].message.content, - "role": "assistant", - }, - ] - } - ], - } - - return response_obj - - return Response(stream_with_data(response), mimetype="application/json-lines") - - -def stream_without_data(response: Stream[ChatCompletionChunk]): - """This function streams the response from Azure OpenAI without data.""" - response_text = "" - for line in response: - if not line.choices: - continue - - delta_text = line.choices[0].delta.content - - if delta_text is None: - return - - response_text += delta_text - - response_obj = { - "id": line.id, - "model": line.model, - "created": line.created, - "object": line.object, - "choices": [ - {"messages": [{"role": "assistant", "content": response_text}]} - ], - } - yield json.dumps(response_obj, ensure_ascii=False) + "\n" - +############################### TO DELETE ######################################## +#def stream_with_data(response: Stream[ChatCompletionChunk]): +# """This function streams the response from Azure OpenAI with data.""" +# response_obj = { +# "id": "", +# "model": "", +# "created": 0, +# "object": "", +# "choices": [ +# { +# "messages": [ +# { +# "content": "", +# "end_turn": False, +# "role": "tool", +# }, +# { +# "content": "", +# "end_turn": False, +# "role": "assistant", +# }, +# ] +# } +# ], +# } +# +# for line in response: +# choice = line.choices[0] +# +# if choice.model_extra["end_turn"]: +# response_obj["choices"][0]["messages"][1]["end_turn"] = True +# yield json.dumps(response_obj, ensure_ascii=False) + "\n" +# return +# +# response_obj["id"] = line.id +# response_obj["model"] = line.model +# response_obj["created"] = line.created +# response_obj["object"] = line.object +# +# delta = choice.delta +# role = delta.role +# +# if role == "assistant": +# citations = get_citations(delta.model_extra["context"]) +# response_obj["choices"][0]["messages"][0]["content"] = json.dumps( +# citations, +# ensure_ascii=False, +# ) +# else: +# response_obj["choices"][0]["messages"][1]["content"] += delta.content +# +# yield json.dumps(response_obj, ensure_ascii=False) + "\n" + + +#def conversation_with_data(conversation: Request, env_helper: EnvHelper): +# """This function streams the response from Azure OpenAI with data.""" +# if env_helper.is_auth_type_keys(): +# openai_client = AzureOpenAI( +# azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT, +# api_version=env_helper.AZURE_OPENAI_API_VERSION, +# api_key=env_helper.AZURE_OPENAI_API_KEY, +# ) +# else: +# openai_client = AzureOpenAI( +# azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT, +# api_version=env_helper.AZURE_OPENAI_API_VERSION, +# azure_ad_token_provider=env_helper.AZURE_TOKEN_PROVIDER, +# ) +# +# request_messages = conversation.json["messages"] +# messages = [] +# config = ConfigHelper.get_active_config_or_default() +# if config.prompts.use_on_your_data_format: +# messages.append( +# {"role": "system", "content": config.prompts.answering_system_prompt} +# ) +# +# for message in request_messages: +# messages.append({"role": message["role"], "content": message["content"]}) +# +# # Azure OpenAI takes the deployment name as the model name, "AZURE_OPENAI_MODEL" means +# # deployment name. +# response = openai_client.chat.completions.create( +# model=env_helper.AZURE_OPENAI_MODEL, +# messages=messages, +# temperature=float(env_helper.AZURE_OPENAI_TEMPERATURE), +# max_tokens=int(env_helper.AZURE_OPENAI_MAX_TOKENS), +# top_p=float(env_helper.AZURE_OPENAI_TOP_P), +# stop=( +# env_helper.AZURE_OPENAI_STOP_SEQUENCE.split("|") +# if env_helper.AZURE_OPENAI_STOP_SEQUENCE +# else None +# ), +# stream=env_helper.SHOULD_STREAM, +# extra_body={ +# "data_sources": [ +# { +# "type": "azure_search", +# "parameters": { +# "authentication": ( +# { +# "type": "api_key", +# "key": env_helper.AZURE_SEARCH_KEY, +# } +# if env_helper.is_auth_type_keys() +# else { +# "type": "system_assigned_managed_identity", +# } +# ), +# "endpoint": env_helper.AZURE_SEARCH_SERVICE, +# "index_name": env_helper.AZURE_SEARCH_INDEX, +# "fields_mapping": { +# "content_fields": ( +# env_helper.AZURE_SEARCH_CONTENT_COLUMN.split("|") +# if env_helper.AZURE_SEARCH_CONTENT_COLUMN +# else [] +# ), +# "vector_fields": [ +# env_helper.AZURE_SEARCH_CONTENT_VECTOR_COLUMN +# ], +# "title_field": env_helper.AZURE_SEARCH_TITLE_COLUMN or None, +# "url_field": env_helper.AZURE_SEARCH_FIELDS_METADATA +# or None, +# "filepath_field": ( +# env_helper.AZURE_SEARCH_FILENAME_COLUMN or None +# ), +# }, +# "filter": env_helper.AZURE_SEARCH_FILTER, +# "in_scope": env_helper.AZURE_SEARCH_ENABLE_IN_DOMAIN, +# "top_n_documents": env_helper.AZURE_SEARCH_TOP_K, +# "embedding_dependency": { +# "type": "deployment_name", +# "deployment_name": env_helper.AZURE_OPENAI_EMBEDDING_MODEL, +# }, +# "query_type": ( +# "vector_semantic_hybrid" +# if env_helper.AZURE_SEARCH_USE_SEMANTIC_SEARCH +# else "vector_simple_hybrid" +# ), +# "semantic_configuration": ( +# env_helper.AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG +# if env_helper.AZURE_SEARCH_USE_SEMANTIC_SEARCH +# and env_helper.AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG +# else "" +# ), +# "role_information": env_helper.AZURE_OPENAI_SYSTEM_MESSAGE, +# }, +# } +# ] +# }, +# ) +# +# if not env_helper.SHOULD_STREAM: +# citations = get_citations(response.choices[0].message.model_extra["context"]) +# response_obj = { +# "id": response.id, +# "model": response.model, +# "created": response.created, +# "object": response.object, +# "choices": [ +# { +# "messages": [ +# { +# "content": json.dumps( +# citations, +# ensure_ascii=False, +# ), +# "end_turn": False, +# "role": "tool", +# }, +# { +# "end_turn": True, +# "content": response.choices[0].message.content, +# "role": "assistant", +# }, +# ] +# } +# ], +# } +# +# return response_obj +# +# return Response(stream_with_data(response), mimetype="application/json-lines") +# +# +#def stream_without_data(response: Stream[ChatCompletionChunk]): +# """This function streams the response from Azure OpenAI without data.""" +# response_text = "" +# for line in response: +# if not line.choices: +# continue +# +# delta_text = line.choices[0].delta.content +# +# if delta_text is None: +# return +# +# response_text += delta_text +# +# response_obj = { +# "id": line.id, +# "model": line.model, +# "created": line.created, +# "object": line.object, +# "choices": [ +# {"messages": [{"role": "assistant", "content": response_text}]} +# ], +# } +# yield json.dumps(response_obj, ensure_ascii=False) + "\n" +######################################################################### till here def get_message_orchestrator(): """This function gets the message orchestrator.""" @@ -303,64 +305,66 @@ def get_orchestrator_config(): """This function gets the orchestrator configuration.""" return ConfigHelper.get_active_config_or_default().orchestrator +#################### TO DELETE ############################ +#def conversation_without_data(conversation: Request, env_helper: EnvHelper): +# """This function streams the response from Azure OpenAI without data.""" +# if env_helper.AZURE_AUTH_TYPE == "rbac": +# openai_client = AzureOpenAI( +# azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT, +# api_version=env_helper.AZURE_OPENAI_API_VERSION, +# azure_ad_token_provider=env_helper.AZURE_TOKEN_PROVIDER, +# ) +# else: +# openai_client = AzureOpenAI( +# azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT, +# api_version=env_helper.AZURE_OPENAI_API_VERSION, +# api_key=env_helper.AZURE_OPENAI_API_KEY, +# ) +# +# request_messages = conversation.json["messages"] +# messages = [{"role": "system", "content": env_helper.AZURE_OPENAI_SYSTEM_MESSAGE}] +# +# for message in request_messages: +# messages.append({"role": message["role"], "content": message["content"]}) +# +# # Azure Open AI takes the deployment name as the model name, "AZURE_OPENAI_MODEL" means +# # deployment name. +# response = openai_client.chat.completions.create( +# model=env_helper.AZURE_OPENAI_MODEL, +# messages=messages, +# temperature=float(env_helper.AZURE_OPENAI_TEMPERATURE), +# max_tokens=int(env_helper.AZURE_OPENAI_MAX_TOKENS), +# top_p=float(env_helper.AZURE_OPENAI_TOP_P), +# stop=( +# env_helper.AZURE_OPENAI_STOP_SEQUENCE.split("|") +# if env_helper.AZURE_OPENAI_STOP_SEQUENCE +# else None +# ), +# stream=env_helper.SHOULD_STREAM, +# ) +# +# if not env_helper.SHOULD_STREAM: +# response_obj = { +# "id": response.id, +# "model": response.model, +# "created": response.created, +# "object": response.object, +# "choices": [ +# { +# "messages": [ +# { +# "role": "assistant", +# "content": response.choices[0].message.content, +# } +# ] +# } +# ], +# } +# return jsonify(response_obj), 200 +# +# return Response(stream_without_data(response), mimetype="application/json-lines") +########################################################## till here -def conversation_without_data(conversation: Request, env_helper: EnvHelper): - """This function streams the response from Azure OpenAI without data.""" - if env_helper.AZURE_AUTH_TYPE == "rbac": - openai_client = AzureOpenAI( - azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT, - api_version=env_helper.AZURE_OPENAI_API_VERSION, - azure_ad_token_provider=env_helper.AZURE_TOKEN_PROVIDER, - ) - else: - openai_client = AzureOpenAI( - azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT, - api_version=env_helper.AZURE_OPENAI_API_VERSION, - api_key=env_helper.AZURE_OPENAI_API_KEY, - ) - - request_messages = conversation.json["messages"] - messages = [{"role": "system", "content": env_helper.AZURE_OPENAI_SYSTEM_MESSAGE}] - - for message in request_messages: - messages.append({"role": message["role"], "content": message["content"]}) - - # Azure Open AI takes the deployment name as the model name, "AZURE_OPENAI_MODEL" means - # deployment name. - response = openai_client.chat.completions.create( - model=env_helper.AZURE_OPENAI_MODEL, - messages=messages, - temperature=float(env_helper.AZURE_OPENAI_TEMPERATURE), - max_tokens=int(env_helper.AZURE_OPENAI_MAX_TOKENS), - top_p=float(env_helper.AZURE_OPENAI_TOP_P), - stop=( - env_helper.AZURE_OPENAI_STOP_SEQUENCE.split("|") - if env_helper.AZURE_OPENAI_STOP_SEQUENCE - else None - ), - stream=env_helper.SHOULD_STREAM, - ) - - if not env_helper.SHOULD_STREAM: - response_obj = { - "id": response.id, - "model": response.model, - "created": response.created, - "object": response.object, - "choices": [ - { - "messages": [ - { - "role": "assistant", - "content": response.choices[0].message.content, - } - ] - } - ], - } - return jsonify(response_obj), 200 - - return Response(stream_without_data(response), mimetype="application/json-lines") @functools.cache @@ -408,25 +412,29 @@ def static_file(path): def health(): return "OK" - def conversation_azure_byod(): - try: - if should_use_data(env_helper, azure_search_helper): - return conversation_with_data(request, env_helper) - else: - return conversation_without_data(request, env_helper) - except APIStatusError as e: - error_message = str(e) - logger.exception("Exception in /api/conversation | %s", error_message) - response_json = e.response.json() - response_message = response_json.get("error", {}).get("message", "") - response_code = response_json.get("error", {}).get("code", "") - if response_code == "429" or "429" in response_message: - return jsonify({"error": ERROR_429_MESSAGE}), 429 - return jsonify({"error": ERROR_GENERIC_MESSAGE}), 500 - except Exception as e: - error_message = str(e) - logger.exception("Exception in /api/conversation | %s", error_message) - return jsonify({"error": ERROR_GENERIC_MESSAGE}), 500 +################## TO DELETE ############################ +# def conversation_azure_byod(): +# try: +# if should_use_data(env_helper, azure_search_helper): +# return conversation_with_data(request, env_helper) +# else: +# return conversation_without_data(request, env_helper) +# except APIStatusError as e: +# error_message = str(e) +# logger.exception("Exception in /api/conversation | %s", error_message) +# response_json = e.response.json() +# response_message = response_json.get("error", {}).get("message", "") +# response_code = response_json.get("error", {}).get("code", "") +# if response_code == "429" or "429" in response_message: +# return jsonify({"error": ERROR_429_MESSAGE}), 429 +# return jsonify({"error": ERROR_GENERIC_MESSAGE}), 500 +# except Exception as e: +# error_message = str(e) +# logger.exception("Exception in /api/conversation | %s", error_message) +# return jsonify({"error": ERROR_GENERIC_MESSAGE}), 500 +# +######################################################### till here + async def conversation_custom(): message_orchestrator = get_message_orchestrator() @@ -474,22 +482,25 @@ async def conversation_custom(): @app.route("/api/conversation", methods=["POST"]) async def conversation(): - ConfigHelper.get_active_config_or_default.cache_clear() - result = ConfigHelper.get_active_config_or_default() - conversation_flow = result.prompts.conversational_flow - if conversation_flow == ConversationFlow.CUSTOM.value: - return await conversation_custom() - elif conversation_flow == ConversationFlow.BYOD.value: - return conversation_azure_byod() - else: - return ( - jsonify( - { - "error": "Invalid conversation flow configured. Value can only be 'custom' or 'byod'." - } - ), - 500, - ) + return await conversation_custom() + #ConfigHelper.get_active_config_or_default.cache_clear() + #result = ConfigHelper.get_active_config_or_default() + # conversation flow deprecated + #conversation_flow = result.prompts.conversational_flow + #if conversation_flow == ConversationFlow.CUSTOM.value: + #return await conversation_custom() + # elif conversation_flow == ConversationFlow.BYOD.value: + #raise NotImplementedError("Conversation flow BYOD is no longer a feature and it's now part of the orchestrators env var.") + #return conversation_azure_byod() + #else: + # return ( + # jsonify( + # { + # "error": "Invalid conversation flow configured. Value can only be 'custom'" # or 'byod'." + # } + # ), + # 500, + # ) @app.route("/api/speech", methods=["GET"]) def speech_config(): diff --git a/code/tests/utilities/orchestrator/test_byod_orchestrator.py b/code/tests/utilities/orchestrator/test_byod_orchestrator.py new file mode 100644 index 000000000..97b7a25c9 --- /dev/null +++ b/code/tests/utilities/orchestrator/test_byod_orchestrator.py @@ -0,0 +1,177 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from backend.batch.utilities.orchestrator.byod_orchestrator import ( + ByodOrchestrator +) +from backend.batch.utilities.parser.output_parser_tool import OutputParserTool + + +import pytest +from unittest.mock import MagicMock, patch, AsyncMock +from backend.batch.utilities.orchestrator.byod_orchestrator import ByodOrchestrator +from backend.batch.utilities.helpers.llm_helper import LLMHelper +from backend.batch.utilities.helpers.env_helper import EnvHelper + + +@pytest.fixture(autouse=True) +def llm_helper_mock(): + with patch( + "backend.batch.utilities.orchestrator.byod_orchestrator.LLMHelper" + ) as mock: + llm_helper = mock.return_value + + yield llm_helper + + +@pytest.fixture +def orchestrator(autouse=True): + with patch("backend.batch.utilities.orchestrator.orchestrator_base.ConfigHelper.get_active_config_or_default") as mock_config: + mock_config.return_value.prompts.enable_content_safety = True + orchestrator = ByodOrchestrator() + orchestrator.llm_helper = MagicMock(spec=LLMHelper) + orchestrator.llm_helper.openai_client = MagicMock() + orchestrator.llm_helper.AZURE_OPENAI_MODEL = "test-model" + orchestrator.env_helper = MagicMock(spec=EnvHelper) + + env_helper_mock = MagicMock(spec=EnvHelper) + + # Dictionary of necessary attributes from .env + env_attributes = { + "AZURE_OPENAI_MODEL": "test-model", + "AZURE_OPENAI_TEMPERATURE": 0.6, + "AZURE_OPENAI_MAX_TOKENS": 1500, + "AZURE_OPENAI_TOP_P": 1, + "AZURE_OPENAI_STOP_SEQUENCE": None, + "SHOULD_STREAM": False, + "AZURE_SEARCH_KEY": "AZURE-SEARCH-KEY", + "AZURE_SEARCH_SERVICE": "https://search-tmx73bp4hzfbw.search.windows.net/", + "AZURE_SEARCH_INDEX": "index-tmx73bp4hzfbw", + "AZURE_SEARCH_CONTENT_COLUMN": "content", + "AZURE_SEARCH_CONTENT_VECTOR_COLUMN": "content_vector", + "AZURE_SEARCH_TITLE_COLUMN": "title", + "AZURE_SEARCH_FIELDS_METADATA": "metadata", + "AZURE_SEARCH_FILENAME_COLUMN": "filename", + "AZURE_SEARCH_FILTER": "", + "AZURE_SEARCH_ENABLE_IN_DOMAIN": True, + "AZURE_SEARCH_TOP_K": 5, + "AZURE_OPENAI_EMBEDDING_MODEL": "text-embedding-ada-002", + "AZURE_SEARCH_USE_SEMANTIC_SEARCH": False, + "AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG": "default", + "AZURE_OPENAI_SYSTEM_MESSAGE": "You are an AI assistant that helps people find information." + } + + # Set attributes on the MagicMock instance + for attr, value in env_attributes.items(): + setattr(env_helper_mock, attr, value) + + orchestrator.env_helper = env_helper_mock + + return orchestrator + + +def test_initialization(orchestrator): + assert isinstance(orchestrator, ByodOrchestrator) + + + +@pytest.mark.asyncio +async def test_orchestrate3(orchestrator): + # Arrange + #orchestrator = ByodOrchestrator() + + user_message = "Tell me about Azure AI" + chat_history = [{"role": "system", "content": "This is a test"}] + + # Define a mocked response from the API using SimpleNamespace to simulate an object with attributes + mock_message = SimpleNamespace( + content="Azure AI is a set of tools and services...", + model_extra={"context": {"citations": [{"content": "Citation text", "url": "example.com"}]}} + ) + mock_choice = SimpleNamespace(message=mock_message) + mock_api_response = AsyncMock() + mock_api_response.choices = [mock_choice] + + with patch.object( + orchestrator.llm_helper.openai_client.chat.completions, + "create", + return_value=mock_api_response + ) as mock_create: + # Act + result = await orchestrator.orchestrate(user_message, chat_history) + + # Assert + mock_create.assert_called_once() # Ensure API call was made once + assert result # Check the result is not None or empty + assert isinstance(result, list) # Ensure output is a list + assert result[0].get("content") == "Azure AI is a set of tools and services..." # Check response content + + +@pytest.mark.asyncio +async def test_orchestrate(orchestrator): + orchestrator.llm_helper.openai_client.chat.completions.create = AsyncMock(return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="response content", model_extra={"context": {}}))])) + user_message = "Hello" + chat_history = [] + response = await orchestrator.orchestrate(user_message, chat_history) + assert response is not None + + + +def test_get_citations(orchestrator): + citation_list = { + "citations": [ + { + "content": "citation content", + "url": '{"source": "source_url", "id": "1"}', + "title": "citation title", + "chunk_id": "1" + } + ] + } + citations = orchestrator.get_citations(citation_list) + assert citations is not None + assert len(citations["citations"]) == 1 + + +@pytest.mark.asyncio +async def test_orchestrate_with_content_safety_enabled(orchestrator): + orchestrator.config.prompts.enable_content_safety = True + orchestrator.call_content_safety_input = MagicMock(return_value=[{"role": "assistant", "content": "Content safety response"}]) + user_message = "Hello" + chat_history = [] + response = await orchestrator.orchestrate(user_message, chat_history) + assert response == [{"role": "assistant", "content": "Content safety response"}] + + +@pytest.mark.asyncio +async def test_orchestrate_without_content_safety_enabled(orchestrator): + orchestrator.config.prompts.enable_content_safety = False + orchestrator.llm_helper.openai_client.chat.completions.create = AsyncMock(return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="response content", model_extra={"context": {}}))])) + user_message = "Hello" + chat_history = [] + response = await orchestrator.orchestrate(user_message, chat_history) + assert response is not None + assert isinstance(response, list) + + +@pytest.mark.asyncio +async def test_orchestrate_with_streaming_disabled(orchestrator): + orchestrator.env_helper.SHOULD_STREAM = False + orchestrator.llm_helper.openai_client.chat.completions.create = AsyncMock(return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="response content", model_extra={"context": {"citations": []}}))])) + user_message = "Hello" + chat_history = [] + response = await orchestrator.orchestrate(user_message, chat_history) + assert response is not None + assert isinstance(response, list) + + +##@pytest.mark.asyncio +##async def test_orchestrate_with_streaming_enabled(orchestrator): +## orchestrator.env_helper.SHOULD_STREAM = True +## orchestrator.llm_helper.openai_client.chat.completions.create = AsyncMock(return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="response content", model_extra={"context": {}}))])) +## user_message = "Hello" +## chat_history = [] +## response = await orchestrator.orchestrate(user_message, chat_history) +## assert response is not None +## assert isinstance(response, Response) diff --git a/code/tests/utilities/orchestrator/test_orchestrator.py b/code/tests/utilities/orchestrator/test_orchestrator.py index 991fe4faa..bdb189e7e 100644 --- a/code/tests/utilities/orchestrator/test_orchestrator.py +++ b/code/tests/utilities/orchestrator/test_orchestrator.py @@ -5,6 +5,11 @@ ) +""" +TOKEN RATE LIMITS MAY APPLY AND THROTTLE OPENAI CALLS ON THE TESTS +""" + + @pytest.mark.azure("This test requires Azure Open AI configured") @pytest.mark.asyncio async def test_orchestrator_openai_function(): @@ -33,3 +38,18 @@ async def test_orchestrator_langchain(): ) assert messages[-1]["role"] == "assistant" assert messages[-1]["content"] != "" + + +@pytest.mark.azure("This test requires Azure Open AI configured") +@pytest.mark.asyncio +async def test_orchestrator_byod(): + message_orchestrator = Orchestrator() + strategy = "byod" + messages = await message_orchestrator.handle_message( + user_message="What's Azure AI Search?", + chat_history=[], + conversation_id="test_byod", + orchestrator=OrchestrationSettings({"strategy": strategy}), + ) + assert messages[-1]["role"] == "assistant" + assert messages[-1]["content"] != ""