Skip to content

Commit

Permalink
Merge pull request #155 from NexaAI/perry/customize-nctx
Browse files Browse the repository at this point in the history
Feature implementation for issue #109 : supported argument -cm --context_maximum for text and VLM inferences
  • Loading branch information
zhiyuan8 authored Oct 8, 2024
2 parents 666230f + b58bc65 commit 67f1370
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
1 change: 1 addition & 0 deletions nexa/cli/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def main():
text_group.add_argument("-p", "--top_p", type=float, help="Top-p sampling parameter")
text_group.add_argument("-sw", "--stop_words", nargs="*", help="List of stop words for early stopping")
text_group.add_argument("--lora_path", type=str, help="Path to a LoRA file to apply to the model.")
text_group.add_argument("-cm", "--context_maximum", type=int, default=2048, help="Maximum context length of the model you're using")

# Image generation arguments
image_group = run_parser.add_argument_group('Image generation options')
Expand Down
1 change: 1 addition & 0 deletions nexa/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ class ModelType(Enum):
DEFAULT_TEXT_GEN_PARAMS = {
"temperature": 0.7,
"max_new_tokens": 2048,
"context_maximum": 2048,
"top_k": 50,
"top_p": 1.0,
}
Expand Down
12 changes: 10 additions & 2 deletions nexa/gguf/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def create_embedding(
def _load_model(self):
logging.debug(f"Loading model from {self.downloaded_path}, use_cuda_or_metal : {is_gpu_available()}")
start_time = time.time()
print("context_maximum: ", self.params.get("context_maximum", 2048))
with suppress_stdout_stderr():
from nexa.gguf.llama.llama import Llama
try:
Expand All @@ -105,7 +106,7 @@ def _load_model(self):
model_path=self.downloaded_path,
verbose=self.profiling,
chat_format=self.chat_format,
n_ctx=2048,
n_ctx=self.params.get("context_maximum", 2048),
n_gpu_layers=-1 if is_gpu_available() else 0,
lora_path=self.params.get("lora_path", ""),
)
Expand All @@ -115,7 +116,7 @@ def _load_model(self):
model_path=self.downloaded_path,
verbose=self.profiling,
chat_format=self.chat_format,
n_ctx=2048,
n_ctx=self.params.get("context_maximum", 2048),
n_gpu_layers=0, # hardcode to use CPU
lora_path=self.params.get("lora_path", ""),
)
Expand Down Expand Up @@ -321,6 +322,13 @@ def run_streamlit(self, model_path: str, is_local_path = False, hf = False):
parser.add_argument(
"-p", "--top_p", type=float, default=1.0, help="Top-p sampling parameter"
)
parser.add_argument(
"-cm",
"--context_maximum",
type=int,
default=2048,
help="Maximum context length of the model you're using"
)
parser.add_argument(
"-sw",
"--stop_words",
Expand Down
11 changes: 9 additions & 2 deletions nexa/gguf/nexa_inference_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _load_model(self):
chat_handler=self.projector,
verbose=False,
chat_format=self.chat_format,
n_ctx=2048,
n_ctx=self.params.get("context_maximum", 2048),
n_gpu_layers=-1 if is_gpu_available() else 0,
)
except Exception as e:
Expand All @@ -181,7 +181,7 @@ def _load_model(self):
chat_handler=self.projector,
verbose=False,
chat_format=self.chat_format,
n_ctx=2048,
n_ctx=self.params.get("context_maximum", 2048),
n_gpu_layers=0, # hardcode to use CPU
)

Expand Down Expand Up @@ -369,6 +369,13 @@ def run_streamlit(self, model_path: str, is_local_path = False, hf = False, proj
parser.add_argument(
"-p", "--top_p", type=float, default=1.0, help="Top-p sampling parameter"
)
parser.add_argument(
"-cm",
"--context_maximum",
type=int,
default=2048,
help="Maximum context length of the model you're using"
)
parser.add_argument(
"-sw",
"--stop_words",
Expand Down

0 comments on commit 67f1370

Please sign in to comment.