Skip to content

Commit

Permalink
Merge pull request #89 from codelion/feat-add-local-inference
Browse files Browse the repository at this point in the history
Feat add local inference
  • Loading branch information
codelion authored Nov 13, 2024
2 parents 476719c + ad90fd8 commit 7381008
Show file tree
Hide file tree
Showing 4 changed files with 1,586 additions and 39 deletions.
75 changes: 56 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,6 @@ python optillm.py
* Running on http://192.168.10.48:8000
2024-09-06 07:57:14,212 - INFO - Press CTRL+C to quit
```

### Starting the optillm proxy for a local server (e.g. llama.cpp)

- Set the `OPENAI_API_KEY` env variable to a placeholder value
- e.g. `export OPENAI_API_KEY="no_key"`
- Run `./llama-server -c 4096 -m path_to_model` to start the server with the specified model and a context length of 4096 tokens
- Run `python3 optillm.py --base_url base_url` to start the proxy
- e.g. for llama.cpp, run `python3 optillm.py --base_url http://localhost:8080/v1`

> [!WARNING]
> Note that llama-server currently does not support sampling multiple responses from a model, which limits the available approaches to the following:
> `cot_reflection`, `leap`, `plansearch`, `rstar`, `rto`, `self_consistency`, `re2`, and `z3`.
> [!NOTE]
> You'll later need to specify a model name in the OpenAI client configuration. Since llama-server was started with a single model, you can choose any name you want.
## Usage

