Skip to content

Commit

Permalink
Merge pull request #122 from Azure-Samples/testeval6
Browse files Browse the repository at this point in the history
Add seed parameter (optional) and custom evaluation metric for citations overlap
  • Loading branch information
pamelafox authored Oct 23, 2024
2 parents 02ed71a + da05c77 commit 469474d
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 3 deletions.
5 changes: 3 additions & 2 deletions evals/eval_config.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
{
"testdata_path": "ground_truth.jsonl",
"results_dir": "results/experiment<TIMESTAMP>",
"requested_metrics": ["gpt_groundedness", "gpt_relevance", "answer_length", "latency", "citation_match"],
"requested_metrics": ["gpt_groundedness", "gpt_relevance", "answer_length", "latency", "citations_matched"],
"target_url": "http://127.0.0.1:8000/chat",
"target_parameters": {
"overrides": {
"use_advanced_flow": true,
"top": 3,
"retrieval_mode": "hybrid",
"temperature": 0.3
"temperature": 0.3,
"seed": 42
}
},
"target_response_answer_jmespath": "message.content",
Expand Down
28 changes: 28 additions & 0 deletions evals/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,44 @@
import argparse
import logging
import os
import re
from pathlib import Path
from typing import Any

import azure.identity
from dotenv import load_dotenv
from evaltools.eval.evaluate import run_evaluate_from_config
from evaltools.eval.evaluate_metrics import register_metric
from evaltools.eval.evaluate_metrics.base_metric import BaseMetric
from rich.logging import RichHandler

logger = logging.getLogger("ragapp")


class CitationsMatchedMetric(BaseMetric):
METRIC_NAME = "citations_matched"

@classmethod
def evaluator_fn(cls, **kwargs):
def citations_overlap(*, response, ground_truth, **kwargs):
if response is None:
logger.warning("Received response of None, can't compute citation_match metric. Setting to -1.")
return {cls.METRIC_NAME: -1}
truth_citations = set(re.findall(r"\[(\d+)\]", ground_truth))
response_citations = set(re.findall(r"\[(\d+)\]", response))
# Count the percentage of citations that are present in the response
num_citations = len(truth_citations)
num_matched_citations = len(truth_citations.intersection(response_citations))
return {cls.METRIC_NAME: num_matched_citations / num_citations}

return citations_overlap

@classmethod
def get_aggregate_stats(cls, df):
df = df[df[cls.METRIC_NAME] != -1]
return {"mean": round(df[cls.METRIC_NAME].mean(), 2)}


def get_openai_config() -> dict:
openai_config: dict[str, Any]
if os.environ.get("OPENAI_CHAT_HOST") == "azure":
Expand Down Expand Up @@ -60,6 +87,7 @@ def get_openai_config() -> dict:

openai_config = get_openai_config()

register_metric(CitationsMatchedMetric)
run_evaluate_from_config(
working_dir=Path(__file__).parent,
config_path="eval_config.json",
Expand Down
1 change: 1 addition & 0 deletions src/backend/fastapi_app/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ChatRequestOverrides(BaseModel):
retrieval_mode: RetrievalMode = RetrievalMode.HYBRID
use_advanced_flow: bool = True
prompt_template: str | None = None
seed: int | None = None


class ChatRequestContext(BaseModel):
Expand Down
9 changes: 8 additions & 1 deletion src/backend/fastapi_app/rag_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def __init__(
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)

async def generate_search_query(
self, original_user_query: str, past_messages: list[ChatCompletionMessageParam], query_response_token_limit: int
self,
original_user_query: str,
past_messages: list[ChatCompletionMessageParam],
query_response_token_limit: int,
seed: int | None = None,
) -> tuple[list[ChatCompletionMessageParam], Any | str | None, list]:
"""Generate an optimized keyword search query based on the chat history and the last question"""

Expand Down Expand Up @@ -63,6 +67,7 @@ async def generate_search_query(
n=1,
tools=tools,
tool_choice=tool_choice,
seed=seed,
)

query_text, filters = extract_search_arguments(original_user_query, chat_completion)
Expand All @@ -76,6 +81,7 @@ async def prepare_context(
original_user_query=chat_params.original_user_query,
past_messages=chat_params.past_messages,
query_response_token_limit=500,
seed=chat_params.seed,
)

# Retrieve relevant rows from the database with the GPT optimized query
Expand Down Expand Up @@ -142,6 +148,7 @@ async def answer(
max_tokens=chat_params.response_token_limit,
n=1,
stream=False,
seed=chat_params.seed,
)

return RetrievalResponse(
Expand Down
1 change: 1 addition & 0 deletions src/backend/fastapi_app/rag_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_params(self, messages: list[ChatCompletionMessageParam], overrides: Chat
return ChatParams(
top=overrides.top,
temperature=overrides.temperature,
seed=overrides.seed,
retrieval_mode=overrides.retrieval_mode,
use_advanced_flow=overrides.use_advanced_flow,
response_token_limit=response_token_limit,
Expand Down
2 changes: 2 additions & 0 deletions src/backend/fastapi_app/rag_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ async def answer(
max_tokens=chat_params.response_token_limit,
n=1,
stream=False,
seed=chat_params.seed,
)

return RetrievalResponse(
Expand Down Expand Up @@ -130,6 +131,7 @@ async def answer_stream(
max_tokens=chat_params.response_token_limit,
n=1,
stream=True,
seed=chat_params.seed,
)

yield RetrievalResponseDelta(
Expand Down

0 comments on commit 469474d

Please sign in to comment.