This repository has been archived by the owner on Feb 12, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/update 01 02 24 rebased (#136)
* update model * tally distribution * update synthesizer project id * update output script * updates * fix scripts
- Loading branch information
1 parent
8cddbfd
commit d45900f
Showing
12 changed files
with
400 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ authors = ["Owen Colegrove <[email protected]>"] | |
license = "Apache-2.0" | ||
readme = "README.md" | ||
name = 'sciphi-synthesizer' | ||
version = '1.0.3' | ||
version = '1.0.5' | ||
packages = [ | ||
{ include = "synthesizer" } | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base import BingRAGConfig, BingRAGInterface | ||
|
||
__all__ = ["BingRAGConfig", "BingRAGInterface"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import os | ||
from dataclasses import dataclass | ||
|
||
from synthesizer.core import RAGProviderName | ||
from synthesizer.interface.base import ( | ||
RAGInterface, | ||
RAGProviderConfig, | ||
RagResult, | ||
) | ||
from synthesizer.interface.rag_interface_manager import ( | ||
rag_config, | ||
rag_provider, | ||
) | ||
|
||
from .bing_client import BingSearchClient # Import your BingSearchClient | ||
|
||
|
||
@dataclass | ||
@rag_config | ||
class BingRAGConfig(RAGProviderConfig): | ||
"""Configuration for the Bing RAG provider.""" | ||
|
||
provider_name: RAGProviderName = RAGProviderName.BING | ||
api_base: str = "https://api.bing.microsoft.com/v7.0/search" | ||
limit_results: int = 30 | ||
|
||
|
||
@rag_provider | ||
class BingRAGInterface(RAGInterface): | ||
"""A RAG provider that uses Bing as the retrieval source.""" | ||
|
||
provider_name = RAGProviderName.BING | ||
FORMAT_INDENT = " " | ||
|
||
def __init__( | ||
self, config: BingRAGConfig = BingRAGConfig(), *args, **kwargs | ||
) -> None: | ||
super().__init__(config) | ||
self.config: BingRAGConfig = config | ||
print('self.config = ', self.config) | ||
api_key = self.config.api_key or os.getenv("BING_API_KEY") | ||
if not api_key: | ||
raise ValueError( | ||
"No API key provided. Please provide an API key or set the BING_API_KEY environment variable." | ||
) | ||
self.client = BingSearchClient(api_key) | ||
|
||
def get_rag_context(self, query) -> RagResult: | ||
"""Retrieve context for a given query using Bing.""" | ||
results = self.client.search(query, self.config.limit_results) | ||
serp_results = self.client.format_as_serp_results(results) | ||
SPLIT_MARKER = "/" | ||
context = "\n\n".join( | ||
[ | ||
f"{i+1}. URL: {SPLIT_MARKER.join(result.url.split(SPLIT_MARKER)[0:4])}\nTitle: {result.title}\nSnippet:\n{result.text}" | ||
for i, result in enumerate(serp_results) | ||
] | ||
) | ||
return RagResult( | ||
context=context, | ||
meta_data=[ele.to_string_dict() for ele in serp_results], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
from typing import Any, Dict, List | ||
|
||
import requests | ||
|
||
from .bing_types import ( | ||
Creator, | ||
DisplayConfig, | ||
Entity, | ||
ImageInfo, | ||
Publisher, | ||
SearchResult, | ||
Video, | ||
WebPage, | ||
) | ||
|
||
|
||
class BingSearchClient: | ||
FIELD_NAME_MAPPING = { | ||
"name": "Title", | ||
"description": "Snippet", | ||
"snippet": "Snippet", | ||
"url": "URL", | ||
"contentUrl": "URL", | ||
} | ||
|
||
def __init__(self, subscription_key: str): | ||
self.subscription_key = subscription_key | ||
self.search_url = "https://api.bing.microsoft.com/v7.0/search" | ||
self.headers = {"Ocp-Apim-Subscription-Key": self.subscription_key} | ||
|
||
def search(self, query: str, count: int = 30) -> Dict[str, Any]: | ||
params = { | ||
"q": query, | ||
"textDecorations": True, | ||
"textFormat": "HTML", | ||
"count": count, | ||
} | ||
response = requests.get( | ||
self.search_url, headers=self.headers, params=params | ||
) | ||
response.raise_for_status() | ||
search_results = response.json() | ||
|
||
# Parse all types of data | ||
parsed_data = { | ||
"entities": self.parse_entities( | ||
search_results.get("entities", {}) | ||
), | ||
"related_queries": self.parse_related_queries( | ||
search_results.get("relatedSearches", {}) | ||
), | ||
"web_pages": self.parse_web_pages( | ||
search_results.get("webPages", {}) | ||
), | ||
"videos": self.parse_videos(search_results.get("videos", {})), | ||
} | ||
|
||
return parsed_data | ||
|
||
def parse_entities(self, entities_data: Dict[str, Any]) -> List[Entity]: | ||
entities = entities_data.get("value", []) | ||
return [Entity.construct(**entity) for entity in entities] | ||
|
||
def parse_related_queries( | ||
self, related_queries_data: Dict[str, Any] | ||
) -> List[str]: | ||
queries = related_queries_data.get("value", []) | ||
return [query.get("text", "N/A") for query in queries] | ||
|
||
def parse_web_pages(self, web_pages_data: Dict[str, Any]) -> List[WebPage]: | ||
web_pages = web_pages_data.get("value", []) | ||
return [WebPage.construct(**web_page) for web_page in web_pages] | ||
|
||
def parse_videos(self, videos_data: Dict[str, Any]) -> List[Video]: | ||
videos = videos_data.get("value", []) | ||
return [ | ||
Video.construct( | ||
webSearchUrl=video["webSearchUrl"], | ||
name=video["name"], | ||
description=video["description"], | ||
thumbnail=ImageInfo( | ||
thumbnailUrl=video["thumbnailUrl"], | ||
hostPageUrl=video["hostPageUrl"], | ||
width=video["width"], | ||
height=video["height"], | ||
sourceWidth=video.get( | ||
"sourceWidth", video["width"] | ||
), # Assuming sourceWidth is same as width if not provided | ||
sourceHeight=video.get( | ||
"sourceHeight", video["height"] | ||
), # Assuming sourceHeight is same as height if not provided | ||
), | ||
datePublished=video["datePublished"], | ||
publisher=[ | ||
Publisher(name=p["name"]) for p in video["publisher"] | ||
], | ||
creator=Creator(name=video["creator"]["name"]) | ||
if video.get("creator") | ||
else None, | ||
contentUrl=video["contentUrl"], | ||
hostPageUrl=video["hostPageUrl"], | ||
encodingFormat=video["encodingFormat"], | ||
hostPageDisplayUrl=video["hostPageDisplayUrl"], | ||
duration=video.get("duration"), | ||
viewCount=video.get("viewCount"), | ||
) | ||
for video in videos | ||
] | ||
|
||
def print_search_results( | ||
self, search_results: Dict[str, Any], config: DisplayConfig | ||
) -> str: | ||
output = [] | ||
global_index = 1 # Initialize global index | ||
|
||
def format_item(item, fields): | ||
nonlocal global_index | ||
item_info = ", ".join( | ||
f"{BingSearchClient.FIELD_NAME_MAPPING.get(field, field)}: {getattr(item, field)}" | ||
for field in fields | ||
) | ||
formatted_item = f"{global_index}.) {item_info}" | ||
global_index += 1 | ||
return formatted_item | ||
|
||
if config.show_entities and "entities" in search_results: | ||
entities_output = ["Entities:"] + [ | ||
format_item(entity, config.entity_fields) | ||
for entity in search_results["entities"] | ||
] | ||
output.append("\n".join(entities_output)) | ||
|
||
if config.show_related_queries and "related_queries" in search_results: | ||
related_queries_output = ["Related Queries:"] + [ | ||
f"{global_index}. {query}" | ||
for query in search_results["related_queries"] | ||
] | ||
global_index += len(search_results["related_queries"]) | ||
output.append("\n".join(related_queries_output)) | ||
|
||
if config.show_web_pages and "web_pages" in search_results: | ||
web_pages_output = ["Web Pages:"] + [ | ||
format_item(web_page, config.web_page_fields) | ||
for web_page in search_results["web_pages"] | ||
] | ||
output.append("\n".join(web_pages_output)) | ||
|
||
if config.show_videos and "videos" in search_results: | ||
videos_output = ["Videos:"] + [ | ||
format_item(video, config.video_fields) | ||
for video in search_results["videos"] | ||
] | ||
output.append("\n".join(videos_output)) | ||
|
||
return "\n\n".join(output) | ||
|
||
def format_as_serp_results(self, search_results: Dict[str, Any]) -> str: | ||
web_pages = search_results.get("web_pages", []) | ||
|
||
results = [] | ||
for web_page in web_pages: | ||
results.append( | ||
SearchResult( | ||
url=getattr(web_page, "url", ""), | ||
title=getattr(web_page, "name", ""), | ||
dataset=f"Bing Search", | ||
metadata="", | ||
text=getattr(web_page, "description", "") | ||
or getattr(web_page, "snippet", ""), | ||
) | ||
) | ||
return results |
Oops, something went wrong.