Skip to content

Commit

Permalink
Pass MODEL around. use it to select the current model. Implement --li…
Browse files Browse the repository at this point in the history
…st-models.
  • Loading branch information
klntsky committed Nov 12, 2024
1 parent dfbf8e1 commit 5a633e2
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 5 deletions.
3 changes: 2 additions & 1 deletion python/src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ async def eval_ast(ast, runtime):
evaluated_parameters[parameter] = await _collect_exprs(
parameters[parameter], runtime
)
if "MODEL" not in evaluated_parameters:
evaluated_parameters["MODEL"] = old_env.get("MODEL")
runtime.env = Env(evaluated_parameters)
async for expr in eval_ast(loaded_ast, runtime):
yield expr
runtime.env = old_env

elif ast["type"] == "assign":
var_name = ast["name"]
value = await _collect_exprs(ast["exprs"], runtime)
Expand Down
3 changes: 2 additions & 1 deletion python/src/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ def _discover_variables(ast):
if ast["type"] == "comment":
return
elif ast["type"] == "var":
yield {"type": "var", "name": ast["name"]}
if ast["name"] != "MODEL":
yield {"type": "var", "name": ast["name"]}
elif ast["type"] == "assign":
yield {"type": "assign", "name": ast["name"]}
for key in ast:
Expand Down
10 changes: 10 additions & 0 deletions python/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def parse_arguments():
default="interactive" # TODO: use dynamic model selection
)

parser.add_argument("--list-models", action="store_true", help="List available LLMs for use with --model, based on the available LLM providers")

parser.add_argument(
"--set",
action=ParseSetAction,
Expand All @@ -55,6 +57,14 @@ def parse_arguments():
async def main():
args = parse_arguments()
config = load_config()

if args.list_models:
print("Available models:")
print()
print("\n".join(["- " + key for key in config.providers]))
print()
print("Use --model to specify")
return
config.parameters = dict(args.variables or {})
for file_path in args.INPUT_FILES:
if os.path.isfile(file_path):
Expand Down
5 changes: 4 additions & 1 deletion python/src/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ class OpenAIProvider(ProviderConfig):
def __init__(self, api_key: str = None, models=None, *args, **kwargs):
super().__init__(self, *args, **kwargs)
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
models = models or [model.id for model in openai.models.list().data]
models = models or [
model.id for model in openai.models.list().data
if "gpt" in model.id
]
for model_name in models:
self.add(
model_name,
Expand Down
4 changes: 2 additions & 2 deletions python/src/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ class Runtime:
def __init__(self, config, env):
self.config = config
self.env = env
self.model_stack = [config.model]
self.env.set("MODEL", config.model)
self.cwd = os.getcwd()

def get_current_model(self):
return self.model_stack[-1]
return self.env.get("MODEL")

def set_variable(self, var_name, value):
self.env.set(var_name, value)
Expand Down

0 comments on commit 5a633e2

Please sign in to comment.