Skip to content

Commit

Permalink
Optimize SQL group order
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglei1172 committed Jun 7, 2023
1 parent f382806 commit b3c0e6a
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 15 deletions.
114 changes: 114 additions & 0 deletions mlcopilot/experience.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from collections import OrderedDict
from typing import Any, Dict, List, Optional

import langchain
import numpy as np
import orjson
import pandas as pd
from langchain.cache import InMemoryCache
from peewee import ModelSelect, fn

from mlcopilot.constants import *
from mlcopilot.orm import Knowledge, Solution, Space, Task, database_proxy
Expand Down Expand Up @@ -308,3 +310,115 @@ def _gen_experience_demos(space: Space, task: Task) -> str:
]
)
return demos


def _get_best_relevant_solutions(space: Space, task_desc: str) -> ModelSelect:
"""
Get the best relevant solution for a task.
The relevance is measured by cosine similarity between task description embeddings, which affects the order of results.
Parameters
----------
space: Space
The space.
task_desc: str
The task description.
Returns
-------
ModelSelect
The best relevant solution.
"""
SolutionAlias = Solution.alias()
subquery = (
SolutionAlias.select(
SolutionAlias.demo,
Task.task_id,
Task.desc,
Task.embedding,
fn.RANK()
.over(
partition_by=[SolutionAlias.space, SolutionAlias.task],
order_by=[SolutionAlias.metric.desc()],
)
.alias("rnk"),
)
.where(SolutionAlias.space == space)
.join(Task, on=(SolutionAlias.task == Task.task_id))
.order_by(fn.cosine_similarity(task_desc, Task.embedding).desc())
.alias("subq")
)
query = Solution.select(subquery.c.task_id, subquery.c.demo, subquery.c.desc).from_(
subquery
)
return query


def _get_best_solutions(space: Space) -> ModelSelect:
"""
Get the best solution for each task.
Parameters
----------
space: Space
The space.
Returns
-------
ModelSelect
The best solution for each task.
"""
SolutionAlias = Solution.alias()
subquery = (
SolutionAlias.select(
SolutionAlias.demo,
Task.task_id,
Task.desc,
Task.embedding,
fn.RANK()
.over(
partition_by=[SolutionAlias.space, SolutionAlias.task],
order_by=[SolutionAlias.metric.desc()],
)
.alias("rnk"),
)
.where(SolutionAlias.space == space)
.join(Task, on=(SolutionAlias.task == Task.task_id))
.alias("subq")
)
query = Solution.select(subquery.c.task_id, subquery.c.demo, subquery.c.desc).from_(
subquery
)
return query


def gen_experience(space: Space, task_desc: Optional[str] = None) -> List[str]:
"""
Generate experience content from space and optional task description.
Parameters
----------
space: Space
The space.
task_desc
The task description.
Returns
-------
List[str]
The experience content.
"""
if task_desc is None:
query = _get_best_solutions(space)
else:
query = _get_best_relevant_solutions(space, task_desc)
examples = OrderedDict()

for solution in query:
if solution.task_id not in examples:
examples[solution.task_id] = [solution.desc]
if len(examples[solution.task_id]) <= TOP_K:
examples[solution.task_id].append(
f"Configuration {len(examples[solution.task_id])}: {solution.demo}"
)
return ["\n".join(e) for e in examples.values()]
7 changes: 2 additions & 5 deletions mlcopilot/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain.prompts.example_selector import LengthBasedExampleSelector

from mlcopilot.constants import *
from mlcopilot.experience import gen_experience_per_task
from mlcopilot.experience import gen_experience
from mlcopilot.orm import Knowledge, Solution, Space, Task, database_proxy
from mlcopilot.surrogate_utils import evaluate_configs
from mlcopilot.utils import get_llm, parse_configs
Expand Down Expand Up @@ -140,10 +140,7 @@ def post_validation(
print("Knowledge already exists.")
return knowledge
quantile_infos = orjson.loads(space.quantile_info)
examples = [
gen_experience_per_task(space, task)
for task in Task.select().join(Solution).where(Solution.space == space)
]
examples = gen_experience(space)
best_score = float("-inf")
knowledge = None
for _ in range(3):
Expand Down
12 changes: 3 additions & 9 deletions mlcopilot/suggest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from peewee import fn

from mlcopilot.constants import *
from mlcopilot.experience import gen_experience_per_task
from mlcopilot.experience import gen_experience
from mlcopilot.knowledge import get_knowledge
from mlcopilot.orm import Knowledge, Solution, Space, Task, database_proxy
from mlcopilot.space import import_space, print_space
Expand Down Expand Up @@ -71,14 +71,8 @@ def suggest(space: Space, task_desc: str) -> Tuple[Dict[str, Any], Union[str, No
knowledge = get_knowledge(space)
task_desc = f"""Task: {task_desc}"""

tasks_select = (
Task.select()
.join(Solution)
.where(Solution.space == space)
.distinct()
.order_by(fn.cosine_similarity(task_desc, Task.embedding).desc())
) # TODO SQL groupby
examples = [gen_experience_per_task(space, task) for task in tasks_select]
examples = gen_experience(space, task_desc)

llm = get_llm("suggest")()
quantile_infos = orjson.loads(space.quantile_info)

Expand Down
22 changes: 21 additions & 1 deletion test/test_experience.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import numpy as np
import pandas as pd
import pytest
from peewee import fn

from mlcopilot.constants import TOP_K
from mlcopilot.experience import (
_ingest_solution,
_ingest_space,
_ingest_task,
canonicalize_config,
gen_experience,
gen_experience_per_task,
get_quantile_stat,
ingest_experience,
)
from mlcopilot.orm import Task
from mlcopilot.orm import Solution, Task
from mlcopilot.utils import set_llms

from .llm import MockEmbeddingModel, MockKnowledgeLLM
Expand Down Expand Up @@ -99,3 +101,21 @@ def test_gen_experience_per_task():
experience_per_task = gen_experience_per_task(space, task)
assert isinstance(experience_per_task, str)
return experience_per_task


def test_gen_experience():
task_desc = "test task description"
space, _ = test_ingest_space()

test_ingest_task()

tasks_select = (
Task.select()
.join(Solution)
.where(Solution.space == space)
.distinct()
.order_by(fn.cosine_similarity(task_desc, Task.embedding).desc())
) # TODO SQL groupby
examples_gt = [gen_experience_per_task(space, task) for task in tasks_select]
examples = gen_experience(space, task_desc)
assert examples == examples_gt

0 comments on commit b3c0e6a

Please sign in to comment.