Skip to content

Commit

Permalink
customized function for llm invoking and tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
HuXiangkun committed Sep 18, 2024
1 parent 038584e commit 6e4557d
Show file tree
Hide file tree
Showing 10 changed files with 763 additions and 18 deletions.
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# RAGChecker: A Fine-grained Framework For Diagnosing RAG

<p align="center">
<a href="https://arxiv.org/pdf/2408.08067">RAGChecker Paper</a> &nbsp&nbsp | &nbsp&nbsp <a href="./tutorial/ragchecker_tutorial_en.md">Tutorial (English)</a> &nbsp&nbsp | &nbsp&nbsp <a href="./tutorial/ragchecker_tutorial_zh.md">中文教程</a>
</p>

RAGChecker is an advanced automatic evaluation framework designed to assess and diagnose Retrieval-Augmented Generation (RAG) systems. It provides a comprehensive suite of metrics and tools for in-depth analysis of RAG performance.

<p align="center">
Expand Down Expand Up @@ -86,8 +90,8 @@ If you are using AWS Bedrock version of Llama3 70B for the claim extractor and c
ragchecker-cli \
--input_path=examples/checking_inputs.json \
--output_path=examples/checking_outputs.json \
--extractor_name=bedrock/meta.llama3-70b-instruct-v1:0 \
--checker_name=bedrock/meta.llama3-70b-instruct-v1:0 \
--extractor_name=bedrock/meta.llama3-1-70b-instruct-v1:0 \
--checker_name=bedrock/meta.llama3-1-70b-instruct-v1:0 \
--batch_size_extractor=64 \
--batch_size_checker=64 \
--metrics all_metrics \
Expand Down Expand Up @@ -133,8 +137,8 @@ with open("examples/checking_inputs.json") as fp:

# set-up the evaluator
evaluator = RAGChecker(
extractor_name="bedrock/meta.llama3-70b-instruct-v1:0",
checker_name="bedrock/meta.llama3-70b-instruct-v1:0",
extractor_name="bedrock/meta.llama3-1-70b-instruct-v1:0",
checker_name="bedrock/meta.llama3-1-70b-instruct-v1:0",
batch_size_extractor=32,
batch_size_checker=32
)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ragchecker"
version = "0.1.6"
version = "0.1.7"
description = "RAGChecker: A Fine-grained Framework For Diagnosing Retrieval-Augmented Generation (RAG) systems."
authors = [
"Xiangkun Hu <[email protected]>",
Expand All @@ -15,7 +15,7 @@ license = "Apache-2.0"

[tool.poetry.dependencies]
python = "^3.9"
refchecker = "^0.2.6"
refchecker = "^0.2.10"
loguru = "^0.7"
dataclasses-json = "^0.6"

Expand Down
16 changes: 4 additions & 12 deletions ragchecker/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def __init__(
openai_api_key=None,
joint_check=True,
joint_check_num=5,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
if openai_api_key:
Expand All @@ -62,9 +60,7 @@ def __init__(
self.joint_check = joint_check
self.joint_check_num = joint_check_num
self.kwargs = kwargs
self.sagemaker_client = sagemaker_client
self.sagemaker_params = sagemaker_params
self.sagemaker_get_response_func = sagemaker_get_response_func
self.custom_llm_api_func = custom_llm_api_func

self.extractor = LLMExtractor(
model=extractor_name,
Expand Down Expand Up @@ -111,9 +107,7 @@ def extract_claims(self, results: List[RAGResult], extract_type="gt_answer"):
batch_responses=texts,
batch_questions=questions,
max_new_tokens=self.extractor_max_new_tokens,
sagemaker_client=self.sagemaker_client,
sagemaker_params=self.sagemaker_params,
sagemaker_get_response_func=self.sagemaker_get_response_func,
custom_llm_api_func=self.custom_llm_api_func,
**self.kwargs
)
claims = [[c.content for c in res.claims] for res in extraction_results]
Expand Down Expand Up @@ -174,9 +168,7 @@ def check_claims(self, results: RAGResults, check_type="answer2response"):
merge_psg=merge_psg,
is_joint=self.joint_check,
joint_check_num=self.joint_check_num,
sagemaker_client=self.sagemaker_client,
sagemaker_params=self.sagemaker_params,
sagemaker_get_response_func=self.sagemaker_get_response_func,
custom_llm_api_func=self.custom_llm_api_func,
**self.kwargs
)
for i, result in enumerate(results):
Expand Down
Binary file added tutorial/figures/claim_checking.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tutorial/figures/generator_metrics.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tutorial/figures/overall_metrics.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tutorial/figures/rag_pipeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tutorial/figures/retriever_metrics.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 6e4557d

Please sign in to comment.