Skip to content

Commit

Permalink
community[minor]: Added classification_location parameter in PebbloSa…
Browse files Browse the repository at this point in the history
…feLoader. (langchain-ai#22565)

Description: Add classifier_location feature flag. This flag enables
Pebblo to decide the classifier location, local or pebblo-cloud.
Unit Tests: N/A
Documentation: N/A

---------

Signed-off-by: Rahul Tripathi <[email protected]>
Co-authored-by: Rahul Tripathi <[email protected]>
  • Loading branch information
2 people authored and Matt DeGenaro committed Jul 8, 2024
1 parent 6b4d8e3 commit 8227ea8
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 169 deletions.
182 changes: 122 additions & 60 deletions libs/community/langchain_community/chains/pebblo_retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import datetime
import inspect
import json
import logging
from http import HTTPStatus
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -72,7 +73,9 @@ class PebbloRetrievalQA(Chain):
"""Pebblo cloud API key for app."""
classifier_url: str = CLASSIFIER_URL #: :meta private:
"""Classifier endpoint."""
_discover_sent: bool = False #: :meta private:
classifier_location: str = "local" #: :meta private:
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
_discover_sent = False #: :meta private:
"""Flag to check if discover payload has been sent."""
_prompt_sent: bool = False #: :meta private:
"""Flag to check if prompt payload has been sent."""
Expand All @@ -94,6 +97,7 @@ def _call(
answer, docs = res['result'], res['source_documents']
"""
prompt_time = datetime.datetime.now().isoformat()
PebbloRetrievalQA.set_prompt_sent(value=False)
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key, {})
Expand All @@ -115,7 +119,9 @@ def _call(
"name": self.app_name,
"context": [
{
"retrieved_from": doc.metadata.get("source"),
"retrieved_from": doc.metadata.get(
"full_path", doc.metadata.get("source")
),
"doc": doc.page_content,
"vector_db": self.retriever.vectorstore.__class__.__name__,
}
Expand All @@ -131,6 +137,7 @@ def _call(
"user_identities": auth_context.user_auth
if auth_context and hasattr(auth_context, "user_auth")
else [],
"classifier_location": self.classifier_location,
}
qa_payload = Qa(**qa)
self._send_prompt(qa_payload)
Expand Down Expand Up @@ -220,6 +227,7 @@ def from_chain_type(
chain_type_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
classifier_url: str = CLASSIFIER_URL,
classifier_location: str = "local",
**kwargs: Any,
) -> "PebbloRetrievalQA":
"""Load chain from chain type."""
Expand All @@ -231,7 +239,7 @@ def from_chain_type(
)

# generate app
app = PebbloRetrievalQA._get_app_details(
app: App = PebbloRetrievalQA._get_app_details(
app_name=app_name,
description=description,
owner=owner,
Expand All @@ -240,7 +248,10 @@ def from_chain_type(
)

PebbloRetrievalQA._send_discover(
app, api_key=api_key, classifier_url=classifier_url
app,
api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location,
)

return cls(
Expand All @@ -250,6 +261,7 @@ def from_chain_type(
description=description,
api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location,
**kwargs,
)

Expand Down Expand Up @@ -300,7 +312,9 @@ async def _aget_docs(
)

@staticmethod
def _get_app_details(app_name, owner, description, llm, **kwargs) -> App: # type: ignore
def _get_app_details( # type: ignore
app_name: str, owner: str, description: str, llm: BaseLanguageModel, **kwargs
) -> App:
"""Fetch app details. Internal method.
Returns:
App: App details.
Expand All @@ -319,38 +333,49 @@ def _get_app_details(app_name, owner, description, llm, **kwargs) -> App: # typ
return app

@staticmethod
def _send_discover(app, api_key, classifier_url) -> None: # type: ignore
def _send_discover(
app: App,
api_key: Optional[str],
classifier_url: str,
classifier_location: str,
) -> None: # type: ignore
"""Send app discovery payload to pebblo-server. Internal method."""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = app.dict(exclude_unset=True)
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
if classifier_location == "local":
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
"Received unexpected HTTP response code:"
+ f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)

if api_key:
try:
Expand Down Expand Up @@ -385,48 +410,82 @@ def set_discover_sent(cls) -> None:
cls._discover_sent = True

@classmethod
def set_prompt_sent(cls) -> None:
cls._prompt_sent = True
def set_prompt_sent(cls, value: bool = True) -> None:
cls._prompt_sent = value

def _send_prompt(self, qa_payload: Qa) -> None:
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
app_discover_url = f"{self.classifier_url}{PROMPT_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=qa_payload.dict(), timeout=20
)
logger.debug("prompt-payload: %s", qa_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
pebblo_resp = None
payload = qa_payload.dict(exclude_unset=True)
if self.classifier_location == "local":
try:
pebblo_resp = requests.post(
app_discover_url,
headers=headers,
json=payload,
timeout=20,
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
logger.debug("prompt-payload: %s", payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
"Received unexpected HTTP response code:"
+ f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)

# If classifier location is local, then response, context and prompt
# should be fetched from pebblo_resp and replaced in payload.
if self.api_key:
if self.classifier_location == "local":
if pebblo_resp:
payload["response"] = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("response", {})
)
payload["context"] = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("context", [])
)
payload["prompt"] = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("prompt", {})
)
else:
payload["response"] = None
payload["context"] = None
payload["prompt"] = None
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
try:
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
pebblo_cloud_response = requests.post(
pebblo_cloud_url,
headers=headers,
json=qa_payload.dict(),
json=payload,
timeout=20,
)

Expand All @@ -449,9 +508,12 @@ def _send_prompt(self, qa_payload: Qa) -> None:
logger.warning("Unable to reach Pebblo cloud server.")
except Exception as e:
logger.warning("An Exception caught in _send_prompt: cloud %s", e)
elif self.classifier_location == "pebblo-cloud":
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.")

@classmethod
def get_chain_details(cls, llm, **kwargs): # type: ignore
def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore
llm_dict = llm.__dict__
chain = [
{
Expand All @@ -474,6 +536,6 @@ def get_chain_details(cls, llm, **kwargs): # type: ignore
),
}
],
}
},
]
return chain
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Models for the PebbloRetrievalQA chain."""

from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from langchain_core.pydantic_v1 import BaseModel

Expand Down Expand Up @@ -137,9 +137,10 @@ class Prompt(BaseModel):

class Qa(BaseModel):
name: str
context: List[Optional[Context]]
prompt: Prompt
response: Prompt
context: Union[List[Optional[Context]], Optional[Context]]
prompt: Optional[Prompt]
response: Optional[Prompt]
prompt_time: str
user: str
user_identities: Optional[List[str]]
classifier_location: str
Loading

0 comments on commit 8227ea8

Please sign in to comment.