From b15620ee9c8a83b4f61cb0a8929b103d56661f50 Mon Sep 17 00:00:00 2001 From: Lance Martin <122662504+rlancemartin@users.noreply.github.com> Date: Wed, 13 Dec 2023 07:59:12 -0800 Subject: [PATCH] Minor clean, add Mixtral (#123) --- .../ss_eval_chunk_sizes.ipynb | 64 ++++++++++++------- .../ss_eval_long_context.ipynb | 25 ++++++-- .../ss_eval_multi_vector.ipynb | 4 +- 3 files changed, 63 insertions(+), 30 deletions(-) diff --git a/docs/source/notebooks/retrieval/semi_structured_benchmarking/ss_eval_chunk_sizes.ipynb b/docs/source/notebooks/retrieval/semi_structured_benchmarking/ss_eval_chunk_sizes.ipynb index ee61735e..1ee15282 100644 --- a/docs/source/notebooks/retrieval/semi_structured_benchmarking/ss_eval_chunk_sizes.ipynb +++ b/docs/source/notebooks/retrieval/semi_structured_benchmarking/ss_eval_chunk_sizes.ipynb @@ -24,7 +24,7 @@ "outputs": [], "source": [ "# %pip install -U langchain langsmith langchain_benchmarks\n", - "# %pip install --quiet chromadb openai" + "# %pip install --quiet chromadb openai pypdf tiktoken fireworks-ai" ] }, { @@ -38,7 +38,7 @@ "import os\n", "\n", "os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.smith.langchain.com\"\n", - "env_vars = [\"LANGCHAIN_API_KEY\", \"OPENAI_API_KEY\"]\n", + "env_vars = [\"LANGCHAIN_API_KEY\", \"OPENAI_API_KEY\", \"FIREWORKS_API_KEY\"]\n", "for var in env_vars:\n", " if var not in os.environ:\n", " os.environ[var] = getpass.getpass(prompt=f\"Enter your {var}: \")" @@ -56,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "a94d9aa5-acd8-4032-ad8f-f995dec4d13c", "metadata": {}, "outputs": [], @@ -84,10 +84,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "1ecca7af-c3e7-42d1-97dd-c7d9777207cb", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset Semi-structured Reports already exists. Skipping.\n", + "You can access the dataset at https://smith.langchain.com/o/1fa8b1f4-fcb9-4072-9aa9-983e35ad61b8/datasets/6549a3a5-1cb9-463f-951d-0166cb9cf45c.\n" + ] + } + ], "source": [ "clone_public_dataset(task.dataset_id, dataset_name=task.name)" ] @@ -106,9 +115,7 @@ "cell_type": "code", "execution_count": null, "id": "7eb9e333-77e6-48f9-b221-9bded023b978", - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "from langchain.document_loaders import PyPDFLoader\n", @@ -119,6 +126,9 @@ "from langchain.prompts import ChatPromptTemplate\n", "from langchain.schema.output_parser import StrOutputParser\n", "from langchain.schema.runnable import RunnablePassthrough\n", + "from langchain.chat_models import ChatFireworks\n", + "from langchain.callbacks.manager import CallbackManager\n", + "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", "\n", "\n", "def load_and_split(file, token_count, split_document=True):\n", @@ -148,7 +158,7 @@ " return texts\n", "\n", "\n", - "def load_files(files, directory, token_count, split_document):\n", + "def load_files(files, token_count, split_document):\n", " \"\"\"\n", " Load files.\n", "\n", @@ -161,7 +171,7 @@ "\n", " texts = []\n", " for fi in files:\n", - " texts.extend(load_and_split(directory + fi, token_count, split_document))\n", + " texts.extend(load_and_split(fi, token_count, split_document))\n", " return texts\n", "\n", "\n", @@ -180,12 +190,13 @@ " return retriever\n", "\n", "\n", - "def rag_chain(retriever):\n", + "def rag_chain(retriever, llm):\n", " \"\"\"\n", " RAG chain.\n", "\n", " Args:\n", " retriever: The retriever to use.\n", + " llm: The llm to use.\n", " \"\"\"\n", "\n", " # Prompt template\n", @@ -196,7 +207,12 @@ " prompt = ChatPromptTemplate.from_template(template)\n", "\n", " # LLM\n", - " model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n", + " if llm == \"mixtral\":\n", + " model = ChatFireworks(\n", + " model=\"accounts/fireworks/models/mixtral-8x7b-instruct\", temperature=0\n", + " )\n", + " else:\n", + " model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n", "\n", " # RAG pipeline\n", " chain = (\n", @@ -213,18 +229,19 @@ "\n", "# Experiment configurations\n", "experiments = [\n", - " (None, False, \"page_split\"),\n", - " (50, True, \"50_tok_split\"),\n", - " (100, True, \"100_tok_split\"),\n", - " (250, True, \"250_tok_split\"),\n", + " (None, False, \"page_split-oai\", \"oai\"),\n", + " (50, True, \"50_tok_split-oai\", \"oai\"),\n", + " (100, True, \"100_tok_split-oai\", \"oai\"),\n", + " (250, True, \"250_tok_split-oai\", \"oai\"),\n", + " (250, True, \"250_tok_split-mixtral\", \"mixtral\"),\n", "]\n", "\n", "# Run\n", "stor_chain = {}\n", - "for token_count, split_document, expt in experiments:\n", - " texts = load_files(files, directory, token_count, split_document)\n", + "for token_count, split_document, expt, llm in experiments:\n", + " texts = load_files(files, token_count, split_document)\n", " retriever = make_retriever(texts, expt)\n", - " stor_chain[expt] = rag_chain(retriever)" + " stor_chain[expt] = rag_chain(retriever, llm)" ] }, { @@ -256,10 +273,11 @@ "\n", "# Experiments\n", "chain_map = {\n", - " \"page_split\": stor_chain[\"page_split\"],\n", - " \"baseline-50-tok\": stor_chain[\"50_tok_split\"],\n", - " \"baseline-100-tok\": stor_chain[\"100_tok_split\"],\n", - " \"baseline-250-tok\": stor_chain[\"250_tok_split\"],\n", + " \"page_split\": stor_chain[\"page_split-oai\"],\n", + " \"baseline-50-tok\": stor_chain[\"50_tok_split-oai\"],\n", + " \"baseline-100-tok\": stor_chain[\"100_tok_split-oai\"],\n", + " \"baseline-250-tok\": stor_chain[\"250_tok_split-oai\"],\n", + " \"baseline-250-tok-mixtral\": stor_chain[\"250_tok_split-mixtral\"],\n", "}\n", "\n", "# Run evaluation\n", diff --git a/docs/source/notebooks/retrieval/semi_structured_benchmarking/ss_eval_long_context.ipynb b/docs/source/notebooks/retrieval/semi_structured_benchmarking/ss_eval_long_context.ipynb index 71322cfe..2e7f2aa8 100644 --- a/docs/source/notebooks/retrieval/semi_structured_benchmarking/ss_eval_long_context.ipynb +++ b/docs/source/notebooks/retrieval/semi_structured_benchmarking/ss_eval_long_context.ipynb @@ -54,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "id": "61b816df-b43f-45b4-9b58-883d9847dd40", "metadata": {}, "outputs": [], @@ -82,16 +82,31 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "id": "ead966fe-bfac-4d09-b8b8-00c0e8fca991", "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8c76b8af35ec486abdd6b061df4c9ac1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/30 [00:00