Skip to content

Commit

Permalink
Merge pull request #158 from NexaAI/perry/customize-nctx
Browse files Browse the repository at this point in the history
made server also support -cm --context_maximum input, and updated docs
  • Loading branch information
zhiyuan8 authored Oct 10, 2024
2 parents ab0ab06 + dc22834 commit a180e7c
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 16 deletions.
9 changes: 7 additions & 2 deletions CLI.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## CLI Reference

### Overview

```
usage: nexa [-h] [-V] {run,onnx,server,pull,remove,clean,list,login,whoami,logout} ...
Expand Down Expand Up @@ -116,6 +117,7 @@ Text generation options:
-p, --top_p TOP_P Top-p sampling parameter
-sw, --stop_words [STOP_WORDS ...]
List of stop words for early stopping
--nctx TEXT_CONTEXT Length of context window
```

##### Example
Expand Down Expand Up @@ -195,6 +197,7 @@ VLM generation options:
-p, --top_p TOP_P Top-p sampling parameter
-sw, --stop_words [STOP_WORDS ...]
List of stop words for early stopping
--nctx TEXT_CONTEXT Length of context window
```

##### Example
Expand Down Expand Up @@ -265,8 +268,9 @@ nexa server llama2
### Run Model Evaluation

Run evaluation using models on your local computer.

```
usage: nexa eval model_path [-h] [--tasks TASKS] [--limit LIMIT]
usage: nexa eval model_path [-h] [--tasks TASKS] [--limit LIMIT]
positional arguments:
model_path Path or identifier for the model in Nexa Model Hub
Expand All @@ -278,6 +282,7 @@ options:
```

#### Examples

```
nexa eval phi3 --tasks ifeval --limit 0.5
```
Expand All @@ -293,4 +298,4 @@ For `model_path` in nexa commands, it's better to follow the standard format to

- `gemma-2b:q4_0`
- `Meta-Llama-3-8B-Instruct:onnx-cpu-int8`
- `alanzhuly/Qwen2-1B-Instruct:q4_0`
- `liuhaotian/llava-v1.6-vicuna-7b:gguf-q4_0`
1 change: 1 addition & 0 deletions SERVER.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ usage: nexa server [-h] [--host HOST] [--port PORT] [--reload] model_path
- `--host`: Host to bind the server to
- `--port`: Port to bind the server to
- `--reload`: Enable automatic reloading on code changes
- `--nctx`: Maximum context length of the model you're using

### Example Commands:

Expand Down
4 changes: 2 additions & 2 deletions nexa/cli/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def main():
text_group.add_argument("-p", "--top_p", type=float, help="Top-p sampling parameter")
text_group.add_argument("-sw", "--stop_words", nargs="*", help="List of stop words for early stopping")
text_group.add_argument("--lora_path", type=str, help="Path to a LoRA file to apply to the model.")
text_group.add_argument("-cm", "--context_maximum", type=int, default=2048, help="Maximum context length of the model you're using")
text_group.add_argument("--nctx", type=int, default=2048, help="Maximum context length of the model you're using")

# Image generation arguments
image_group = run_parser.add_argument_group('Image generation options')
Expand Down Expand Up @@ -313,7 +313,7 @@ def main():
server_parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")
server_parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to")
server_parser.add_argument("--reload", action="store_true", help="Enable automatic reloading on code changes")
server_parser.add_argument("--nctx", type=int, default=2048, help="Length of context window")
server_parser.add_argument("--nctx", type=int, default=2048, help="Maximum context length of the model you're using")

# Other commands
pull_parser = subparsers.add_parser("pull", help="Pull a model from official or hub.")
Expand Down
2 changes: 1 addition & 1 deletion nexa/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class ModelType(Enum):
DEFAULT_TEXT_GEN_PARAMS = {
"temperature": 0.7,
"max_new_tokens": 2048,
"context_maximum": 2048,
"nctx": 2048,
"top_k": 50,
"top_p": 1.0,
}
Expand Down
8 changes: 3 additions & 5 deletions nexa/gguf/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def create_embedding(
def _load_model(self):
logging.debug(f"Loading model from {self.downloaded_path}, use_cuda_or_metal : {is_gpu_available()}")
start_time = time.time()
print("context_maximum: ", self.params.get("context_maximum", 2048))
with suppress_stdout_stderr():
from nexa.gguf.llama.llama import Llama
try:
Expand All @@ -106,7 +105,7 @@ def _load_model(self):
model_path=self.downloaded_path,
verbose=self.profiling,
chat_format=self.chat_format,
n_ctx=self.params.get("context_maximum", 2048),
n_ctx=self.params.get("nctx", 2048),
n_gpu_layers=-1 if is_gpu_available() else 0,
lora_path=self.params.get("lora_path", ""),
)
Expand All @@ -116,7 +115,7 @@ def _load_model(self):
model_path=self.downloaded_path,
verbose=self.profiling,
chat_format=self.chat_format,
n_ctx=self.params.get("context_maximum", 2048),
n_ctx=self.params.get("nctx", 2048),
n_gpu_layers=0, # hardcode to use CPU
lora_path=self.params.get("lora_path", ""),
)
Expand Down Expand Up @@ -323,8 +322,7 @@ def run_streamlit(self, model_path: str, is_local_path = False, hf = False):
"-p", "--top_p", type=float, default=1.0, help="Top-p sampling parameter"
)
parser.add_argument(
"-cm",
"--context_maximum",
"--nctx",
type=int,
default=2048,
help="Maximum context length of the model you're using"
Expand Down
7 changes: 3 additions & 4 deletions nexa/gguf/nexa_inference_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _load_model(self):
chat_handler=self.projector,
verbose=False,
chat_format=self.chat_format,
n_ctx=self.params.get("context_maximum", 2048),
n_ctx=self.params.get("nctx", 2048),
n_gpu_layers=-1 if is_gpu_available() else 0,
)
except Exception as e:
Expand All @@ -181,7 +181,7 @@ def _load_model(self):
chat_handler=self.projector,
verbose=False,
chat_format=self.chat_format,
n_ctx=self.params.get("context_maximum", 2048),
n_ctx=self.params.get("nctx", 2048),
n_gpu_layers=0, # hardcode to use CPU
)

Expand Down Expand Up @@ -370,8 +370,7 @@ def run_streamlit(self, model_path: str, is_local_path = False, hf = False, proj
"-p", "--top_p", type=float, default=1.0, help="Top-p sampling parameter"
)
parser.add_argument(
"-cm",
"--context_maximum",
"--nctx",
type=int,
default=2048,
help="Maximum context length of the model you're using"
Expand Down
4 changes: 2 additions & 2 deletions nexa/gguf/server/nexa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ async def load_model():
chat_handler=projector,
verbose=False,
chat_format=chat_format,
n_ctx=2048,
n_ctx=n_ctx,
n_gpu_layers=-1 if is_gpu_available() else 0,
)
except Exception as e:
Expand All @@ -295,7 +295,7 @@ async def load_model():
chat_handler=projector,
verbose=False,
chat_format=chat_format,
n_ctx=2048,
n_ctx=n_ctx,
n_gpu_layers=0, # hardcode to use CPU
)

Expand Down

0 comments on commit a180e7c

Please sign in to comment.