diff --git a/mlcopilot/experience.py b/mlcopilot/experience.py index cd9aeb3..f66aead 100644 --- a/mlcopilot/experience.py +++ b/mlcopilot/experience.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import OrderedDict from typing import Any, Dict, List, Optional import langchain @@ -7,6 +8,7 @@ import orjson import pandas as pd from langchain.cache import InMemoryCache +from peewee import ModelSelect, fn from mlcopilot.constants import * from mlcopilot.orm import Knowledge, Solution, Space, Task, database_proxy @@ -308,3 +310,115 @@ def _gen_experience_demos(space: Space, task: Task) -> str: ] ) return demos + + +def _get_best_relevant_solutions(space: Space, task_desc: str) -> ModelSelect: + """ + Get the best relevant solution for a task. + The relevance is measured by cosine similarity between task description embeddings, which affects the order of results. + + Parameters + ---------- + space: Space + The space. + task_desc: str + The task description. + + Returns + ------- + ModelSelect + The best relevant solution. + """ + SolutionAlias = Solution.alias() + subquery = ( + SolutionAlias.select( + SolutionAlias.demo, + Task.task_id, + Task.desc, + Task.embedding, + fn.RANK() + .over( + partition_by=[SolutionAlias.space, SolutionAlias.task], + order_by=[SolutionAlias.metric.desc()], + ) + .alias("rnk"), + ) + .where(SolutionAlias.space == space) + .join(Task, on=(SolutionAlias.task == Task.task_id)) + .order_by(fn.cosine_similarity(task_desc, Task.embedding).desc()) + .alias("subq") + ) + query = Solution.select(subquery.c.task_id, subquery.c.demo, subquery.c.desc).from_( + subquery + ) + return query + + +def _get_best_solutions(space: Space) -> ModelSelect: + """ + Get the best solution for each task. + + Parameters + ---------- + space: Space + The space. + + Returns + ------- + ModelSelect + The best solution for each task. + """ + SolutionAlias = Solution.alias() + subquery = ( + SolutionAlias.select( + SolutionAlias.demo, + Task.task_id, + Task.desc, + Task.embedding, + fn.RANK() + .over( + partition_by=[SolutionAlias.space, SolutionAlias.task], + order_by=[SolutionAlias.metric.desc()], + ) + .alias("rnk"), + ) + .where(SolutionAlias.space == space) + .join(Task, on=(SolutionAlias.task == Task.task_id)) + .alias("subq") + ) + query = Solution.select(subquery.c.task_id, subquery.c.demo, subquery.c.desc).from_( + subquery + ) + return query + + +def gen_experience(space: Space, task_desc: Optional[str] = None) -> List[str]: + """ + Generate experience content from space and optional task description. + + Parameters + ---------- + space: Space + The space. + task_desc + The task description. + + Returns + ------- + List[str] + The experience content. + """ + if task_desc is None: + query = _get_best_solutions(space) + else: + query = _get_best_relevant_solutions(space, task_desc) + examples = OrderedDict() + + for solution in query: + if solution.task_id not in examples: + examples[solution.task_id] = [solution.desc] + if len(examples[solution.task_id]) <= TOP_K: + examples[solution.task_id].append( + f"Configuration {len(examples[solution.task_id])}: {solution.demo}" + ) + return ["\n".join(e) for e in examples.values()] diff --git a/mlcopilot/knowledge.py b/mlcopilot/knowledge.py index 9b301ec..4ba0bf2 100644 --- a/mlcopilot/knowledge.py +++ b/mlcopilot/knowledge.py @@ -6,7 +6,7 @@ from langchain.prompts.example_selector import LengthBasedExampleSelector from mlcopilot.constants import * -from mlcopilot.experience import gen_experience_per_task +from mlcopilot.experience import gen_experience from mlcopilot.orm import Knowledge, Solution, Space, Task, database_proxy from mlcopilot.surrogate_utils import evaluate_configs from mlcopilot.utils import get_llm, parse_configs @@ -140,10 +140,7 @@ def post_validation( print("Knowledge already exists.") return knowledge quantile_infos = orjson.loads(space.quantile_info) - examples = [ - gen_experience_per_task(space, task) - for task in Task.select().join(Solution).where(Solution.space == space) - ] + examples = gen_experience(space) best_score = float("-inf") knowledge = None for _ in range(3): diff --git a/mlcopilot/suggest.py b/mlcopilot/suggest.py index 2f5bd51..653dd8c 100644 --- a/mlcopilot/suggest.py +++ b/mlcopilot/suggest.py @@ -6,7 +6,7 @@ from peewee import fn from mlcopilot.constants import * -from mlcopilot.experience import gen_experience_per_task +from mlcopilot.experience import gen_experience from mlcopilot.knowledge import get_knowledge from mlcopilot.orm import Knowledge, Solution, Space, Task, database_proxy from mlcopilot.space import import_space, print_space @@ -71,14 +71,8 @@ def suggest(space: Space, task_desc: str) -> Tuple[Dict[str, Any], Union[str, No knowledge = get_knowledge(space) task_desc = f"""Task: {task_desc}""" - tasks_select = ( - Task.select() - .join(Solution) - .where(Solution.space == space) - .distinct() - .order_by(fn.cosine_similarity(task_desc, Task.embedding).desc()) - ) # TODO SQL groupby - examples = [gen_experience_per_task(space, task) for task in tasks_select] + examples = gen_experience(space, task_desc) + llm = get_llm("suggest")() quantile_infos = orjson.loads(space.quantile_info) diff --git a/test/test_experience.py b/test/test_experience.py index 0f791ce..4782d5d 100644 --- a/test/test_experience.py +++ b/test/test_experience.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd import pytest +from peewee import fn from mlcopilot.constants import TOP_K from mlcopilot.experience import ( @@ -8,11 +9,12 @@ _ingest_space, _ingest_task, canonicalize_config, + gen_experience, gen_experience_per_task, get_quantile_stat, ingest_experience, ) -from mlcopilot.orm import Task +from mlcopilot.orm import Solution, Task from mlcopilot.utils import set_llms from .llm import MockEmbeddingModel, MockKnowledgeLLM @@ -99,3 +101,21 @@ def test_gen_experience_per_task(): experience_per_task = gen_experience_per_task(space, task) assert isinstance(experience_per_task, str) return experience_per_task + + +def test_gen_experience(): + task_desc = "test task description" + space, _ = test_ingest_space() + + test_ingest_task() + + tasks_select = ( + Task.select() + .join(Solution) + .where(Solution.space == space) + .distinct() + .order_by(fn.cosine_similarity(task_desc, Task.embedding).desc()) + ) # TODO SQL groupby + examples_gt = [gen_experience_per_task(space, task) for task in tasks_select] + examples = gen_experience(space, task_desc) + assert examples == examples_gt