diff --git a/libs/community/langchain_community/utilities/you.py b/libs/community/langchain_community/utilities/you.py index b74841c868b17..6fefecc29398d 100644 --- a/libs/community/langchain_community/utilities/you.py +++ b/libs/community/langchain_community/utilities/you.py @@ -2,6 +2,8 @@ In order to set this up, follow instructions at: """ + +import warnings from typing import Any, Dict, List, Literal, Optional import aiohttp @@ -44,11 +46,11 @@ class YouDocument(BaseModel): class YouSearchAPIWrapper(BaseModel): - """Wrapper for you.com Search API. + """Wrapper for you.com Search and News API. To connect to the You.com api requires an API key which you can get at https://api.you.com. - You can check out the docs at https://documentation.you.com. + You can check out the docs at https://documentation.you.com/api-reference/. You need to set the environment variable `YDC_API_KEY` for retriever to operate. @@ -56,33 +58,50 @@ class YouSearchAPIWrapper(BaseModel): ---------- ydc_api_key: str, optional you.com api key, if YDC_API_KEY is not set in the environment + endpoint_type: str, optional + you.com endpoints: search, news, rag; + `web` and `snippet` alias `search` + `rag` returns `{'message': 'Forbidden'}` + @todo `news` endpoint num_web_results: int, optional - The max number of web results to return, must be under 20 + The max number of web results to return, must be under 20. + This is mapped to the `count` query parameter for the News API. safesearch: str, optional Safesearch settings, one of off, moderate, strict, defaults to moderate country: str, optional - Country code, ex: 'US' for united states, see api docs for list + Country code, ex: 'US' for United States, see api docs for list + search_lang: str, optional + (News API) Language codes, ex: 'en' for English, see api docs for list + ui_lang: str, optional + (News API) User interface language for the response, ex: 'en' for English, + see api docs for list + spellcheck: bool, optional + (News API) Whether to spell check query or not, defaults to True k: int, optional max number of Documents to return using `results()` n_hits: int, optional, deprecated Alias for num_web_results n_snippets_per_hit: int, optional limit the number of snippets returned per hit - endpoint_type: str, optional - you.com endpoints: search, news, rag; - `web` and `snippet` alias `search` - `rag` returns `{'message': 'Forbidden'}` - @todo `news` endpoint """ ydc_api_key: Optional[str] = None + + # @todo deprecate `snippet`, not part of API + endpoint_type: Literal["search", "news", "rag", "snippet"] = "search" + + # Common fields between Search and News API num_web_results: Optional[int] = None - safesearch: Optional[str] = None + safesearch: Optional[Literal["off", "moderate", "strict"]] = None country: Optional[str] = None + + # News API specific fields + search_lang: Optional[str] = None + ui_lang: Optional[str] = None + spellcheck: Optional[bool] = None + k: Optional[int] = None n_snippets_per_hit: Optional[int] = None - # @todo deprecate `snippet`, not part of API - endpoint_type: Literal["search", "news", "rag", "snippet"] = "search" # should deprecate n_hits n_hits: Optional[int] = None @@ -94,6 +113,74 @@ def validate_environment(cls, values: Dict) -> Dict: return values + @root_validator + def warn_if_set_fields_have_no_effect(cls, values: Dict) -> Dict: + if values["endpoint_type"] != "news": + news_api_fields = ("search_lang", "ui_lang", "spellcheck") + for field in news_api_fields: + if values[field]: + warnings.warn( + ( + f"News API-specific field '{field}' is set but " + f"`endpoint_type=\"{values['endpoint_type']}\"`. " + "This will have no effect." + ), + UserWarning, + ) + if values["endpoint_type"] not in ("search", "snippet"): + if values["n_snippets_per_hit"]: + warnings.warn( + ( + "Field 'n_snippets_per_hit' only has effect on " + '`endpoint_type="search"`.' + ), + UserWarning, + ) + return values + + @root_validator + def warn_if_deprecated_endpoints_are_used(cls, values: Dict) -> Dict: + if values["endpoint_type"] == "snippets": + warnings.warn( + ( + f"`endpoint_type=\"{values['endpoint_type']}\"` is deprecated. " + 'Use `endpoint_type="search"` instead.' + ), + DeprecationWarning, + ) + return values + + def _generate_params(self, query: str, **kwargs: Any) -> Dict: + """ + Parse parameters required for different You.com APIs. + + Args: + query: The query to search for. + """ + params = { + "safesearch": self.safesearch, + "country": self.country, + **kwargs, + } + + # Add endpoint-specific params + if self.endpoint_type in ("search", "snippet"): + params.update( + query=query, + num_web_results=self.num_web_results, + ) + elif self.endpoint_type == "news": + params.update( + q=query, + count=self.num_web_results, + search_lang=self.search_lang, + ui_lang=self.ui_lang, + spellcheck=self.spellcheck, + ) + + params = {k: v for k, v in params.items() if v is not None} + return params + def _parse_results(self, raw_search_results: Dict) -> List[Document]: """ Extracts snippets from each hit and puts them in a Document @@ -105,9 +192,12 @@ def _parse_results(self, raw_search_results: Dict) -> List[Document]: # return news results if self.endpoint_type == "news": + news_results = raw_search_results["news"]["results"] + if self.k is not None: + news_results = news_results[: self.k] return [ Document(page_content=result["description"], metadata=result) - for result in raw_search_results["news"]["results"] + for result in news_results ] docs = [] @@ -138,26 +228,10 @@ def raw_results( Args: query: The query to search for. - num_web_results: The maximum number of results to return. - safesearch: Safesearch settings, - one of off, moderate, strict, defaults to moderate - country: Country code Returns: YouAPIOutput """ headers = {"X-API-Key": self.ydc_api_key or ""} - params = { - "query": query, - "num_web_results": self.num_web_results, - "safesearch": self.safesearch, - "country": self.country, - **kwargs, - } - - params = {k: v for k, v in params.items() if v is not None} - # news endpoint expects `q` instead of `query` - if self.endpoint_type == "news": - params["q"] = params["query"] - del params["query"] + params = self._generate_params(query, **kwargs) # @todo deprecate `snippet`, not part of API if self.endpoint_type == "snippet": @@ -192,18 +266,7 @@ async def raw_results_async( """Get results from the you.com Search API asynchronously.""" headers = {"X-API-Key": self.ydc_api_key or ""} - params = { - "query": query, - "num_web_results": self.num_web_results, - "safesearch": self.safesearch, - "country": self.country, - **kwargs, - } - params = {k: v for k, v in params.items() if v is not None} - # news endpoint expects `q` instead of `query` - if self.endpoint_type == "news": - params["q"] = params["query"] - del params["query"] + params = self._generate_params(query, **kwargs) # @todo deprecate `snippet`, not part of API if self.endpoint_type == "snippet":