-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
52 lines (39 loc) · 1.91 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import pandas as pd
from dotenv import load_dotenv
load_dotenv()
from utils import LLMS, PROMPTS, TASKS, batch_format_prompt, completion, setup_logger
logger = setup_logger(__name__)
def main(task: str, model_name: str):
data = pd.read_json(TASKS[task], lines=True)
prompt = PROMPTS["stack-inf"]
messages = batch_format_prompt(prompt, data)
logger.info("Loaded %s messages for inference on task %s.", len(messages), task)
if model_name not in LLMS:
raise ValueError(f"Model {model_name} not found in the configuration.")
model = LLMS[model_name]
logger.info("Launching batch completion using model %s...", model_name)
data["completion"] = completion(
messages,
custom_llm_provider=model['custom_llm_provider'],
**model['model_parameters'], **model['sample_parameters'],
num_retries=3, timeout=60
)
empty_completions = (data["completion"].str.len() == 0).sum()
if empty_completions > 0:
logger.warning("Found %s empty completions.", empty_completions)
data["model"] = [model["model_parameters"]["model"] for _ in range(len(data))]
data["modelMetadata"] = [model["model_metadata"] for _ in range(len(data))]
output_path = f"output/{task}/inf/{model_name}.jsonl"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
data.to_json(output_path, lines=True, orient="records")
logger.info("Saved completions to %s", output_path)
if __name__ == "__main__":
import argparse
#fmt: off
parser = argparse.ArgumentParser(description="Run inference on a task using a specified model.")
parser.add_argument("--task", "-t", required=True, type=str, choices=TASKS.keys(), help="The task to run inference on.")
parser.add_argument( "--model", "-m", required=True, type=str, choices=LLMS.keys(), help="The model to use for inference.")
args = parser.parse_args()
#fmt: on
main(args.task, args.model)