Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor LLM code #6

Merged
merged 3 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![Coverage](https://codecov.io/gh/Azzaare/ConstraintsTranslator.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/Azzaare/ConstraintsTranslator.jl)
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

A package for translating natural-language descriptions of optimization problems into Constraint Programming models to be solved via [`CBLS.jl`](https://github.com/JuliaConstraints/CBLS.jl) using Large Language Models (LLMs).
A package for translating natural-language descriptions of optimization problems into Constraint Programming models using Large Language Models (LLMs). For this pre-stable version stage, our target is to have models solved via [`CBLS.jl`](https://github.com/JuliaConstraints/CBLS.jl). Eventually, we expect this library to work for most of Julia CP ecosystem, alongside other CP modeling languages such as MiniZinc, OR-Tools, etc.

This package acts as a light wrapper around common LLM API endpoints, supplying appropriate system prompts and context informations to the LLMs to generate CP models. Specifically, we first prompt the model for generating an high-level representation of the problem in editable Markdown format, and then we prompt the model to generate Julia code.

Expand All @@ -17,17 +17,28 @@ We currently support the following LLM APIs:
Groq and Gemini are currently offering rate-limited free access to their APIs, and llama.cpp is free and open-source. We are still actively experimenting with this package, and we are not in a position to pay for API access. We might consider adding support for other APIs in the future.

## Workflow example
To begin playing with the package, you can start from the example below:
Before playing with the package, we need to set up two environment variables:
1. The EDITOR variable for specifying a text editor (such as `vim`, `nano`, `emacs`, ...). This will be used during interactive execution.
2. An API key. This is necessary only for interacting with proprietary LLMs.

We can configure those variables by, e.g., appending the following to your `.bashrc` or equivalent:
```bash
export EDITOR="vim"
export GOOGLE_API_KEY="42"
```

Or we can configure them in Julia:
```julia
ENV["EDITOR"] = "vim"
ENV["GOOGLE_API_KEY"] = "42"
```

Finally, we can start playing with the package. Below, an example for translating a natural-language description of the Traveling Salesman Problem:
```julia
using ConstraintsTranslator

llm = GoogleLLM("gemini-1.5-pro")

# Optional setup of a terminal editor (uncomment and select a viable editor on your machine such as vim, nano, emacs, ...)
ENV["EDITOR"] = "vim"


description = """
We need to determine the shortest possible route for a salesman who must visit a set of cities exactly once and return to the starting city.
The objective is to minimize the total travel distance while ensuring that each city is visited exactly once.
Expand All @@ -49,9 +60,10 @@ response = translate(llm, description)

The `translate` function will first produce a Markdown representation of the problem, and then return the generated Julia code for parsing the input data and building the model.

This example uses Google Gemini as an LLM. You will need an API key and a model id to access proprietary API endpoints. Use `help?>` in the Julia REPL to learn more about the available models.
The flag `interactive=true` will enable a simple interactive command-line application, where you will be able to inspect, edit and regenerate each intermediate output.

At each generation step, it will prompt the user in an interactive menu to accept the answer, edit the prompt and/or the generated text, or generate another answer with the same prompt.

The LLM expects the user to provide examples of the input data format. If no examples are present, the LLM will make assumptions about the data format based on the problem description.

This example uses Google Gemini as an LLM. You will need an API key and a model id to access proprietary API endpoints. Use `help?>` in the Julia REPL to learn more about the available models.
152 changes: 44 additions & 108 deletions src/llm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,68 +3,75 @@ const GEMINI_URL::String = "https://generativelanguage.googleapis.com/v1beta/mod
const GEMINI_URL_STREAM::String = "https://generativelanguage.googleapis.com/v1beta/models/{{model_id}}:streamGenerateContent?alt=sse"

abstract type AbstractLLM end
abstract type OpenAILLM <: AbstractLLM end

"""
GroqLLM
Structure encapsulating the parameters for accessing the Groq LLM API.
- `api_key`: an API key for accessing the Groq API (https://groq.com), read from the environmental variable GROQ_API_KEY
- `api_key`: an API key for accessing the Groq API (https://groq.com), read from the environmental variable GROQ_API_KEY.
- `model_id`: a string identifier for the model to query. See https://console.groq.com/docs/models for the list of available models.
- `url`: URL for chat completions. Defaults to "https://api.groq.com/openai/v1/chat/completions".
"""
struct GroqLLM <: AbstractLLM
struct GroqLLM <: OpenAILLM
api_key::String
model_id::String
url::String

function GroqLLM(model_id::String = "llama-3.1-8b-instant")
function GroqLLM(model_id::String = "llama3-70b-8192", url = GROQ_URL)
api_key = get(ENV, "GROQ_API_KEY", "")
if isempty(api_key)
error("Environment variable GROQ_API_KEY is not set")
end
new(api_key, model_id)
new(api_key, model_id, url)
end
end

"""
Google LLM
Structure encapsulating the parameters for accessing the Google LLM API.
- `api_key`: an API key for accessing the Google Gemini API (https://ai.google.dev/gemini-api/docs/), read from the environmental variable GOOGLE_API_KEY
- `api_key`: an API key for accessing the Google Gemini API (https://ai.google.dev/gemini-api/docs/), read from the environmental variable GOOGLE_API_KEY.
- `model_id`: a string identifier for the model to query. See https://ai.google.dev/gemini-api/docs/models/gemini for the list of available models.
- `url`: URL for chat completions. Defaults to ""https://generativelanguage.googleapis.com/v1beta/models/{{model_id}}".
"""
struct GoogleLLM <: AbstractLLM
api_key::String
model_id::String
url::String

function GoogleLLM(model_id::String = "gemini-1.5-flash")
api_key = get(ENV, "GOOGLE_API_KEY", "")
if isempty(api_key)
error("Environment variable GOOGLE_API_KEY is not set")
end
new(api_key, model_id)
new(api_key, model_id, GEMINI_URL)
end
end

"""
LlamaCppLLM
Structure encapsulating the parameters for accessing the llama.cpp server API.
- `api_key`: an optional API key for accessing the server
- `model_id`: a string identifier for the model to query. Unused, kept for API compatibility.
- `url`: the URL of the llama.cpp server OpenAI API endpoint (e.g., http://localhost:8080)
NOTE: we do not apply the appropriate chat templates to the prompt.
This must be handled either in an external code path or by the server.
"""
struct LlamaCppLLM <: AbstractLLM
struct LlamaCppLLM <: OpenAILLM
api_key::String
model_id::String
url::String

function LlamaCppLLM(url::String)
api_key = get(ENV, "LLAMA_CPP_API_KEY", "no-key")
new(api_key, url)
new(api_key, "hal-9000-v2", url)
end
end

"""
get_completion(llm::GroqLLM, prompt::Prompt)
Returns a completion for the given prompt using the Groq LLM API.
get_completion(llm::OpenAILLM, prompt::Prompt)
Returns a completion for the given prompt using an OpenAI API compatible LLM
"""
function get_completion(llm::GroqLLM, prompt::Prompt)
function get_completion(llm::OpenAILLM, prompt::Prompt)
headers = [
"Authorization" => "Bearer $(llm.api_key)",
"Content-Type" => "application/json",
Expand All @@ -76,7 +83,7 @@ function get_completion(llm::GroqLLM, prompt::Prompt)
],
"model" => llm.model_id,
))
response = HTTP.post(GROQ_URL, headers, body)
response = HTTP.post(llm.url, headers, body)
body = JSON3.read(response.body)
return body["choices"][1]["message"]["content"]
end
Expand All @@ -102,85 +109,11 @@ function get_completion(llm::GoogleLLM, prompt::Prompt)
end

"""
get_completion(llm::LlamaCppLLM, prompt::Prompt)
Returns a completion for the given prompt using the llama.cpp server API.
"""
function get_completion(llm::LlamaCppLLM, prompt::Prompt)
url = join([llm.url, "v1/chat/completions"], "/")
header = [
"Authorization" => "Bearer $(llm.api_key)",
"Content-Type" => "application/json",
]
body = JSON3.write(Dict(
"messages" => [
Dict("role" => "system", "content" => prompt.system),
Dict("role" => "user", "content" => prompt.user),
],
))
response = HTTP.post(url, header, body)
body = JSON3.read(response.body)
return body["choices"][1]["message"]["content"]
end

"""
stream_completion(llm::LlamaCppLLM, prompt::Prompt)
Returns a completion for the given prompt using the Groq LLM API.
The completion is streamed to the terminal as it is generated.
"""
function stream_completion(llm::LlamaCppLLM, prompt::Prompt)
url = join([llm.url, "v1/chat/completions"], "/")
headers = [
"Authorization" => "Bearer $(llm.api_key)",
"Content-Type" => "application/json",
]
body = JSON3.write(Dict(
"messages" => [
Dict("role" => "system", "content" => prompt.system),
Dict("role" => "user", "content" => prompt.user),
],
"stream" => true,
))

accumulated_content = ""
event_buffer = ""

HTTP.open(:POST, url, headers; body = body) do io
write(io, body)
HTTP.closewrite(io)
HTTP.startread(io)
while !eof(io)
chunk = String(readavailable(io))
events = split(chunk, "\n\n")
if !endswith(event_buffer, "\n\n")
event_buffer = events[end]
events = events[1:(end - 1)]
else
event_buffer = ""
end
events = join(events, "\n")
for line in eachmatch(r"(?<=data: ).*", events, overlap = true)
if line.match == "[DONE]"
print("\n")
break
end
message = JSON3.read(line.match)
if !isempty(message["choices"][1]["delta"])
print(message["choices"][1]["delta"]["content"])
accumulated_content *= message["choices"][1]["delta"]["content"]
end
end
end
HTTP.closeread(io)
end
return accumulated_content
end

"""
stream_completion(llm::GroqLLM, prompt::Prompt)
Returns a completion for the given prompt using the Groq LLM API.
stream_completion(llm::OpenAILLM, prompt::Prompt)
Returns a completion for the given prompt using an OpenAI API compatible model.
The completion is streamed to the terminal as it is generated.
"""
function stream_completion(llm::GroqLLM, prompt::Prompt)
function stream_completion(llm::OpenAILLM, prompt::Prompt)
headers = [
"Authorization" => "Bearer $(llm.api_key)",
"Content-Type" => "application/json",
Expand All @@ -197,29 +130,32 @@ function stream_completion(llm::GroqLLM, prompt::Prompt)
accumulated_content = ""
event_buffer = ""

HTTP.open(:POST, GROQ_URL, headers; body = body) do io
HTTP.open(:POST, llm.url, headers; body = body) do io
write(io, body)
HTTP.closewrite(io)
HTTP.startread(io)
while !eof(io)
chunk = String(readavailable(io))
events = split(chunk, "\n\n")
event_buffer *= chunk
events = split(event_buffer, "\n\n")
if !endswith(event_buffer, "\n\n")
event_buffer = events[end]
events = events[1:(end - 1)]
else
event_buffer = ""
end
events = join(events, "\n")
for line in eachmatch(r"(?<=data: ).*", events, overlap = true)
if line.match == "[DONE]"
print("\n")
break
end
message = JSON3.read(line.match)
if !isempty(message["choices"][1]["delta"])
print(message["choices"][1]["delta"]["content"])
accumulated_content *= message["choices"][1]["delta"]["content"]

for event in events
for line in eachmatch(r"(?<=data: ).*", event)
if line.match == "[DONE]"
print("\n")
return accumulated_content
end
message = JSON3.read(line.match)
if !isempty(message["choices"][1]["delta"])
print(message["choices"][1]["delta"]["content"])
accumulated_content *= message["choices"][1]["delta"]["content"]
end
end
end
end
Expand Down Expand Up @@ -253,14 +189,14 @@ function stream_completion(llm::GoogleLLM, prompt::Prompt)
HTTP.startread(io)
while !eof(io)
chunk = String(readavailable(io))
line = match(r"(?<=data: ).*", chunk)
if isnothing(line)
print("\n")
break
for line in eachmatch(r"(?<=data: ).*", chunk)
if isnothing(line)
continue
end
message = JSON3.read(line.match)
print(message["candidates"][1]["content"]["parts"][1]["text"])
accumulated_content *= String(message["candidates"][1]["content"]["parts"][1]["text"])
end
message = JSON3.read(line.match)
print(message["candidates"][1]["content"]["parts"][1]["text"])
accumulated_content *= String(message["candidates"][1]["content"]["parts"][1]["text"])
end
HTTP.closeread(io)
end
Expand Down
4 changes: 2 additions & 2 deletions src/translate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,15 @@ function fix_syntax_errors(model::AbstractLLM, code::AbstractString, error::Abst
end

"""
translate(model::AbstractLLM, description::AbstractString)
translate(model::AbstractLLM, description::AbstractString; interactive::Bool = false)
Translate the natural-language `description` of an optimization problem into
a Constraint Programming model by querying the Large Language Model `model`.
If `interactive`, the user will be prompted via the command line to inspect the
intermediate outputs of the LLM, and possibly modify them.
"""
function translate(
model::AbstractLLM,
description::AbstractString,
description::AbstractString;
interactive::Bool = false,
)
constraints = String[]
Expand Down
2 changes: 1 addition & 1 deletion templates/FixJuliaSyntax.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"_type": "metadatamessage"
},
{
"content": "You are an AI assistant specialized in writing Julia code. Your task is to examine a given code snippet alongside an error message related to syntax errors, and provide an updated version of the code snippet with the syntax errors resolved. \nIMPORTANT: 1. You must only fix the syntax errors without changing the functionality of the code.\n2. Think step-by-step, first describing the syntax errors in a bulleted list, and then providing the corrected code snippet in a Julia code block.\n3. You must report the complete code with the fix.",
"content": "You are an AI assistant specialized in writing Julia code. Your task is to examine a given code snippet alongside an error message related to syntax errors, and provide an updated version of the code snippet with the syntax errors resolved. \nIMPORTANT: 1. You must only fix the syntax errors without changing the functionality of the code.\n2. Think step-by-step, first describing the syntax errors in a bulleted list, and then providing the corrected code snippet in a Julia code block (i.e., ```julia [your code here] ```.\n3. You must report the complete code with the fix.",
"variables": [],
"_type": "systemmessage"
},
Expand Down
2 changes: 1 addition & 1 deletion templates/JumpifyModel.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"_type": "metadatamessage"
},
{
"content": "You are an AI assistant specialized in modeling Constraint Programming problems. Your task is to examine a given description of a Constraint Programming model and provide a code implementation in Julia, using JuMP and the CBLS solver. The code MUST: 1) Read the input data from external files into data structures according to the specifications provided in the description, using the appropriate Julia packages (e.g., DataFrames.jl, CSV.jl, etc.), 2) build the model, and 3) return the model.\nConstraints MUST be expressed with the following JuMP syntax: `@constraint(model, x in ConstraintName(kwargs))`, where `x` is a vector of variables, `ConstraintName` is the name of the constraint in camel-case (example: all different constraint -> AllDifferent()), and `kwargs` are the keyword arguments for the constraint (example: Sum(op=<=, val=10).\nIMPORTANT: 1. Output only the required function with no additional text or usage examples.\n2. You must write a docstring for the code.\n3. The code must be succinct and capture all the described constraints.\n4. You MUST use the provide syntax to express constraints. Do NOT express constraints in algebraic form.\n\n{{examples}}",
"content": "You are an AI assistant specialized in modeling Constraint Programming problems. Your task is to examine a given description of a Constraint Programming model and provide a code implementation in Julia, using JuMP and the CBLS solver. The code MUST: 1) Read the input data from external files into data structures according to the specifications provided in the description, using the appropriate Julia packages (e.g., DataFrames.jl, CSV.jl, etc.), 2) build the model, and 3) return the model.\nConstraints MUST be expressed with the following JuMP syntax: `@constraint(model, x in ConstraintName(kwargs))`, where `x` is a vector of variables, `ConstraintName` is the name of the constraint in camel-case (example: all different constraint -> AllDifferent()), and `kwargs` are the keyword arguments for the constraint (example: Sum(op=<=, val=10).\nIMPORTANT: 1. Output only the required function with no additional text or usage examples. The code must be wrapped in a Julia code block (i.e., ```julia [your code here] ```).\n2. You must write a docstring for the code.\n3. The code must be succinct and capture all the described constraints.\n4. You MUST use the provide syntax to express constraints. Do NOT express constraints in algebraic form.\n\n{{examples}}",
"variables": [
"examples"
],
Expand Down
Loading