From 5a633e24e9cd24eeeab7efac85723afaa40fe52b Mon Sep 17 00:00:00 2001 From: Vladimir Kalnitsky Date: Wed, 13 Nov 2024 01:31:22 +0400 Subject: [PATCH] Pass MODEL around. use it to select the current model. Implement --list-models. --- python/src/eval.py | 3 ++- python/src/loader.py | 3 ++- python/src/main.py | 10 ++++++++++ python/src/providers/openai.py | 5 ++++- python/src/runtime.py | 4 ++-- 5 files changed, 20 insertions(+), 5 deletions(-) diff --git a/python/src/eval.py b/python/src/eval.py index e241684..ce4769f 100644 --- a/python/src/eval.py +++ b/python/src/eval.py @@ -59,11 +59,12 @@ async def eval_ast(ast, runtime): evaluated_parameters[parameter] = await _collect_exprs( parameters[parameter], runtime ) + if "MODEL" not in evaluated_parameters: + evaluated_parameters["MODEL"] = old_env.get("MODEL") runtime.env = Env(evaluated_parameters) async for expr in eval_ast(loaded_ast, runtime): yield expr runtime.env = old_env - elif ast["type"] == "assign": var_name = ast["name"] value = await _collect_exprs(ast["exprs"], runtime) diff --git a/python/src/loader.py b/python/src/loader.py index 5137645..7567c4b 100644 --- a/python/src/loader.py +++ b/python/src/loader.py @@ -7,7 +7,8 @@ def _discover_variables(ast): if ast["type"] == "comment": return elif ast["type"] == "var": - yield {"type": "var", "name": ast["name"]} + if ast["name"] != "MODEL": + yield {"type": "var", "name": ast["name"]} elif ast["type"] == "assign": yield {"type": "assign", "name": ast["name"]} for key in ast: diff --git a/python/src/main.py b/python/src/main.py index 36ce18e..b3b9929 100644 --- a/python/src/main.py +++ b/python/src/main.py @@ -40,6 +40,8 @@ def parse_arguments(): default="interactive" # TODO: use dynamic model selection ) + parser.add_argument("--list-models", action="store_true", help="List available LLMs for use with --model, based on the available LLM providers") + parser.add_argument( "--set", action=ParseSetAction, @@ -55,6 +57,14 @@ def parse_arguments(): async def main(): args = parse_arguments() config = load_config() + + if args.list_models: + print("Available models:") + print() + print("\n".join(["- " + key for key in config.providers])) + print() + print("Use --model to specify") + return config.parameters = dict(args.variables or {}) for file_path in args.INPUT_FILES: if os.path.isfile(file_path): diff --git a/python/src/providers/openai.py b/python/src/providers/openai.py index 8b07129..aabfdff 100644 --- a/python/src/providers/openai.py +++ b/python/src/providers/openai.py @@ -12,7 +12,10 @@ class OpenAIProvider(ProviderConfig): def __init__(self, api_key: str = None, models=None, *args, **kwargs): super().__init__(self, *args, **kwargs) openai.api_key = api_key or os.getenv("OPENAI_API_KEY") - models = models or [model.id for model in openai.models.list().data] + models = models or [ + model.id for model in openai.models.list().data + if "gpt" in model.id + ] for model_name in models: self.add( model_name, diff --git a/python/src/runtime.py b/python/src/runtime.py index 48420b5..abe9792 100644 --- a/python/src/runtime.py +++ b/python/src/runtime.py @@ -8,11 +8,11 @@ class Runtime: def __init__(self, config, env): self.config = config self.env = env - self.model_stack = [config.model] + self.env.set("MODEL", config.model) self.cwd = os.getcwd() def get_current_model(self): - return self.model_stack[-1] + return self.env.get("MODEL") def set_variable(self, var_name, value): self.env.set(var_name, value)