Skip to content

Commit

Permalink
Merge pull request #8 from JuliaConstraints/dev
Browse files Browse the repository at this point in the history
Release 0.0.2
  • Loading branch information
nicoladicicco authored Sep 24, 2024
2 parents cb8b117 + 07924d8 commit 67a7702
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 120 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ConstraintsTranslator"
uuid = "314c63f5-3dda-4b35-95e7-4cc933f13053"
authors = ["Jean-François BAFFIER (@Azzaare)"]
version = "0.0.1"
version = "0.0.2"

[deps]
Constraints = "30f324ab-b02d-43f0-b619-e131c61659f7"
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ CityA,CityB,10
CityA,CityC,8
"""

response = translate(llm, description)
response = translate(llm, description, interactive=true)
```

The `translate` function will first produce a Markdown representation of the problem, and then return the generated Julia code for parsing the input data and building the model.

The flag `interactive=true` will enable a simple interactive command-line application, where you will be able to inspect, edit and regenerate each intermediate output.

At each generation step, it will prompt the user in an interactive menu to accept the answer, edit the prompt and/or the generated text, or generate another answer with the same prompt.

The LLM expects the user to provide examples of the input data format. If no examples are present, the LLM will make assumptions about the data format based on the problem description.
Expand Down
161 changes: 49 additions & 112 deletions src/llm.jl
Original file line number Diff line number Diff line change
@@ -1,70 +1,76 @@
const GROQ_URL::String = "https://api.groq.com/openai/v1/chat/completions"
const GEMINI_URL::String = "https://generativelanguage.googleapis.com/v1beta/models/{{model_id}}:generateContent"
const GEMINI_URL_STREAM::String = "https://generativelanguage.googleapis.com/v1beta/models/{{model_id}}:streamGenerateContent?alt=sse"
const GEMINI_URL::String = "https://generativelanguage.googleapis.com/v1beta/models/{{model_id}}"

abstract type AbstractLLM end
abstract type OpenAILLM <: AbstractLLM end

"""
GroqLLM
Structure encapsulating the parameters for accessing the Groq LLM API.
- `api_key`: an API key for accessing the Groq API (https://groq.com), read from the environmental variable GROQ_API_KEY
- `api_key`: an API key for accessing the Groq API (https://groq.com), read from the environmental variable GROQ_API_KEY.
- `model_id`: a string identifier for the model to query. See https://console.groq.com/docs/models for the list of available models.
- `url`: URL for chat completions. Defaults to "https://api.groq.com/openai/v1/chat/completions".
"""
struct GroqLLM <: AbstractLLM
struct GroqLLM <: OpenAILLM
api_key::String
model_id::String
url::String

function GroqLLM(model_id::String = "llama-3.1-8b-instant")
function GroqLLM(model_id::String = "llama3-70b-8192", url = GROQ_URL)
api_key = get(ENV, "GROQ_API_KEY", "")
if isempty(api_key)
error("Environment variable GROQ_API_KEY is not set")
end
new(api_key, model_id)
new(api_key, model_id, url)
end
end

"""
Google LLM
Structure encapsulating the parameters for accessing the Google LLM API.
- `api_key`: an API key for accessing the Google Gemini API (https://ai.google.dev/gemini-api/docs/), read from the environmental variable GOOGLE_API_KEY
- `api_key`: an API key for accessing the Google Gemini API (https://ai.google.dev/gemini-api/docs/), read from the environmental variable GOOGLE_API_KEY.
- `model_id`: a string identifier for the model to query. See https://ai.google.dev/gemini-api/docs/models/gemini for the list of available models.
- `url`: URL for chat completions. Defaults to ""https://generativelanguage.googleapis.com/v1beta/models/{{model_id}}".
"""
struct GoogleLLM <: AbstractLLM
api_key::String
model_id::String
url::String

function GoogleLLM(model_id::String = "gemini-1.5-flash")
api_key = get(ENV, "GOOGLE_API_KEY", "")
if isempty(api_key)
error("Environment variable GOOGLE_API_KEY is not set")
end
new(api_key, model_id)
new(api_key, model_id, GEMINI_URL)
end
end

"""
LlamaCppLLM
Structure encapsulating the parameters for accessing the llama.cpp server API.
- `api_key`: an optional API key for accessing the server
- `model_id`: a string identifier for the model to query. Unused, kept for API compatibility.
- `url`: the URL of the llama.cpp server OpenAI API endpoint (e.g., http://localhost:8080)
NOTE: we do not apply the appropriate chat templates to the prompt.
This must be handled either in an external code path or by the server.
"""
struct LlamaCppLLM <: AbstractLLM
struct LlamaCppLLM <: OpenAILLM
api_key::String
model_id::String
url::String

function LlamaCppLLM(url::String)
api_key = get(ENV, "LLAMA_CPP_API_KEY", "no-key")
new(api_key, url)
new(api_key, "hal-9000-v2", url)
end
end

"""
get_completion(llm::GroqLLM, prompt::Prompt)
Returns a completion for the given prompt using the Groq LLM API.
get_completion(llm::OpenAILLM, prompt::Prompt)
Returns a completion for the given prompt using an OpenAI API compatible LLM
"""
function get_completion(llm::GroqLLM, prompt::Prompt)
function get_completion(llm::OpenAILLM, prompt::Prompt)
headers = [
"Authorization" => "Bearer $(llm.api_key)",
"Content-Type" => "application/json",
Expand All @@ -76,7 +82,7 @@ function get_completion(llm::GroqLLM, prompt::Prompt)
],
"model" => llm.model_id,
))
response = HTTP.post(GROQ_URL, headers, body)
response = HTTP.post(llm.url, headers, body)
body = JSON3.read(response.body)
return body["choices"][1]["message"]["content"]
end
Expand All @@ -86,7 +92,8 @@ end
Returns a completion for the given prompt using the Google Gemini LLM API.
"""
function get_completion(llm::GoogleLLM, prompt::Prompt)
url = replace(GEMINI_URL, "{{model_id}}" => llm.model_id)
url = replace(llm.url, "{{model_id}}" => llm.model_id)
url *= ":generateContent"
headers = [
"x-goog-api-key" => "$(llm.api_key)",
"Content-Type" => "application/json",
Expand All @@ -102,85 +109,11 @@ function get_completion(llm::GoogleLLM, prompt::Prompt)
end

"""
get_completion(llm::LlamaCppLLM, prompt::Prompt)
Returns a completion for the given prompt using the llama.cpp server API.
"""
function get_completion(llm::LlamaCppLLM, prompt::Prompt)
url = join([llm.url, "v1/chat/completions"], "/")
header = [
"Authorization" => "Bearer $(llm.api_key)",
"Content-Type" => "application/json",
]
body = JSON3.write(Dict(
"messages" => [
Dict("role" => "system", "content" => prompt.system),
Dict("role" => "user", "content" => prompt.user),
],
))
response = HTTP.post(url, header, body)
body = JSON3.read(response.body)
return body["choices"][1]["message"]["content"]
end

"""
stream_completion(llm::LlamaCppLLM, prompt::Prompt)
Returns a completion for the given prompt using the Groq LLM API.
The completion is streamed to the terminal as it is generated.
"""
function stream_completion(llm::LlamaCppLLM, prompt::Prompt)
url = join([llm.url, "v1/chat/completions"], "/")
headers = [
"Authorization" => "Bearer $(llm.api_key)",
"Content-Type" => "application/json",
]
body = JSON3.write(Dict(
"messages" => [
Dict("role" => "system", "content" => prompt.system),
Dict("role" => "user", "content" => prompt.user),
],
"stream" => true,
))

accumulated_content = ""
event_buffer = ""

HTTP.open(:POST, url, headers; body = body) do io
write(io, body)
HTTP.closewrite(io)
HTTP.startread(io)
while !eof(io)
chunk = String(readavailable(io))
events = split(chunk, "\n\n")
if !endswith(event_buffer, "\n\n")
event_buffer = events[end]
events = events[1:(end - 1)]
else
event_buffer = ""
end
events = join(events, "\n")
for line in eachmatch(r"(?<=data: ).*", events, overlap = true)
if line.match == "[DONE]"
print("\n")
break
end
message = JSON3.read(line.match)
if !isempty(message["choices"][1]["delta"])
print(message["choices"][1]["delta"]["content"])
accumulated_content *= message["choices"][1]["delta"]["content"]
end
end
end
HTTP.closeread(io)
end
return accumulated_content
end

"""
stream_completion(llm::GroqLLM, prompt::Prompt)
Returns a completion for the given prompt using the Groq LLM API.
stream_completion(llm::OpenAILLM, prompt::Prompt)
Returns a completion for the given prompt using an OpenAI API compatible model.
The completion is streamed to the terminal as it is generated.
"""
function stream_completion(llm::GroqLLM, prompt::Prompt)
function stream_completion(llm::OpenAILLM, prompt::Prompt)
headers = [
"Authorization" => "Bearer $(llm.api_key)",
"Content-Type" => "application/json",
Expand All @@ -197,29 +130,32 @@ function stream_completion(llm::GroqLLM, prompt::Prompt)
accumulated_content = ""
event_buffer = ""

HTTP.open(:POST, GROQ_URL, headers; body = body) do io
HTTP.open(:POST, llm.url, headers; body = body) do io
write(io, body)
HTTP.closewrite(io)
HTTP.startread(io)
while !eof(io)
chunk = String(readavailable(io))
events = split(chunk, "\n\n")
event_buffer *= chunk
events = split(event_buffer, "\n\n")
if !endswith(event_buffer, "\n\n")
event_buffer = events[end]
events = events[1:(end - 1)]
else
event_buffer = ""
end
events = join(events, "\n")
for line in eachmatch(r"(?<=data: ).*", events, overlap = true)
if line.match == "[DONE]"
print("\n")
break
end
message = JSON3.read(line.match)
if !isempty(message["choices"][1]["delta"])
print(message["choices"][1]["delta"]["content"])
accumulated_content *= message["choices"][1]["delta"]["content"]

for event in events
for line in eachmatch(r"(?<=data: ).*", event)
if line.match == "[DONE]"
print("\n")
return accumulated_content
end
message = JSON3.read(line.match)
if !isempty(message["choices"][1]["delta"])
print(message["choices"][1]["delta"]["content"])
accumulated_content *= message["choices"][1]["delta"]["content"]
end
end
end
end
Expand All @@ -234,7 +170,8 @@ Returns a completion for the given prompt using the Google Gemini LLM API.
The completion is streamed to the terminal as it is generated.
"""
function stream_completion(llm::GoogleLLM, prompt::Prompt)
url = replace(GEMINI_URL_STREAM, "{{model_id}}" => llm.model_id)
url = replace(llm.url, "{{model_id}}" => llm.model_id)
url *= ":streamGenerateContent?alt=sse"
headers = [
"x-goog-api-key" => "$(llm.api_key)",
"Content-Type" => "application/json",
Expand All @@ -253,14 +190,14 @@ function stream_completion(llm::GoogleLLM, prompt::Prompt)
HTTP.startread(io)
while !eof(io)
chunk = String(readavailable(io))
line = match(r"(?<=data: ).*", chunk)
if isnothing(line)
print("\n")
break
for line in eachmatch(r"(?<=data: ).*", chunk)
if isnothing(line)
continue
end
message = JSON3.read(line.match)
print(message["candidates"][1]["content"]["parts"][1]["text"])
accumulated_content *= String(message["candidates"][1]["content"]["parts"][1]["text"])
end
message = JSON3.read(line.match)
print(message["candidates"][1]["content"]["parts"][1]["text"])
accumulated_content *= String(message["candidates"][1]["content"]["parts"][1]["text"])
end
HTTP.closeread(io)
end
Expand Down
4 changes: 2 additions & 2 deletions src/translate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,15 @@ function fix_syntax_errors(model::AbstractLLM, code::AbstractString, error::Abst
end

"""
translate(model::AbstractLLM, description::AbstractString)
translate(model::AbstractLLM, description::AbstractString; interactive::Bool = false)
Translate the natural-language `description` of an optimization problem into
a Constraint Programming model by querying the Large Language Model `model`.
If `interactive`, the user will be prompted via the command line to inspect the
intermediate outputs of the LLM, and possibly modify them.
"""
function translate(
model::AbstractLLM,
description::AbstractString,
description::AbstractString;
interactive::Bool = false,
)
constraints = String[]
Expand Down
2 changes: 1 addition & 1 deletion templates/FixJuliaSyntax.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"_type": "metadatamessage"
},
{
"content": "You are an AI assistant specialized in writing Julia code. Your task is to examine a given code snippet alongside an error message related to syntax errors, and provide an updated version of the code snippet with the syntax errors resolved. \nIMPORTANT: 1. You must only fix the syntax errors without changing the functionality of the code.\n2. Think step-by-step, first describing the syntax errors in a bulleted list, and then providing the corrected code snippet in a Julia code block.\n3. You must report the complete code with the fix.",
"content": "You are an AI assistant specialized in writing Julia code. Your task is to examine a given code snippet alongside an error message related to syntax errors, and provide an updated version of the code snippet with the syntax errors resolved. \nIMPORTANT: 1. You must only fix the syntax errors without changing the functionality of the code.\n2. Think step-by-step, first describing the syntax errors in a bulleted list, and then providing the corrected code snippet in a Julia code block (i.e., ```julia [your code here] ```.\n3. You must report the complete code with the fix.",
"variables": [],
"_type": "systemmessage"
},
Expand Down
2 changes: 1 addition & 1 deletion templates/JumpifyModel.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"_type": "metadatamessage"
},
{
"content": "You are an AI assistant specialized in modeling Constraint Programming problems. Your task is to examine a given description of a Constraint Programming model and provide a code implementation in Julia, using JuMP and the CBLS solver. The code MUST: 1) Read the input data from external files into data structures according to the specifications provided in the description, using the appropriate Julia packages (e.g., DataFrames.jl, CSV.jl, etc.), 2) build the model, and 3) return the model.\nConstraints MUST be expressed with the following JuMP syntax: `@constraint(model, x in ConstraintName(kwargs))`, where `x` is a vector of variables, `ConstraintName` is the name of the constraint in camel-case (example: all different constraint -> AllDifferent()), and `kwargs` are the keyword arguments for the constraint (example: Sum(op=<=, val=10).\nIMPORTANT: 1. Output only the required function with no additional text or usage examples.\n2. You must write a docstring for the code.\n3. The code must be succinct and capture all the described constraints.\n4. You MUST use the provide syntax to express constraints. Do NOT express constraints in algebraic form.\n\n{{examples}}",
"content": "You are an AI assistant specialized in modeling Constraint Programming problems. Your task is to examine a given description of a Constraint Programming model and provide a code implementation in Julia, using JuMP and the CBLS solver. The code MUST: 1) Read the input data from external files into data structures according to the specifications provided in the description, using the appropriate Julia packages (e.g., DataFrames.jl, CSV.jl, etc.), 2) build the model, and 3) return the model.\nConstraints MUST be expressed with the following JuMP syntax: `@constraint(model, x in ConstraintName(kwargs))`, where `x` is a vector of variables, `ConstraintName` is the name of the constraint in camel-case (example: all different constraint -> AllDifferent()), and `kwargs` are the keyword arguments for the constraint (example: Sum(op=<=, val=10).\nIMPORTANT: 1. Output only the required function with no additional text or usage examples. The code must be wrapped in a Julia code block (i.e., ```julia [your code here] ```).\n2. You must write a docstring for the code.\n3. The code must be succinct and capture all the described constraints.\n4. You MUST use the provide syntax to express constraints. Do NOT express constraints in algebraic form.\n\n{{examples}}",
"variables": [
"examples"
],
Expand Down
6 changes: 4 additions & 2 deletions test/JET.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
@testset "Code linting (JET.jl)" begin
JET.test_package(ConstraintsTranslator; target_defined_modules = true)
end
if VERSION v"1.10"
JET.test_package(ConstraintsTranslator; target_defined_modules = true)
end
end

0 comments on commit 67a7702

Please sign in to comment.