Skip to content
This repository has been archived by the owner on Feb 12, 2024. It is now read-only.

Commit

Permalink
Feature/iterate on vllm (#20)
Browse files Browse the repository at this point in the history
* test completion

* iter

* iter

* fix batch inference

* fix batch inference
  • Loading branch information
emrgnt-cmplxty authored Sep 21, 2023
1 parent 3a2531f commit baf57b0
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
4 changes: 2 additions & 2 deletions sciphi/examples/data_generation/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,12 @@ def generate_random_hash() -> str:
llm_config,
)

# Initialize the prompt generator & data maker
# Initialize the prompt generator
data_config = DataConfig(
os.path.join(
get_data_config_dir(), f"{args.example_config}", "main.yaml"
)
)

prompt_generator = PromptGenerator(
data_config.config,
data_config.prompt_templates,
Expand All @@ -124,6 +123,7 @@ def generate_random_hash() -> str:
structure=PromptStructure.SINGLE,
)

# Initialize the data maker
data_maker = DataMaker(
DataGeneratorMode(data_config.generator_mode),
prompt_generator,
Expand Down
9 changes: 9 additions & 0 deletions sciphi/interface/vllm_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A module for interfacing with local vLLM models"""
import logging
from typing import List

from sciphi.interface.base import LLMInterface, ProviderName
from sciphi.interface.interface_manager import llm_provider
Expand Down Expand Up @@ -28,6 +29,14 @@ def get_completion(self, prompt: str) -> str:
)
return self.model.get_instruct_completion(prompt)

def get_batch_completion(self, prompts: List[str]) -> List[str]:
"""Get a completion from the local vLLM provider."""

logger.debug(
f"Requesting completion from local vLLM with model={self._model.config.model_name} and prompts={prompts}"
)
return self.model.get_batch_instruct_completion(prompts)

@property
def model(self) -> vLLM:
return self._model
17 changes: 10 additions & 7 deletions sciphi/llm/vllm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def __init__(
) -> None:
super().__init__(config)
try:
from vllm import LLM as vLLM
from vllm import LLM as vvLLM
from vllm import SamplingParams
except ImportError:
raise ImportError(
"Please install the vllm package before attempting to run with an vLLM model. This can be accomplished via `poetry install -E vllm_support, ...OTHER_DEPENDENCIES_HERE`."
)
self.model = vLLM(model=config.model_name)
self.model = vvLLM(model=config.model_name)
self.sampling_params = SamplingParams(
temperature=config.temperature,
top_p=config.top_p,
Expand All @@ -53,8 +53,11 @@ def get_chat_completion(self, messages: list[dict[str, str]]) -> str:
)

def get_instruct_completion(self, prompt: str) -> str:
"""Get an instruction completion from the OpenAI API based on the provided prompt."""
# outputs = self.model.generate([prompt], self.sampling_params)
raise NotImplementedError(
"Instruction completion not yet implemented for vLLM."
)
"""Get an instruction completion from local vLLM API."""
outputs = self.model.generate([prompt], self.sampling_params)
return outputs[0].outputs[0].text

def get_batch_instruct_completion(self, prompts: list[str]) -> list[str]:
"""Get batch instruction completion from local vLLM."""
raw_outputs = self.model.generate(prompts, self.sampling_params)
return [ele.outputs[0].text for ele in raw_outputs]

0 comments on commit baf57b0

Please sign in to comment.