Skip to content

Commit

Permalink
ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed May 6, 2024
1 parent 7fe57ac commit fdb64e9
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 143 deletions.
5 changes: 1 addition & 4 deletions src/agents/awesome.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,5 @@

model_inference = ModelInference(chat_template="chatml")
model_inference.generate_function_call(
"I need the current stock price of Tesla (TSLA)",
"chatml",
None,
5
"I need the current stock price of Tesla (TSLA)", "chatml", None, 5
)
2 changes: 1 addition & 1 deletion src/agents/hermes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Credits NousResearch
# https://github.com/NousResearch/Hermes-Function-Calling
# https://github.com/NousResearch/Hermes-Function-Calling
116 changes: 75 additions & 41 deletions src/agents/hermes/functioncall.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import argparse
import json

from transformers import (
AutoTokenizer
)
from transformers import AutoTokenizer
from llama_cpp import Llama

from . import functions
Expand All @@ -17,21 +15,25 @@
inference_logger,
get_assistant_message,
get_chat_template,
validate_and_extract_tool_calls
validate_and_extract_tool_calls,
)


class ModelInference:
def __init__(self, model_path="/Users/aniket/weights/llama-cpp/Hermes-2-Pro-Llama-3-8B-Q8_0.gguf",
chat_template="chatml"):
def __init__(
self,
model_path="/Users/aniket/weights/llama-cpp/Hermes-2-Pro-Llama-3-8B-Q8_0.gguf",
chat_template="chatml",
):
inference_logger.info(print_nous_text_art())
self.prompter = PromptManager()
self.model = Llama(model_path=model_path,
n_gpu_layers=1,
n_ctx=4096,
verbose=False)
self.model = Llama(
model_path=model_path, n_gpu_layers=1, n_ctx=10000, verbose=False
)

self.tokenizer = AutoTokenizer.from_pretrained('NousResearch/Hermes-2-Pro-Llama-3-8B', trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Hermes-2-Pro-Llama-3-8B", trust_remote_code=True
)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "left"

Expand All @@ -40,14 +42,19 @@ def __init__(self, model_path="/Users/aniket/weights/llama-cpp/Hermes-2-Pro-Llam
self.tokenizer.chat_template = get_chat_template(chat_template)

def process_completion_and_validate(self, completion, chat_template):

assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token)
assistant_message = get_assistant_message(
completion, chat_template, self.tokenizer.eos_token
)

if assistant_message:
validation, tool_calls, error_message = validate_and_extract_tool_calls(assistant_message)
validation, tool_calls, error_message = validate_and_extract_tool_calls(
assistant_message
)

if validation:
inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
inference_logger.info(
f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}"
)
return tool_calls, assistant_message, error_message
else:
tool_calls = None
Expand All @@ -68,18 +75,11 @@ def execute_function_call(self, tool_call):

def run_inference(self, prompt) -> str:
inputs = self.tokenizer.apply_chat_template(
prompt,
add_generation_prompt=True,
tokenize=False
prompt, add_generation_prompt=True, tokenize=False
)
inference_logger.info(f"inputs:\n{inputs}")
print()
completions = self.model(
inputs,
max_tokens=2000,
temperature=0.8,
echo=True
)
completions = self.model(inputs, max_tokens=2000, temperature=0.8, echo=True)
completion = completions["choices"][0]["text"]
inference_logger.info(f"completion:\n{completion}")
return completion
Expand All @@ -95,24 +95,34 @@ def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5)

def recursive_loop(prompt, completion, depth):
nonlocal max_depth
tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion,
chat_template)
tool_calls, assistant_message, error_message = (
self.process_completion_and_validate(completion, chat_template)
)
prompt.append({"role": "assistant", "content": assistant_message})

tool_message = f"Agent iteration {depth} to assist with user query: {query}\n"
tool_message = (
f"Agent iteration {depth} to assist with user query: {query}\n"
)
if tool_calls:
inference_logger.info(f"Assistant Message:\n{assistant_message}")

