Skip to content

Commit

Permalink
Add example ranker
Browse files Browse the repository at this point in the history
  • Loading branch information
ultmaster committed Jan 11, 2024
1 parent abbf52f commit 36b9c16
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions coml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import random
import re
import warnings
from typing import Any, cast, Literal, Callable
from typing import Any, cast, Literal, Callable, TypeVar

import colorama
from langchain.chat_models.base import BaseChatModel
from langchain.embeddings.base import Embeddings
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
from scipy.spatial.distance import cosine as cosine_distance

from .prompt_utils import (
CHECK_INSTRUCTION,
Expand All @@ -35,6 +37,8 @@

_debug_mode: bool = False

_Type = TypeVar("_Type")


def debug_messages(*messages: BaseMessage) -> None:
if not _debug_mode:
Expand Down Expand Up @@ -126,6 +130,7 @@ def __init__(
] = "vcr",
ensemble: int | None = None,
ensemble_shuffle: bool = True,
example_ranking: Embeddings | None = None,
):
self.llm = llm
self.prompt_version = prompt_version
Expand All @@ -136,6 +141,7 @@ def __init__(
self.context_order = context_order
self.ensemble = ensemble
self.ensemble_shuffle = ensemble_shuffle
self.example_ranking = example_ranking

def _fix_context_from_any_context(
self, context: GenerateContext | FixContext, **kwargs: Any
Expand Down Expand Up @@ -240,6 +246,25 @@ def _generate(self, messages: list[BaseMessage]) -> BaseMessage:
messages = self._pre_generation(messages)
return self.llm(messages)

def _select_examples(self, query: str, fewshots: list[_Type]) -> list[_Type]:
"""Select examples from the fewshots."""
if self.num_examples == 0:
return []

if self.example_ranking is not None:
documents = [cast(Any, shot).get("request", "N/A") for shot in fewshots]
embeddings = self.example_ranking.embed_documents(documents)
query_embedding = self.example_ranking.embed_query(query)
similarities = [
(cosine_distance(query_embedding, embedding), shot)
for embedding, shot in zip(embeddings, fewshots)
]
similarities.sort(key=lambda x: x[0])
fewshots = [shot for _, shot in similarities]

num_shots = max(int(len(fewshots) * self.num_examples), 1)
return fewshots[:num_shots]

def generate_code(
self,
request: str,
Expand All @@ -254,8 +279,7 @@ def generate_code(
else:
messages.append(SystemMessage(content=GENERATE_INSTRUCTION))

num_shots = max(int(len(fewshots) * self.num_examples), 1)
for shot in fewshots[:num_shots]:
for shot in self._select_examples(request, fewshots):
question, answer = render_generate_context(
shot, cot=self.chain_of_thought, context_order=self.context_order
)
Expand Down Expand Up @@ -288,6 +312,7 @@ def fix_code(
prev_context: GenerateContext | FixContext,
) -> FixContext | None:
fewshots = cached_fix_fewshots()
fewshots = self._select_examples(prev_context["request"] or "N/A", fewshots)
messages: list[BaseMessage] = [
SystemMessage(content=FIX_INSTRUCTION),
]
Expand Down

0 comments on commit 36b9c16

Please sign in to comment.