Skip to content

Commit

Permalink
fix: play error scenarios
Browse files Browse the repository at this point in the history
  • Loading branch information
codito committed Oct 7, 2024
1 parent 59a8015 commit 28a41ab
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion arey/data/play.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# generate a completion.

# Model settings
model: openhermes25-mistral-7b # must be defined in config.yml
model: TODO # must be defined in config.yml
#settings: # settings update reloads the model for non ollama models
# n_threads: 11 # default: cpu_count/2
# n_gpu_layers: 18 # default: 0, run on cpu
Expand Down
9 changes: 8 additions & 1 deletion arey/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,14 @@ def run_play_file(play_file_old: PlayFile) -> PlayFile:
or play_file_old.model_settings != play_file_mod.model_settings
):
with console.status("[message_footer]Loading model..."):
model_metrics = load_play_model(play_file_mod)
try:
model_metrics = load_play_model(play_file_mod)
except AreyError as ae:
if ae.category == "config":
console.print(ae.message)
return play_file_mod
raise

footer = f"✓ Model loaded. {model_metrics.init_latency_ms / 1000:.2f}s."
console.print(footer, style="message_footer")
console.print()
Expand Down
2 changes: 1 addition & 1 deletion arey/platform/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def capture_stderr() -> Generator[StringIO, None, None]:

with (
redirect_stderr(stderr) as err,
pipes(stdout=0, stderr=stderr, encoding="utf-8"),
pipes(stdout=0, stderr=stderr, encoding="utf-8"), # pyright: ignore[reportCallIssue, reportArgumentType]
):
yield err
except Exception:
Expand Down
15 changes: 9 additions & 6 deletions arey/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class PlayFile:

file_path: str

model_config: ModelConfig
model_config: ModelConfig | None
model_settings: dict[str, str]

prompt: str
Expand Down Expand Up @@ -85,7 +85,7 @@ def get_play_file(file_path: str) -> PlayFile:
play_file = frontmatter.load(f)

# FIXME validate settings
model_config = config.models[cast(str, play_file.metadata["model"])]
model_config = config.models.get(cast(str, play_file.metadata["model"]), None)
model_settings = cast(dict[str, Any], play_file.metadata.get("settings", {}))
completion_profile = cast(dict[str, Any], play_file.metadata.get("profile", {}))
output_settings = cast(dict[str, str], play_file.metadata.get("output", {}))
Expand All @@ -102,11 +102,14 @@ def get_play_file(file_path: str) -> PlayFile:
def load_play_model(play_file: PlayFile) -> ModelMetrics:
"""Load a model from play file."""
model_config = play_file.model_config
model_settings = play_file.model_settings
with capture_stderr():
model: CompletionModel = get_completion_llm(
model_config=model_config, settings=model_settings
if model_config is None:
raise AreyError(
"config", "Please specify a valid model configuration in play file."
)

model_config.settings |= play_file.model_settings
with capture_stderr():
model: CompletionModel = get_completion_llm(model_config=model_config)
model.load("")
play_file.model = model
return model.metrics
Expand Down
1 change: 1 addition & 0 deletions arey/prompt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Create a abstract class for chat prompts."""
# pyright: basic

import os
from dataclasses import dataclass, field
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,5 @@ select = ["D", "E", "F", "W"]
"**/*.py" = ["F405", "D203", "D213"]

[tool.basedpyright]
include = ["arey", "tests", "docs"]
reportUnusedCallResult = "none"

0 comments on commit 28a41ab

Please sign in to comment.