Skip to content

Commit

Permalink
feat(opentrons-ai-server): AUTH-403 BE chat completion (#15213)
Browse files Browse the repository at this point in the history
<!--
Thanks for taking the time to open a pull request! Please make sure
you've read the "Opening Pull Requests" section of our Contributing
Guide:


https://github.com/Opentrons/opentrons/blob/edge/CONTRIBUTING.md#opening-pull-requests

To ensure your code is reviewed quickly and thoroughly, please fill out
the sections below to the best of your ability!
-->

# Overview

- Extended the code to include storage and tools
- Now server correctly responds to the client

close [AUTH-403](https://opentrons.atlassian.net/browse/AUTH-403)

<!--
Use this section to describe your pull-request at a high level. If the
PR addresses any open issues, please tag the issues here.
-->

# Test Plan 
- make unit-test
- make direct-chat-completion
  - when you see shown below, type your prompt
    ` Type a prompt to send to the OpenAI API::` 


<!--
Use this section to describe the steps that you took to test your Pull
Request.
If you did not perform any testing provide justification why.

OT-3 Developers: You should default to testing on actual physical
hardware.
Once again, if you did not perform testing against hardware, justify
why.

Note: It can be helpful to write a test plan before doing development

Example Test Plan (HTTP API Change)

- Verified that new optional argument `dance-party` causes the robot to
flash its lights, move the pipettes,
then home.
- Verified that when you omit the `dance-party` option the robot homes
normally
- Added protocol that uses `dance-party` argument to G-Code Testing
Suite
- Ran protocol that did not use `dance-party` argument and everything
was successful
- Added unit tests to validate that changes to pydantic model are
correct

-->

# Changelog

<!--
List out the changes to the code in this PR. Please try your best to
categorize your changes and describe what has changed and why.

Example changelog:
- Fixed app crash when trying to calibrate an illegal pipette
- Added state to API to track pipette usage
- Updated API docs to mention only two pipettes are supported

IMPORTANT: MAKE SURE ANY BREAKING CHANGES ARE PROPERLY COMMUNICATED
-->

# Review requests

<!--
Describe any requests for your reviewers here.
-->

# Risk assessment
low
<!--
Carefully go over your pull request and look at the other parts of the
codebase it may affect. Look for the possibility, even if you think it's
small, that your change may affect some other part of the system - for
instance, changing return tip behavior in protocol may also change the
behavior of labware calibration.

Identify the other parts of the system your codebase may affect, so that
in addition to your own review and testing, other people who may not
have the system internalized as much as you can focus their attention
and testing there.
-->


[AUTH-403]:
https://opentrons.atlassian.net/browse/AUTH-403?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
  • Loading branch information
Elyorcv authored May 17, 2024
1 parent ea27e76 commit 6090b72
Show file tree
Hide file tree
Showing 14 changed files with 916,955 additions and 8 deletions.
158 changes: 151 additions & 7 deletions opentrons-ai-server/api/domain/openai_predict.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,139 @@
from typing import List
from pathlib import Path
from typing import List, Tuple

from llama_index.core import Settings as li_settings
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI as li_OpenAI
from llama_index.program.openai import OpenAIPydanticProgram
from openai import OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
from openai.types.chat import ChatCompletion, ChatCompletionFunctionMessageParam, ChatCompletionMessage, ChatCompletionMessageParam
from pydantic import BaseModel

from api.domain.prompts import system_notes
from api.domain.prompts import (
example_pcr_1,
execute_function_call,
general_rules_1,
pipette_type,
prompt_template_str,
rules_for_transfer,
standard_labware_api,
system_notes,
tools,
)
from api.domain.utils import refine_characters
from api.settings import Settings, is_running_on_lambda

ROOT_PATH: Path = Path(Path(__file__)).parent.parent.parent


class OpenAIPredict:
def __init__(self, settings: Settings) -> None:
self.settings: Settings = settings
self.client: OpenAI = OpenAI(api_key=settings.openai_api_key.get_secret_value())
li_settings.embed_model = OpenAIEmbedding(
model_name="text-embedding-3-large", api_key=self.settings.openai_api_key.get_secret_value()
)

def get_docs_all(self, query: str) -> Tuple[str, str, str]:
commands = self.extract_atomic_description(query)
print(f"commands: {commands}")

# define file paths for storage
example_command_path = str(ROOT_PATH / "api" / "storage" / "index" / "commands")
documentation_path = str(ROOT_PATH / "api" / "storage" / "index" / "v215")
labware_api_path = standard_labware_api

# retrieve example commands
example_commands = f"\n\n{'='*15} EXAMPLE COMMANDS {'='*15}\n"
storage_context = StorageContext.from_defaults(persist_dir=example_command_path)
index = load_index_from_storage(storage_context)
retriever = index.as_retriever(similarity_top_k=1)
content_all = ""
if isinstance(commands, list):
for command in commands:
nodes = retriever.retrieve(command)
content = "\n".join(node.text for node in nodes)
content_all += f">>>> >>>> \n\\{content}n"
example_commands += content_all
else:
example_commands = []

# retrieve documentation
storage_context = StorageContext.from_defaults(persist_dir=documentation_path)
index = load_index_from_storage(storage_context)
retriever = index.as_retriever(similarity_top_k=3)
nodes = retriever.retrieve(query)
docs = "\n".join(node.text.strip() for node in nodes)
docs_v215 = f"\n{'='*15} DOCUMENTATION {'='*15}\n\n" + docs

# standard api names
standard_api_names = f"\n{'='*15} STANDARD API NAMES {'='*15}\n\n" + labware_api_path

return example_commands, docs_v215, standard_api_names

def extract_atomic_description(self, protocol_description: str) -> List[str]:
class atomic_descr(BaseModel):
"""
Model for atomic descriptions
"""

desc: List[str]

program = OpenAIPydanticProgram.from_defaults(
output_cls=atomic_descr,
prompt_template_str=prompt_template_str.format(protocol_description=protocol_description),
verbose=False,
llm=li_OpenAI(model=self.settings.OPENAI_MODEL_NAME),
)
details = program(protocol_description=protocol_description)
descriptions = []
print("=" * 50)
for x in details.desc:
if x not in ["Modules:", "Adapter:", "Labware:", "Pipette mount:", "Commands:", "Well Allocation:", "No modules"]:
descriptions.append(x)
return descriptions

def refine_response(self, assitant_message: str) -> str:
if assitant_message is None:
return ""
system_message: ChatCompletionMessageParam = {
"role": "system",
"content": f"{general_rules_1}\n Please leave useful comments for each command.",
}

user_message: ChatCompletionMessageParam = {"role": "user", "content": assitant_message}

response = self.client.chat.completions.create(
model=self.settings.OPENAI_MODEL_NAME,
messages=[system_message, user_message],
stream=False,
temperature=0.005,
max_tokens=4000,
top_p=0.0,
frequency_penalty=0,
presence_penalty=0,
)

return response.choices[0].message.content if response.choices[0].message.content is not None else ""

def predict(self, prompt: str, chat_completion_message_params: List[ChatCompletionMessageParam] | None = None) -> None | str:
"""The simplest chat completion from the OpenAI API"""
top_p = 0.0

prompt = refine_characters(prompt)
messages: List[ChatCompletionMessageParam] = [{"role": "system", "content": system_notes}]
if chat_completion_message_params:
messages += chat_completion_message_params

user_message: ChatCompletionMessageParam = {"role": "user", "content": f"QUESTION/DESCRIPTION: \n{prompt}\n\n"}
example_commands, docs_v215, standard_api_names = self.get_docs_all(prompt)

user_message: ChatCompletionMessageParam = {
"role": "user",
"content": f"QUESTION/DESCRIPTION: \n{prompt}\n\n"
f"PYTHON API V2 DOCUMENTATION: \n{example_commands}\n"
f"{pipette_type}\n{example_pcr_1}\n\n{docs_v215}\n\n"
f"{rules_for_transfer}\n\n{standard_api_names}\n\n",
}

messages.append(user_message)

response: ChatCompletion = self.client.chat.completions.create(
Expand All @@ -28,12 +142,42 @@ def predict(self, prompt: str, chat_completion_message_params: List[ChatCompleti
stream=False,
temperature=0.005,
max_tokens=4000,
top_p=top_p,
top_p=0.0,
frequency_penalty=0,
presence_penalty=0,
tools=tools,
tool_choice="auto",
)

assistant_message: ChatCompletionMessage = response.choices[0].message
if assistant_message.content is None:
assistant_message.content = ""
assistant_message.content = str(self.refine_response(assistant_message.content))

if assistant_message.tool_calls and assistant_message.tool_calls[0]:
print("Simulation is started.")
if assistant_message.tool_calls[0]:
assistant_message.content = str(assistant_message.tool_calls[0].function)
messages.append({"role": assistant_message.role, "content": assistant_message.content})
tool_call = assistant_message.tool_calls[0]
function_response = execute_function_call(tool_call.function.name, tool_call.function.arguments)

# append tool call response to messages
messages.append(
ChatCompletionFunctionMessageParam(role="function", name=tool_call.function.name, content=str(function_response))
)
response2: ChatCompletion = self.client.chat.completions.create(
model=self.settings.OPENAI_MODEL_NAME,
messages=messages,
stream=False,
temperature=0,
max_tokens=4000,
top_p=0.0,
frequency_penalty=0,
presence_penalty=0,
)
final_response = response2.choices[0].message.content
return final_response
return assistant_message.content


Expand Down
Loading

0 comments on commit 6090b72

Please sign in to comment.