Skip to content

Commit

Permalink
Allow llm and embedding arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric-Shang committed Dec 6, 2024
1 parent 96d7648 commit 1295120
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 31 deletions.
26 changes: 12 additions & 14 deletions autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from typing import Dict, List, Optional, TypeAlias, Union

from llama_index.core import PropertyGraphIndex, SimpleDirectoryReader
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.indices.property_graph import SchemaLLMPathExtractor
from llama_index.core.indices.property_graph.transformations.schema_llm import Triple
from llama_index.core.llms import LLM
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
from llama_index.llms.openai import OpenAI
Expand Down Expand Up @@ -36,9 +38,8 @@ def __init__(
database: str = "neo4j",
username: str = "neo4j",
password: str = "neo4j",
model: str = "gpt-3.5-turbo",
temperature: float = 0.0,
embed_model: str = "text-embedding-3-small",
llm: LLM = OpenAI(model="gpt-3.5-turbo", temperature=0.0),
embedding: BaseEmbedding = OpenAIEmbedding(model_name="text-embedding-3-small"),
entities: Optional[TypeAlias] = None,
relations: Optional[TypeAlias] = None,
validation_schema: Optional[Union[Dict[str, str], List[Triple]]] = None,
Expand All @@ -55,9 +56,8 @@ def __init__(
database (str): Neo4j database name.
username (str): Neo4j username.
password (str): Neo4j password.
model (str): LLM model to use for Neo4j to build and retrieve from the graph, default to use OAI gpt-3.5-turbo.
temperature (float): LLM temperature.
include_embeddings (bool): Whether to include embeddings in the graph.
llm (LLM): Language model to use for extracting tripletss.
embedding (BaseEmbedding): Embedding model to use constructing index and query
entities (Optional[TypeAlias]): Custom possible entities to include in the graph.
relations (Optional[TypeAlias]): Custom poissble relations to include in the graph.
validation_schema (Optional[Union[Dict[str, str], List[Triple]]): Custom schema to validate the extracted triplets
Expand All @@ -68,9 +68,8 @@ def __init__(
self.database = database
self.username = username
self.password = password
self.model = model
self.temperature = temperature
self.embed_model = embed_model
self.llm = llm
self.embedding = embedding
self.entities = entities
self.relations = relations
self.validation_schema = validation_schema
Expand All @@ -94,7 +93,7 @@ def init_db(self, input_doc: List[Document] | None = None):
database=self.database,
)

# delete all entities and relationships if a graph pre-exists
# delete all entities and relationships in case a graph pre-exists
self._clear()

self.documents = SimpleDirectoryReader(input_files=self.input_files).load_data()
Expand All @@ -103,7 +102,7 @@ def init_db(self, input_doc: List[Document] | None = None):
# To add more extractors, please refer to https://docs.llamaindex.ai/en/latest/module_guides/indexing/lpg_index_guide/#construction
self.kg_extractors = [
SchemaLLMPathExtractor(
llm=OpenAI(model=self.model, temperature=self.temperature),
llm=self.llm,
possible_entities=self.entities,
possible_relations=self.relations,
kg_validation_schema=self.validation_schema,
Expand All @@ -113,7 +112,7 @@ def init_db(self, input_doc: List[Document] | None = None):

self.index = PropertyGraphIndex.from_documents(
self.documents,
embed_model=OpenAIEmbedding(model_name=self.embed_model),
embed_model=self.embedding,
kg_extractors=self.kg_extractors,
property_graph_store=self.graph_store,
show_progress=True,
Expand Down Expand Up @@ -180,8 +179,7 @@ def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryR
def _clear(self) -> None:
"""
Delete all entities and relationships in the graph.
# TODO: Delete all the data in the database including indexes and constraints.
TODO: Delete all the data in the database including indexes and constraints.
"""
# %%
with self.graph_store._driver.session() as session:
session.run("MATCH (n) DETACH DELETE n;")
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def _reply_using_neo4j_query(
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""
Query neo4j and return the message. Internally, it utilises OpenAI to generate a reply based on the given messages.
The performance will be improved in future releases.
Query neo4j and return the message. Internally, it queries the Property graph
and returns the answer from the graph query engine.
TODO: reply with a dictionary including both the answer and semantic source triplets.
Args:
recipient: The agent instance that will receive the message.
Expand Down
Loading

0 comments on commit 1295120

Please sign in to comment.