for tool_call in tool_calls:
validation, message = validate_function_call_schema(tool_call, tools)
validation, message = validate_function_call_schema(
tool_call, tools
)
if validation:
try:
function_response = self.execute_function_call(tool_call)
function_response = self.execute_function_call(
tool_call
)
tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
inference_logger.info(
f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}")
f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}"
)
except Exception as e:
inference_logger.info(f"Could not execute function: {e}")
inference_logger.info(
f"Could not execute function: {e}"
)
tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
else:
inference_logger.info(message)
Expand All @@ -121,7 +131,9 @@ def recursive_loop(prompt, completion, depth):

depth += 1
if depth >= max_depth:
print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
print(
f"Maximum recursion depth reached ({max_depth}). Stopping recursion."
)
return

completion = self.run_inference(prompt)
Expand All @@ -133,7 +145,9 @@ def recursive_loop(prompt, completion, depth):

depth += 1
if depth >= max_depth:
print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
print(
f"Maximum recursion depth reached ({max_depth}). Stopping recursion."
)
return

completion = self.run_inference(prompt)
Expand All @@ -151,19 +165,39 @@ def recursive_loop(prompt, completion, depth):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run recursive function calling loop")
parser.add_argument("--model_path", type=str, help="Path to the model folder")
parser.add_argument("--chat_template", type=str, default="chatml", help="Chat template for prompt formatting")
parser.add_argument("--num_fewshot", type=int, default=None, help="Option to use json mode examples")
parser.add_argument("--load_in_4bit", type=str, default="False", help="Option to load in 4bit with bitsandbytes")
parser.add_argument("--query", type=str, default="I need the current stock price of Tesla (TSLA)")
parser.add_argument("--max_depth", type=int, default=5, help="Maximum number of recursive iteration")
parser.add_argument(
"--chat_template",
type=str,
default="chatml",
help="Chat template for prompt formatting",
)
parser.add_argument(
"--num_fewshot", type=int, default=None, help="Option to use json mode examples"
)
parser.add_argument(
"--load_in_4bit",
type=str,
default="False",
help="Option to load in 4bit with bitsandbytes",
)
parser.add_argument(
"--query", type=str, default="I need the current stock price of Tesla (TSLA)"
)
parser.add_argument(
"--max_depth", type=int, default=5, help="Maximum number of recursive iteration"
)
args = parser.parse_args()

# specify custom model path
if args.model_path:
inference = ModelInference(args.model_path, args.chat_template, args.load_in_4bit)
inference = ModelInference(
args.model_path, args.chat_template, args.load_in_4bit
)
else:
model_path = 'NousResearch/Hermes-2-Pro-Llama-3-8B'
model_path = "NousResearch/Hermes-2-Pro-Llama-3-8B"
inference = ModelInference(model_path, args.chat_template, args.load_in_4bit)

# Run the model evaluator
inference.generate_function_call(args.query, args.chat_template, args.num_fewshot, args.max_depth)
inference.generate_function_call(
args.query, args.chat_template, args.num_fewshot, args.max_depth
)
102 changes: 63 additions & 39 deletions src/agents/hermes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def code_interpreter(code_markdown: str) -> dict | str:
"""
try:
# Extracting code from Markdown code block
code_lines = code_markdown.split('\n')[1:-1]
code_without_markdown = '\n'.join(code_lines)
code_lines = code_markdown.split("\n")[1:-1]
code_without_markdown = "\n".join(code_lines)

# Create a new namespace for code execution
exec_namespace = {}
Expand All @@ -53,9 +53,11 @@ def code_interpreter(code_markdown: str) -> dict | str:
except TypeError:
# If the function requires arguments, attempt to call it with arguments from the namespace
arg_names = inspect.getfullargspec(value).args
args = {arg_name: exec_namespace.get(arg_name) for arg_name in arg_names}
args = {
arg_name: exec_namespace.get(arg_name) for arg_name in arg_names
}
result_dict[name] = value(**args)
elif not name.startswith('_'): # Exclude variables starting with '_'
elif not name.startswith("_"): # Exclude variables starting with '_'
result_dict[name] = value

return result_dict
Expand All @@ -78,51 +80,73 @@ def google_search_and_scrape(query: str) -> dict:
list: A list of dictionaries containing the URL, text content, and table data for each scraped page.
"""
num_results = 2
url = 'https://www.google.com/search'
params = {'q': query, 'num': num_results}
url = "https://www.google.com/search"
params = {"q": query, "num": num_results}
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.3'}
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.3"
}