Once the proxy is running, you can use it as a drop in replacement for an OpenAI client by setting the `base_url` as `http://localhost:8000/v1`.
Expand Down Expand Up @@ -155,7 +139,60 @@ In the diagram:
- `A` is an existing tool (like [oobabooga](https://github.com/oobabooga/text-generation-webui/)), framework (like [patchwork](https://github.com/patched-codes/patchwork))
or your own code where you want to use the results from optillm. You can use it directly using any OpenAI client sdk.
- `B` is the optillm service (running directly or in a docker container) that will send requests to the `base_url`.
- `C` is any service providing an OpenAI API compatible chat completions endpoint.
- `C` is any service providing an OpenAI API compatible chat completions endpoint.

### Local inference server

We support loading any HuggingFace model or LoRA directly in optillm. To use the built-in inference server set the `OPTILLM_API_KEY` to any value (e.g. `export OPTILLM_API_KEY="optillm"`)
and then use the same in your OpenAI client. You can pass any HuggingFace model in model field. If it is a private model make sure you set the `HF_TOKEN` environment variable
with your HuggingFace key. We also support adding any number of LoRAs on top of the model by using the `+` separator.

E.g. The following code loads the base model `meta-llama/Llama-3.2-1B-Instruct` and then adds two LoRAs on top - `patched-codes/Llama-3.2-1B-FixVulns` and `patched-codes/Llama-3.2-1B-FastApply`.
You can specify which LoRA to use using the `active_adapter` param in `extra_args` field of OpenAI SDK client. By default we will load the last specified adapter.

```python
OPENAI_BASE_URL = "http://localhost:8000/v1"
OPENAI_KEY = "optillm"
response = client.chat.completions.create(
model="meta-llama/Llama-3.2-1B-Instruct+patched-codes/Llama-3.2-1B-FastApply+patched-codes/Llama-3.2-1B-FixVulns",
messages=messages,
temperature=0.2,
logprobs = True,
top_logprobs = 3,
extra_body={"active_adapter": "patched-codes/Llama-3.2-1B-FastApply"},
)
```

You can also use the alternate decoding techniques like `cot_decoding` and `entropy_decoding` directly with the local inference server.

```python
response = client.chat.completions.create(
model="meta-llama/Llama-3.2-1B-Instruct",
messages=messages,
temperature=0.2,
extra_body={
"decoding": "cot_decoding", # or "entropy_decoding"
# CoT specific params
"k": 10,
"aggregate_paths": True,
# OR Entropy specific params
"top_k": 27,
"min_p": 0.03,
}
)
```

### Starting the optillm proxy with an external server (e.g. llama.cpp or ollama)

- Set the `OPENAI_API_KEY` env variable to a placeholder value
- e.g. `export OPENAI_API_KEY="sk-no-key"`
- Run `./llama-server -c 4096 -m path_to_model` to start the server with the specified model and a context length of 4096 tokens
- Run `python3 optillm.py --base_url base_url` to start the proxy
- e.g. for llama.cpp, run `python3 optillm.py --base_url http://localhost:8080/v1`

> [!WARNING]
> Note that llama-server (and ollama) currently does not support sampling multiple responses from a model, which limits the available approaches to the following:
> `cot_reflection`, `leap`, `plansearch`, `rstar`, `rto`, `self_consistency`, `re2`, and `z3`. Use the built-in local inference server to use these approaches.
## Implemented techniques

Expand Down Expand Up @@ -256,9 +293,9 @@ Authorization: Bearer your_secret_api_key
### readurls&memory-gpt-4o-mini on Google FRAMES Benchmark (Oct 2024)
| Model | Accuracy |
| ----- | -------- |
| readlurls&memory-gpt-4o-mini | 65.66 |
| readurls&memory-gpt-4o-mini | 65.66 |
| gpt-4o-mini | 50.0 |
| readlurls&memory-Gemma2-9b | 30.1 |
| readurls&memory-Gemma2-9b | 30.1 |
| Gemma2-9b | 5.1 |
| Gemma2-27b | 30.8 |
| Gemini Flash 1.5 | 66.5 |
Expand Down
112 changes: 93 additions & 19 deletions optillm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import asyncio
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple, Optional, Union, Dict, Any, List

# Import approach modules
from optillm.mcts import chat_with_mcts
Expand Down Expand Up @@ -43,8 +44,13 @@

def get_config():
API_KEY = None
if os.environ.get("OPTILLM_API_KEY"):
# Use local inference engine
from optillm.inference import create_inference_client
API_KEY = os.environ.get("OPTILLM_API_KEY")
default_client = create_inference_client()
# OpenAI, Azure, or LiteLLM API configuration
if os.environ.get("OPENAI_API_KEY"):
elif os.environ.get("OPENAI_API_KEY"):
API_KEY = os.environ.get("OPENAI_API_KEY")
base_url = server_config['base_url']
if base_url != "":
Expand Down Expand Up @@ -78,7 +84,7 @@ def get_config():

# Server configuration
server_config = {
'approach': 'bon',
'approach': 'none',
'mcts_simulations': 2,
'mcts_exploration': 0.2,
'mcts_depth': 1,
Expand All @@ -96,11 +102,52 @@ def get_config():
}

# List of known approaches
known_approaches = ["mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar",
"cot_reflection", "plansearch", "leap", "re2"]
known_approaches = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency",
"pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]

plugin_approaches = {}

def none_approach(
client: Any,
model: str,
original_messages: List[Dict[str, str]],
**kwargs
) -> Dict[str, Any]:
"""
Direct proxy approach that passes through all parameters to the underlying endpoint.
Args:
system_prompt: System prompt text (unused)
initial_query: Initial query/conversation (unused)
client: OpenAI client instance
model: Model identifier
original_messages: Original messages from the request
**kwargs: Additional parameters to pass through
Returns:
Dict[str, Any]: Full OpenAI API response
"""
# Strip 'none-' prefix from model if present
if model.startswith('none-'):
model = model[5:]

try:
# Make the direct completion call with original messages and parameters
response = client.chat.completions.create(
model=model,
messages=original_messages,
**kwargs
)

# Convert to dict if it's not already
if hasattr(response, 'model_dump'):
return response.model_dump()
return response

except Exception as e:
logger.error(f"Error in none approach: {str(e)}")
raise

def load_plugins():
# Clear existing plugins first but modify the global dict in place
plugin_approaches.clear()
Expand Down Expand Up @@ -158,7 +205,7 @@ def load_plugins():

def parse_combined_approach(model: str, known_approaches: list, plugin_approaches: dict):
if model == 'auto':
return 'SINGLE', ['bon'], model
return 'SINGLE', ['none'], model

parts = model.split('-')
approaches = []
Expand All @@ -183,7 +230,7 @@ def parse_combined_approach(model: str, known_approaches: list, plugin_approache
model_parts.append(part)

if not approaches:
approaches = ['bon']
approaches = ['none']
operation = 'SINGLE'

actual_model = '-'.join(model_parts)
Expand All @@ -192,8 +239,21 @@ def parse_combined_approach(model: str, known_approaches: list, plugin_approache

def execute_single_approach(approach, system_prompt, initial_query, client, model):
if approach in known_approaches:
# Execute known approaches
if approach == 'mcts':
if approach == 'none':
# Extract kwargs from the request data
kwargs = {}
if hasattr(request, 'json'):
data = request.get_json()
messages = data.get('messages', [])
# Copy all parameters except 'model' and 'messages'
kwargs = {k: v for k, v in data.items()
if k not in ['model', 'messages', 'optillm_approach']}
response = none_approach(original_messages=messages, client=client, model=model, **kwargs)

# For none approach, we return the response and a token count of 0
# since the full token count is already in the response
return response, 0
elif approach == 'mcts':
return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
server_config['mcts_exploration'], server_config['mcts_depth'])
elif approach == 'bon':
Expand Down Expand Up @@ -324,7 +384,6 @@ def proxy():
bearer_token = ""

if auth_header and auth_header.startswith("Bearer "):
# Extract the bearer token
bearer_token = auth_header.split("Bearer ")[1].strip()
logger.debug(f"Intercepted Bearer Token: {bearer_token}")

Expand Down Expand Up @@ -360,22 +419,37 @@ def proxy():
client = default_client

try:
# Check if any of the approaches is 'none'
contains_none = any(approach == 'none' for approach in approaches)

if operation == 'SINGLE' and approaches[0] == 'none':
# For none approach, return the response directly
result, _ = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
logger.debug(f'Direct proxy response: {result}')
return jsonify(result), 200

elif operation == 'AND' or operation == 'OR':
if contains_none:
raise ValueError("'none' approach cannot be combined with other approaches")

# Handle non-none approaches
if operation == 'SINGLE':
final_response, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
response, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
elif operation == 'AND':
final_response, completion_tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model)
response, completion_tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model)
elif operation == 'OR':
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
final_response, completion_tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model))
response, completion_tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model))
else:
raise ValueError(f"Unknown operation: {operation}")

except Exception as e:
logger.error(f"Error processing request: {str(e)}")
return jsonify({"error": str(e)}), 500

if stream:
return Response(generate_streaming_response(final_response, model), content_type='text/event-stream')
return Response(generate_streaming_response(response, model), content_type='text/event-stream')
else:
response_data = {
'model': model,
Expand All @@ -385,13 +459,13 @@ def proxy():
}
}

if isinstance(final_response, list):
for index, response in enumerate(final_response):
if isinstance(response, list):
for index, resp in enumerate(response):
response_data['choices'].append({
'index': index,
'message': {
'role': 'assistant',
'content': response,
'content': resp,
},
'finish_reason': 'stop'
})
Expand All @@ -400,13 +474,13 @@ def proxy():
'index': 0,
'message': {
'role': 'assistant',
'content': final_response,
'content': response,
},
'finish_reason': 'stop'
})

logger.debug(f'API response: {response_data}')
return jsonify(response_data), 200
logger.debug(f'API response: {response_data}')
return jsonify(response_data), 200

@app.route('/v1/models', methods=['GET'])
def proxy_models():
Expand Down
Loading

0 comments on commit 7381008

Please sign in to comment.