Skip to content
This repository has been archived by the owner on Feb 12, 2024. It is now read-only.

Feature/iterate on vllm #20

Merged
merged 5 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sciphi/examples/data_generation/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,12 @@ def generate_random_hash() -> str:
llm_config,
)

# Initialize the prompt generator & data maker
# Initialize the prompt generator
data_config = DataConfig(
os.path.join(
get_data_config_dir(), f"{args.example_config}", "main.yaml"
)
)

prompt_generator = PromptGenerator(
data_config.config,
data_config.prompt_templates,
Expand All @@ -124,6 +123,7 @@ def generate_random_hash() -> str:
structure=PromptStructure.SINGLE,
)

# Initialize the data maker
data_maker = DataMaker(
DataGeneratorMode(data_config.generator_mode),
prompt_generator,
Expand Down
9 changes: 9 additions & 0 deletions sciphi/interface/vllm_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A module for interfacing with local vLLM models"""
import logging
from typing import List

from sciphi.interface.base import LLMInterface, ProviderName
from sciphi.interface.interface_manager import llm_provider
Expand Down Expand Up @@ -28,6 +29,14 @@ def get_completion(self, prompt: str) -> str:
)
return self.model.get_instruct_completion(prompt)

def get_batch_completion(self, prompts: List[str]) -> List[str]:
"""Get a completion from the local vLLM provider."""

logger.debug(
f"Requesting completion from local vLLM with model={self._model.config.model_name} and prompts={prompts}"
)
return self.model.get_batch_instruct_completion(prompts)

@property
def model(self) -> vLLM:
return self._model
17 changes: 10 additions & 7 deletions sciphi/llm/vllm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def __init__(
) -> None:
super().__init__(config)
try:
from vllm import LLM as vLLM
from vllm import LLM as vvLLM
from vllm import SamplingParams
except ImportError:
raise ImportError(
"Please install the vllm package before attempting to run with an vLLM model. This can be accomplished via `poetry install -E vllm_support, ...OTHER_DEPENDENCIES_HERE`."
)
self.model = vLLM(model=config.model_name)
self.model = vvLLM(model=config.model_name)
self.sampling_params = SamplingParams(
temperature=config.temperature,
top_p=config.top_p,
Expand All @@ -53,8 +53,11 @@ def get_chat_completion(self, messages: list[dict[str, str]]) -> str:
)

def get_instruct_completion(self, prompt: str) -> str:
"""Get an instruction completion from the OpenAI API based on the provided prompt."""
# outputs = self.model.generate([prompt], self.sampling_params)
raise NotImplementedError(
"Instruction completion not yet implemented for vLLM."
)
"""Get an instruction completion from local vLLM API."""
outputs = self.model.generate([prompt], self.sampling_params)
return outputs[0].outputs[0].text

def get_batch_instruct_completion(self, prompts: list[str]) -> list[str]:
"""Get batch instruction completion from local vLLM."""
raw_outputs = self.model.generate(prompts, self.sampling_params)
return [ele.outputs[0].text for ele in raw_outputs]
Loading