inference_logger.info(f"Performing google search with query: {query}\nplease wait...")
inference_logger.info(
f"Performing google search with query: {query}\nplease wait..."
)
response = requests.get(url, params=params, headers=headers)
soup = BeautifulSoup(response.text, 'html.parser')
urls = [result.find('a')['href'] for result in soup.find_all('div', class_='tF2Cxc')]
soup = BeautifulSoup(response.text, "html.parser")
urls = [
result.find("a")["href"] for result in soup.find_all("div", class_="tF2Cxc")
]

inference_logger.info(f"Scraping text from urls, please wait...")
[inference_logger.info(url) for url in urls]
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(
lambda url: (url, requests.get(url, headers=headers).text if isinstance(url, str) else None), url) for url
in urls[:num_results] if isinstance(url, str)]
futures = [
executor.submit(
lambda url: (
url,
requests.get(url, headers=headers).text
if isinstance(url, str)
else None,
),
url,
)
for url in urls[:num_results]
if isinstance(url, str)
]
results = []
for future in concurrent.futures.as_completed(futures):
url, html = future.result()
soup = BeautifulSoup(html, 'html.parser')
paragraphs = [p.text.strip() for p in soup.find_all('p') if p.text.strip()]
text_content = ' '.join(paragraphs)
text_content = re.sub(r'\s+', ' ', text_content)
table_data = [[cell.get_text(strip=True) for cell in row.find_all('td')] for table in soup.find_all('table')
for row in table.find_all('tr')]
soup = BeautifulSoup(html, "html.parser")
paragraphs = [p.text.strip() for p in soup.find_all("p") if p.text.strip()]
text_content = " ".join(paragraphs)
text_content = re.sub(r"\s+", " ", text_content)
table_data = [
[cell.get_text(strip=True) for cell in row.find_all("td")]
for table in soup.find_all("table")
for row in table.find_all("tr")
]
if text_content or table_data:
results.append({'url': url, 'content': text_content, 'tables': table_data})
results.append(
{"url": url, "content": text_content, "tables": table_data}
)
return results


@tool
def get_current_stock_price(symbol: str) -> float:
"""
Get the current stock price for a given symbol.
Get the current stock price for a given symbol.
Args:
symbol (str): The stock symbol.
Args:
symbol (str): The stock symbol.
Returns:
float: The current stock price, or None if an error occurs.
"""
Returns:
float: The current stock price, or None if an error occurs.
"""
try:
stock = yf.Ticker(symbol)
# Use "regularMarketPrice" for regular market hours, or "currentPrice" for pre/post market
current_price = stock.info.get("regularMarketPrice", stock.info.get("currentPrice"))
current_price = stock.info.get(
"regularMarketPrice", stock.info.get("currentPrice")
)
return current_price if current_price else None
except Exception as e:
print(f"Error fetching current price for {symbol}: {e}")
Expand Down Expand Up @@ -157,18 +181,18 @@ def get_stock_fundamentals(symbol: str) -> dict:
stock = yf.Ticker(symbol)
info = stock.info
fundamentals = {
'symbol': symbol,
'company_name': info.get('longName', ''),
'sector': info.get('sector', ''),
'industry': info.get('industry', ''),
'market_cap': info.get('marketCap', None),
'pe_ratio': info.get('forwardPE', None),
'pb_ratio': info.get('priceToBook', None),
'dividend_yield': info.get('dividendYield', None),
'eps': info.get('trailingEps', None),
'beta': info.get('beta', None),
'52_week_high': info.get('fiftyTwoWeekHigh', None),
'52_week_low': info.get('fiftyTwoWeekLow', None)
"symbol": symbol,
"company_name": info.get("longName", ""),
"sector": info.get("sector", ""),
"industry": info.get("industry", ""),
"market_cap": info.get("marketCap", None),
"pe_ratio": info.get("forwardPE", None),
"pb_ratio": info.get("priceToBook", None),
"dividend_yield": info.get("dividendYield", None),
"eps": info.get("trailingEps", None),
"beta": info.get("beta", None),
"52_week_high": info.get("fiftyTwoWeekHigh", None),
"52_week_low": info.get("fiftyTwoWeekLow", None),
}
return fundamentals
except Exception as e:
Expand Down
Loading

0 comments on commit fdb64e9

Please sign in to comment.