diff --git a/sciphi/examples/data_generation/runner.py b/sciphi/examples/data_generation/runner.py index c7bec6e..5013286 100644 --- a/sciphi/examples/data_generation/runner.py +++ b/sciphi/examples/data_generation/runner.py @@ -100,13 +100,12 @@ def generate_random_hash() -> str: llm_config, ) - # Initialize the prompt generator & data maker + # Initialize the prompt generator data_config = DataConfig( os.path.join( get_data_config_dir(), f"{args.example_config}", "main.yaml" ) ) - prompt_generator = PromptGenerator( data_config.config, data_config.prompt_templates, @@ -124,6 +123,7 @@ def generate_random_hash() -> str: structure=PromptStructure.SINGLE, ) + # Initialize the data maker data_maker = DataMaker( DataGeneratorMode(data_config.generator_mode), prompt_generator, diff --git a/sciphi/interface/vllm_interface.py b/sciphi/interface/vllm_interface.py index 58c1587..d5a8767 100644 --- a/sciphi/interface/vllm_interface.py +++ b/sciphi/interface/vllm_interface.py @@ -1,5 +1,6 @@ """A module for interfacing with local vLLM models""" import logging +from typing import List from sciphi.interface.base import LLMInterface, ProviderName from sciphi.interface.interface_manager import llm_provider @@ -28,6 +29,14 @@ def get_completion(self, prompt: str) -> str: ) return self.model.get_instruct_completion(prompt) + def get_batch_completion(self, prompts: List[str]) -> List[str]: + """Get a completion from the local vLLM provider.""" + + logger.debug( + f"Requesting completion from local vLLM with model={self._model.config.model_name} and prompts={prompts}" + ) + return self.model.get_batch_instruct_completion(prompts) + @property def model(self) -> vLLM: return self._model diff --git a/sciphi/llm/vllm_llm.py b/sciphi/llm/vllm_llm.py index d38b2f5..8258b48 100644 --- a/sciphi/llm/vllm_llm.py +++ b/sciphi/llm/vllm_llm.py @@ -32,13 +32,13 @@ def __init__( ) -> None: super().__init__(config) try: - from vllm import LLM as vLLM + from vllm import LLM as vvLLM from vllm import SamplingParams except ImportError: raise ImportError( "Please install the vllm package before attempting to run with an vLLM model. This can be accomplished via `poetry install -E vllm_support, ...OTHER_DEPENDENCIES_HERE`." ) - self.model = vLLM(model=config.model_name) + self.model = vvLLM(model=config.model_name) self.sampling_params = SamplingParams( temperature=config.temperature, top_p=config.top_p, @@ -53,8 +53,11 @@ def get_chat_completion(self, messages: list[dict[str, str]]) -> str: ) def get_instruct_completion(self, prompt: str) -> str: - """Get an instruction completion from the OpenAI API based on the provided prompt.""" - # outputs = self.model.generate([prompt], self.sampling_params) - raise NotImplementedError( - "Instruction completion not yet implemented for vLLM." - ) + """Get an instruction completion from local vLLM API.""" + outputs = self.model.generate([prompt], self.sampling_params) + return outputs[0].outputs[0].text + + def get_batch_instruct_completion(self, prompts: list[str]) -> list[str]: + """Get batch instruction completion from local vLLM.""" + raw_outputs = self.model.generate(prompts, self.sampling_params) + return [ele.outputs[0].text for ele in raw_outputs]