diff --git a/sciphi/llm/vllm_llm.py b/sciphi/llm/vllm_llm.py index 67a21ce..8258b48 100644 --- a/sciphi/llm/vllm_llm.py +++ b/sciphi/llm/vllm_llm.py @@ -59,5 +59,5 @@ def get_instruct_completion(self, prompt: str) -> str: def get_batch_instruct_completion(self, prompts: list[str]) -> list[str]: """Get batch instruction completion from local vLLM.""" - raw_outputs = self.model.generate([prompts], self.sampling_params) + raw_outputs = self.model.generate(prompts, self.sampling_params) return [ele.outputs[0].text for ele in raw_outputs]