Skip to content

Commit

Permalink
feat: add question network view
Browse files Browse the repository at this point in the history
  • Loading branch information
aymenfurter committed Jul 24, 2024
1 parent 93ddeb7 commit 4f15ef5
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 33 deletions.
33 changes: 33 additions & 0 deletions app/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import requests
from flask import Response
from openai import AzureOpenAI
import numpy as np


def get_azure_openai_client(api_key: str = None, api_version: str = None, azure_endpoint: str = None) -> AzureOpenAI:
"""Create and return an AzureOpenAI client."""
Expand Down Expand Up @@ -76,6 +78,37 @@ def get_openai_config() -> Dict[str, str]:
"SEARCH_SERVICE_API_KEY": os.environ.get('SEARCH_SERVICE_API_KEY', '')
}

def get_openai_embedding(text: str) -> Dict[str, Any]:
"""Calculate OpenAI embedding value for a given text."""
config = get_openai_config()
url = f"{config['OPENAI_ENDPOINT']}/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-02-15-preview"
headers = {
"Content-Type": "application/json",
"api-key": config['AOAI_API_KEY']
}
payload = {
"input": text,
"model": "text-embedding-ada-002"
}

response = get_response(url, headers, payload)

if response.get("error"):
return response

embedding = response["data"][0]["embedding"]

return {"embedding": np.array(embedding)}

def calculate_cosine_similarity(vector1: np.ndarray, vector2: np.ndarray) -> float:
"""Calculate cosine similarity between two vectors."""
dot_product = np.dot(vector1, vector2)
norm_vector1 = np.linalg.norm(vector1)
norm_vector2 = np.linalg.norm(vector2)
similarity = dot_product / (norm_vector1 * norm_vector2)

return similarity

def _get_image_analysis_prompt() -> str:
"""Return the prompt for image analysis."""
return """
Expand Down
31 changes: 26 additions & 5 deletions app/research.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from typing import Dict, Any, Callable, Annotated, Generator, Union
from flask import Response
from autogen import AssistantAgent, UserProxyAgent, GroupChat, GroupChatManager
from .azure_openai import create_payload, create_data_source, get_openai_config
from .azure_openai import create_payload, create_data_source, get_openai_config, get_openai_embedding, calculate_cosine_similarity
from .index_manager import create_index_manager, ContainerNameTooLongError
import time
import requests
import numpy as np

class RateLimitException(Exception):
pass
Expand Down Expand Up @@ -135,6 +136,13 @@ def generate_final_conclusion(chat_result: Any) -> str:

return response.json()["choices"][0]["message"]["content"], response.json()

@retry_request
def get_embedding_with_retry(text: str) -> np.ndarray:
response = get_openai_embedding(text)
if "error" in response:
raise RateLimitException(response["error"])
return response["embedding"]

def extract_citations(text):
citations = []
citation_pattern = r'\[([^\]]+)\]\(([^)]+)\)'
Expand Down Expand Up @@ -163,6 +171,8 @@ def research_with_data(data: Dict[str, Any], user_id: str) -> Generator[str, Non
}

message_queue = queue.Queue()
previous_queries = []
previous_embeddings = []

def yield_update(update_type, content):
event = json.dumps({"type": update_type, "content": content}) + '\n'
Expand Down Expand Up @@ -203,15 +213,26 @@ def yield_update(update_type, content):

def create_lookup_function(index: str) -> Callable[[Annotated[str, f"Use this function to search for information on the data source: {index_name}"]], str]:
def lookup_information(question: Annotated[str, f"Use this function to search for information on the data source: {index_name}"]) -> str:
yield_update('search', {'index': index, 'query': question})
# Calculate similarity with previous queries
query_embedding = get_embedding_with_retry(question)
previous_embeddings.append(query_embedding)
previous_queries.append(question)

related_query = None
if len(previous_queries) > 1:
similarities = [calculate_cosine_similarity(query_embedding, prev_embedding) for prev_embedding in previous_embeddings[:-1]]
most_similar_index = np.argmax(similarities)
related_query = previous_queries[most_similar_index]

yield_update('search', {'index': index, 'query': question, 'relatedQuery': related_query})
result, full_response = search(question, index)
yield_update('search_complete', {'index': index, 'result': result, 'full_response': full_response})

# Extract and yield individual citations
citations = extract_citations(result)
for citation in citations:
yield_update('citation', citation)
yield_update('citation', {'query': question, **citation})

return result
return lookup_information

Expand Down Expand Up @@ -276,4 +297,4 @@ def chat_thread():

citations = extract_citations(final_conclusion)
for citation in citations:
yield json.dumps({"type": "final_citation", "content": citation}) + '\n'
yield json.dumps({"type": "final_citation", "content": citation}) + '\n'
81 changes: 80 additions & 1 deletion frontend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"react-dom": "^17.0.2",
"react-scripts": "4.0.3",
"recharts": "^2.12.7",
"styled-components": "^5.3.0"
"styled-components": "^5.3.0",
"vis-network": "^9.1.9"
},
"scripts": {
"start": "react-scripts start",
Expand Down
Loading

0 comments on commit 4f15ef5

Please sign in to comment.