Skip to content

Commit

Permalink
Merge pull request #21 from Rishabh-git10/gpt4all
Browse files Browse the repository at this point in the history
Added gpt4all support & user input for questions
  • Loading branch information
Priyamakeshwari authored Oct 26, 2023
2 parents d39cf21 + fef2822 commit b1e6654
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 47 deletions.
Binary file modified __pycache__/config.cpython-310.pyc
Binary file not shown.
100 changes: 53 additions & 47 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import logging
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.llms import LlamaCpp
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.callbacks.manager import CallbackManager
from langchain.vectorstores import Chroma
from langchain.prompts import PromptTemplate
from huggingface_hub import hf_hub_download
from gpt4all import GPT4All

from config import (
PERSIST_DIRECTORY,
MODEL_DIRECTORY,
SOURCE_DIR,
EMBEDDING_MODEL,
DEVICE_TYPE,
CHROMA_SETTINGS,
Expand All @@ -21,48 +21,55 @@
MAX_TOKEN_LENGTH,
)

def load_model(device_type:str = DEVICE_TYPE, model_id:str = MODEL_NAME, model_basename:str = MODEL_FILE, LOGGING=logging):
def load_model(model_choice, device_type=DEVICE_TYPE, model_id=MODEL_NAME, model_basename=MODEL_FILE, LOGGING=logging):
"""
Load a language model.
Load a language model (either LlamaCpp or GPT4All).
Args:
model_choice (str): The choice of the model to load ('LlamaCpp' or 'GPT4All').
device_type (str): The type of device to use ('cuda', 'mps', or 'cpu').
model_id (str): The ID of the model to load.
model_basename (str): The name of the model file.
LOGGING (logging): The logging object.
Returns:
LlamaCpp: The loaded language model.
LlamaCpp or GPT4All: The loaded language model.
"""
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
try:
model_path = hf_hub_download(
repo_id=model_id,
filename=model_basename,
resume_download=True,
cache_dir=MODEL_DIRECTORY,
)
kwargs = {
"model_path": model_path,
"max_tokens": MAX_TOKEN_LENGTH,
"n_ctx": MAX_TOKEN_LENGTH,
"n_batch": 512,
"callback_manager": callback_manager,
"verbose":False,
"f16_kv":True,
"streaming":True,
}
if device_type.lower() == "mps":
kwargs["n_gpu_layers"] = 1
if device_type.lower() == "cuda":
kwargs["n_gpu_layers"] = N_GPU_LAYERS # set this based on your GPU
llm = LlamaCpp(**kwargs)
LOGGING.info(f"Loaded {model_id} locally")
return llm # Returns a LlamaCpp object
if model_choice == 'LlamaCpp':
model_path = hf_hub_download(
repo_id=model_id,
filename=model_basename,
resume_download=True,
cache_dir=MODEL_DIRECTORY,
)
kwargs = {
"model_path": model_path,
"max_tokens": MAX_TOKEN_LENGTH,
"n_ctx": MAX_TOKEN_LENGTH,
"n_batch": 512,
"callback_manager": callback_manager,
"verbose": False,
"f16_kv": True,
"streaming": True,
}
if device_type.lower() == "mps":
kwargs["n_gpu_layers"] = 1
if device_type.lower() == "cuda":
kwargs["n_gpu_layers"] = N_GPU_LAYERS # set this based on your GPU
llm = LlamaCpp(**kwargs)
LOGGING.info(f"Loaded {model_id} locally")
return llm # Returns a LlamaCpp object
elif model_choice == 'GPT4All':
gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
return gpt4all_model
else:
LOGGING.info("Invalid model choice. Choose 'LlamaCpp' or 'GPT4All'.")
except Exception as e:
LOGGING.info(f"Error {e}")

def retriver(device_type:str = DEVICE_TYPE, LOGGING=logging):
def retriver(device_type=DEVICE_TYPE, LOGGING=logging):
"""
Retrieve information using a language model and Chroma database.
Expand All @@ -80,26 +87,25 @@ def retriver(device_type:str = DEVICE_TYPE, LOGGING=logging):
embedding_function=embeddings,
)
retriever = db.as_retriever()
LOGGING.info(f"Loaded Chroma DB Successfully")
llm = load_model(device_type, model_id=MODEL_NAME, model_basename=MODEL_FILE, LOGGING=logging)
template = """
[INST]
Context: {summaries}
User: {question}
[/INST]
"""
prompt = PromptTemplate(input_variables=["summaries", "question"], template=template)
chain = RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
retriever=retriever,
# chain_type="stuff",
chain_type_kwargs={"prompt": prompt},
)

chain({'question' : "What is the linux command to list files in direcotyu",},return_only_outputs=True)
model_choice = input("Choose a model (LlamaCpp or GPT4All): ")

model = load_model(model_choice, device_type, model_id=MODEL_NAME, model_basename=MODEL_FILE, LOGGING=logging)

if model_choice == 'LlamaCpp':
while True:
question = input("Enter your question (type 'exit' to quit): ")
if question.lower() == 'exit':
break
response = model(question)
print(response)
elif model_choice == 'GPT4All':
while True:
question = input("Enter your question (type 'exit' to quit): ")
if question.lower() == 'exit':
break
response = model.generate(question, max_tokens=50)
print(response)

if __name__ == '__main__':
retriver()


0 comments on commit b1e6654

Please sign in to comment.