From fdb64e97673ca264ed41199a61f0c9f306244afc Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Mon, 6 May 2024 01:14:11 +0100 Subject: [PATCH] ruff format --- src/agents/awesome.py | 5 +- src/agents/hermes/__init__.py | 2 +- src/agents/hermes/functioncall.py | 116 +++++++++++++++++++----------- src/agents/hermes/functions.py | 102 ++++++++++++++++---------- src/agents/hermes/prompter.py | 38 +++++----- src/agents/hermes/schema.py | 3 + src/agents/hermes/utils.py | 63 ++++++++++------ src/agents/hermes/validator.py | 56 ++++++++++----- 8 files changed, 242 insertions(+), 143 deletions(-) diff --git a/src/agents/awesome.py b/src/agents/awesome.py index db16217..f0c5cb2 100644 --- a/src/agents/awesome.py +++ b/src/agents/awesome.py @@ -2,8 +2,5 @@ model_inference = ModelInference(chat_template="chatml") model_inference.generate_function_call( - "I need the current stock price of Tesla (TSLA)", - "chatml", - None, - 5 + "I need the current stock price of Tesla (TSLA)", "chatml", None, 5 ) diff --git a/src/agents/hermes/__init__.py b/src/agents/hermes/__init__.py index 6b48493..0c67b3f 100644 --- a/src/agents/hermes/__init__.py +++ b/src/agents/hermes/__init__.py @@ -1,2 +1,2 @@ # Credits NousResearch -# https://github.com/NousResearch/Hermes-Function-Calling \ No newline at end of file +# https://github.com/NousResearch/Hermes-Function-Calling diff --git a/src/agents/hermes/functioncall.py b/src/agents/hermes/functioncall.py index 43f607c..ab732e7 100644 --- a/src/agents/hermes/functioncall.py +++ b/src/agents/hermes/functioncall.py @@ -3,9 +3,7 @@ import argparse import json -from transformers import ( - AutoTokenizer -) +from transformers import AutoTokenizer from llama_cpp import Llama from . import functions @@ -17,21 +15,25 @@ inference_logger, get_assistant_message, get_chat_template, - validate_and_extract_tool_calls + validate_and_extract_tool_calls, ) class ModelInference: - def __init__(self, model_path="/Users/aniket/weights/llama-cpp/Hermes-2-Pro-Llama-3-8B-Q8_0.gguf", - chat_template="chatml"): + def __init__( + self, + model_path="/Users/aniket/weights/llama-cpp/Hermes-2-Pro-Llama-3-8B-Q8_0.gguf", + chat_template="chatml", + ): inference_logger.info(print_nous_text_art()) self.prompter = PromptManager() - self.model = Llama(model_path=model_path, - n_gpu_layers=1, - n_ctx=4096, - verbose=False) + self.model = Llama( + model_path=model_path, n_gpu_layers=1, n_ctx=10000, verbose=False + ) - self.tokenizer = AutoTokenizer.from_pretrained('NousResearch/Hermes-2-Pro-Llama-3-8B', trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained( + "NousResearch/Hermes-2-Pro-Llama-3-8B", trust_remote_code=True + ) self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_side = "left" @@ -40,14 +42,19 @@ def __init__(self, model_path="/Users/aniket/weights/llama-cpp/Hermes-2-Pro-Llam self.tokenizer.chat_template = get_chat_template(chat_template) def process_completion_and_validate(self, completion, chat_template): - - assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token) + assistant_message = get_assistant_message( + completion, chat_template, self.tokenizer.eos_token + ) if assistant_message: - validation, tool_calls, error_message = validate_and_extract_tool_calls(assistant_message) + validation, tool_calls, error_message = validate_and_extract_tool_calls( + assistant_message + ) if validation: - inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}") + inference_logger.info( + f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}" + ) return tool_calls, assistant_message, error_message else: tool_calls = None @@ -68,18 +75,11 @@ def execute_function_call(self, tool_call): def run_inference(self, prompt) -> str: inputs = self.tokenizer.apply_chat_template( - prompt, - add_generation_prompt=True, - tokenize=False + prompt, add_generation_prompt=True, tokenize=False ) inference_logger.info(f"inputs:\n{inputs}") print() - completions = self.model( - inputs, - max_tokens=2000, - temperature=0.8, - echo=True - ) + completions = self.model(inputs, max_tokens=2000, temperature=0.8, echo=True) completion = completions["choices"][0]["text"] inference_logger.info(f"completion:\n{completion}") return completion @@ -95,24 +95,34 @@ def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5) def recursive_loop(prompt, completion, depth): nonlocal max_depth - tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, - chat_template) + tool_calls, assistant_message, error_message = ( + self.process_completion_and_validate(completion, chat_template) + ) prompt.append({"role": "assistant", "content": assistant_message}) - tool_message = f"Agent iteration {depth} to assist with user query: {query}\n" + tool_message = ( + f"Agent iteration {depth} to assist with user query: {query}\n" + ) if tool_calls: inference_logger.info(f"Assistant Message:\n{assistant_message}") for tool_call in tool_calls: - validation, message = validate_function_call_schema(tool_call, tools) + validation, message = validate_function_call_schema( + tool_call, tools + ) if validation: try: - function_response = self.execute_function_call(tool_call) + function_response = self.execute_function_call( + tool_call + ) tool_message += f"\n{function_response}\n\n" inference_logger.info( - f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}") + f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}" + ) except Exception as e: - inference_logger.info(f"Could not execute function: {e}") + inference_logger.info( + f"Could not execute function: {e}" + ) tool_message += f"\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags \n\n" else: inference_logger.info(message) @@ -121,7 +131,9 @@ def recursive_loop(prompt, completion, depth): depth += 1 if depth >= max_depth: - print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.") + print( + f"Maximum recursion depth reached ({max_depth}). Stopping recursion." + ) return completion = self.run_inference(prompt) @@ -133,7 +145,9 @@ def recursive_loop(prompt, completion, depth): depth += 1 if depth >= max_depth: - print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.") + print( + f"Maximum recursion depth reached ({max_depth}). Stopping recursion." + ) return completion = self.run_inference(prompt) @@ -151,19 +165,39 @@ def recursive_loop(prompt, completion, depth): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run recursive function calling loop") parser.add_argument("--model_path", type=str, help="Path to the model folder") - parser.add_argument("--chat_template", type=str, default="chatml", help="Chat template for prompt formatting") - parser.add_argument("--num_fewshot", type=int, default=None, help="Option to use json mode examples") - parser.add_argument("--load_in_4bit", type=str, default="False", help="Option to load in 4bit with bitsandbytes") - parser.add_argument("--query", type=str, default="I need the current stock price of Tesla (TSLA)") - parser.add_argument("--max_depth", type=int, default=5, help="Maximum number of recursive iteration") + parser.add_argument( + "--chat_template", + type=str, + default="chatml", + help="Chat template for prompt formatting", + ) + parser.add_argument( + "--num_fewshot", type=int, default=None, help="Option to use json mode examples" + ) + parser.add_argument( + "--load_in_4bit", + type=str, + default="False", + help="Option to load in 4bit with bitsandbytes", + ) + parser.add_argument( + "--query", type=str, default="I need the current stock price of Tesla (TSLA)" + ) + parser.add_argument( + "--max_depth", type=int, default=5, help="Maximum number of recursive iteration" + ) args = parser.parse_args() # specify custom model path if args.model_path: - inference = ModelInference(args.model_path, args.chat_template, args.load_in_4bit) + inference = ModelInference( + args.model_path, args.chat_template, args.load_in_4bit + ) else: - model_path = 'NousResearch/Hermes-2-Pro-Llama-3-8B' + model_path = "NousResearch/Hermes-2-Pro-Llama-3-8B" inference = ModelInference(model_path, args.chat_template, args.load_in_4bit) # Run the model evaluator - inference.generate_function_call(args.query, args.chat_template, args.num_fewshot, args.max_depth) + inference.generate_function_call( + args.query, args.chat_template, args.num_fewshot, args.max_depth + ) diff --git a/src/agents/hermes/functions.py b/src/agents/hermes/functions.py index ec1e0fd..6aa41d1 100644 --- a/src/agents/hermes/functions.py +++ b/src/agents/hermes/functions.py @@ -35,8 +35,8 @@ def code_interpreter(code_markdown: str) -> dict | str: """ try: # Extracting code from Markdown code block - code_lines = code_markdown.split('\n')[1:-1] - code_without_markdown = '\n'.join(code_lines) + code_lines = code_markdown.split("\n")[1:-1] + code_without_markdown = "\n".join(code_lines) # Create a new namespace for code execution exec_namespace = {} @@ -53,9 +53,11 @@ def code_interpreter(code_markdown: str) -> dict | str: except TypeError: # If the function requires arguments, attempt to call it with arguments from the namespace arg_names = inspect.getfullargspec(value).args - args = {arg_name: exec_namespace.get(arg_name) for arg_name in arg_names} + args = { + arg_name: exec_namespace.get(arg_name) for arg_name in arg_names + } result_dict[name] = value(**args) - elif not name.startswith('_'): # Exclude variables starting with '_' + elif not name.startswith("_"): # Exclude variables starting with '_' result_dict[name] = value return result_dict @@ -78,51 +80,73 @@ def google_search_and_scrape(query: str) -> dict: list: A list of dictionaries containing the URL, text content, and table data for each scraped page. """ num_results = 2 - url = 'https://www.google.com/search' - params = {'q': query, 'num': num_results} + url = "https://www.google.com/search" + params = {"q": query, "num": num_results} headers = { - 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.3'} + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.3" + } - inference_logger.info(f"Performing google search with query: {query}\nplease wait...") + inference_logger.info( + f"Performing google search with query: {query}\nplease wait..." + ) response = requests.get(url, params=params, headers=headers) - soup = BeautifulSoup(response.text, 'html.parser') - urls = [result.find('a')['href'] for result in soup.find_all('div', class_='tF2Cxc')] + soup = BeautifulSoup(response.text, "html.parser") + urls = [ + result.find("a")["href"] for result in soup.find_all("div", class_="tF2Cxc") + ] inference_logger.info(f"Scraping text from urls, please wait...") [inference_logger.info(url) for url in urls] with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit( - lambda url: (url, requests.get(url, headers=headers).text if isinstance(url, str) else None), url) for url - in urls[:num_results] if isinstance(url, str)] + futures = [ + executor.submit( + lambda url: ( + url, + requests.get(url, headers=headers).text + if isinstance(url, str) + else None, + ), + url, + ) + for url in urls[:num_results] + if isinstance(url, str) + ] results = [] for future in concurrent.futures.as_completed(futures): url, html = future.result() - soup = BeautifulSoup(html, 'html.parser') - paragraphs = [p.text.strip() for p in soup.find_all('p') if p.text.strip()] - text_content = ' '.join(paragraphs) - text_content = re.sub(r'\s+', ' ', text_content) - table_data = [[cell.get_text(strip=True) for cell in row.find_all('td')] for table in soup.find_all('table') - for row in table.find_all('tr')] + soup = BeautifulSoup(html, "html.parser") + paragraphs = [p.text.strip() for p in soup.find_all("p") if p.text.strip()] + text_content = " ".join(paragraphs) + text_content = re.sub(r"\s+", " ", text_content) + table_data = [ + [cell.get_text(strip=True) for cell in row.find_all("td")] + for table in soup.find_all("table") + for row in table.find_all("tr") + ] if text_content or table_data: - results.append({'url': url, 'content': text_content, 'tables': table_data}) + results.append( + {"url": url, "content": text_content, "tables": table_data} + ) return results @tool def get_current_stock_price(symbol: str) -> float: """ - Get the current stock price for a given symbol. + Get the current stock price for a given symbol. - Args: - symbol (str): The stock symbol. + Args: + symbol (str): The stock symbol. - Returns: - float: The current stock price, or None if an error occurs. - """ + Returns: + float: The current stock price, or None if an error occurs. + """ try: stock = yf.Ticker(symbol) # Use "regularMarketPrice" for regular market hours, or "currentPrice" for pre/post market - current_price = stock.info.get("regularMarketPrice", stock.info.get("currentPrice")) + current_price = stock.info.get( + "regularMarketPrice", stock.info.get("currentPrice") + ) return current_price if current_price else None except Exception as e: print(f"Error fetching current price for {symbol}: {e}") @@ -157,18 +181,18 @@ def get_stock_fundamentals(symbol: str) -> dict: stock = yf.Ticker(symbol) info = stock.info fundamentals = { - 'symbol': symbol, - 'company_name': info.get('longName', ''), - 'sector': info.get('sector', ''), - 'industry': info.get('industry', ''), - 'market_cap': info.get('marketCap', None), - 'pe_ratio': info.get('forwardPE', None), - 'pb_ratio': info.get('priceToBook', None), - 'dividend_yield': info.get('dividendYield', None), - 'eps': info.get('trailingEps', None), - 'beta': info.get('beta', None), - '52_week_high': info.get('fiftyTwoWeekHigh', None), - '52_week_low': info.get('fiftyTwoWeekLow', None) + "symbol": symbol, + "company_name": info.get("longName", ""), + "sector": info.get("sector", ""), + "industry": info.get("industry", ""), + "market_cap": info.get("marketCap", None), + "pe_ratio": info.get("forwardPE", None), + "pb_ratio": info.get("priceToBook", None), + "dividend_yield": info.get("dividendYield", None), + "eps": info.get("trailingEps", None), + "beta": info.get("beta", None), + "52_week_high": info.get("fiftyTwoWeekHigh", None), + "52_week_low": info.get("fiftyTwoWeekLow", None), } return fundamentals except Exception as e: diff --git a/src/agents/hermes/prompter.py b/src/agents/hermes/prompter.py index e430d57..130eb9a 100644 --- a/src/agents/hermes/prompter.py +++ b/src/agents/hermes/prompter.py @@ -4,25 +4,25 @@ from pydantic import BaseModel from typing import Dict from .schema import FunctionCall -from .utils import ( - get_fewshot_examples -) +from .utils import get_fewshot_examples import yaml import json import os + class PromptSchema(BaseModel): Role: str Objective: str Tools: str Examples: str Schema: str - Instructions: str + Instructions: str + class PromptManager: def __init__(self): self.script_dir = os.path.dirname(os.path.abspath(__file__)) - + def format_yaml_prompt(self, prompt_schema: PromptSchema, variables: Dict) -> str: formatted_prompt = "" for field, value in prompt_schema.dict().items(): @@ -37,21 +37,21 @@ def format_yaml_prompt(self, prompt_schema: PromptSchema, variables: Dict) -> st return formatted_prompt def read_yaml_file(self, file_path: str) -> PromptSchema: - with open(file_path, 'r') as file: + with open(file_path, "r") as file: yaml_content = yaml.safe_load(file) - + prompt_schema = PromptSchema( - Role=yaml_content.get('Role', ''), - Objective=yaml_content.get('Objective', ''), - Tools=yaml_content.get('Tools', ''), - Examples=yaml_content.get('Examples', ''), - Schema=yaml_content.get('Schema', ''), - Instructions=yaml_content.get('Instructions', ''), + Role=yaml_content.get("Role", ""), + Objective=yaml_content.get("Objective", ""), + Tools=yaml_content.get("Tools", ""), + Examples=yaml_content.get("Examples", ""), + Schema=yaml_content.get("Schema", ""), + Instructions=yaml_content.get("Instructions", ""), ) return prompt_schema - + def generate_prompt(self, user_prompt, tools, num_fewshot=None): - prompt_path = os.path.join(self.script_dir, 'prompt_assets', 'sys_prompt.yml') + prompt_path = os.path.join(self.script_dir, "prompt_assets", "sys_prompt.yml") prompt_schema = self.read_yaml_file(prompt_path) if num_fewshot is not None: @@ -60,18 +60,16 @@ def generate_prompt(self, user_prompt, tools, num_fewshot=None): examples = None schema_json = json.loads(FunctionCall.schema_json()) - #schema = schema_json.get("properties", {}) + # schema = schema_json.get("properties", {}) variables = { "date": datetime.date.today(), "tools": tools, "examples": examples, - "schema": schema_json + "schema": schema_json, } sys_prompt = self.format_yaml_prompt(prompt_schema, variables) - prompt = [ - {'content': sys_prompt, 'role': 'system'} - ] + prompt = [{"content": sys_prompt, "role": "system"}] prompt.extend(user_prompt) return prompt diff --git a/src/agents/hermes/schema.py b/src/agents/hermes/schema.py index 7db49f9..b21421b 100644 --- a/src/agents/hermes/schema.py +++ b/src/agents/hermes/schema.py @@ -3,6 +3,7 @@ from pydantic import BaseModel from typing import List, Dict, Literal, Optional + class FunctionCall(BaseModel): arguments: dict """ @@ -15,11 +16,13 @@ class FunctionCall(BaseModel): name: str """The name of the function to call.""" + class FunctionDefinition(BaseModel): name: str description: Optional[str] = None parameters: Optional[Dict[str, object]] = None + class FunctionSignature(BaseModel): function: FunctionDefinition type: Literal["function"] diff --git a/src/agents/hermes/utils.py b/src/agents/hermes/utils.py index f28b231..cb7c58a 100644 --- a/src/agents/hermes/utils.py +++ b/src/agents/hermes/utils.py @@ -29,12 +29,16 @@ file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0) file_handler.setLevel(logging.INFO) -formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S") +formatter = logging.Formatter( + "%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", + datefmt="%Y-%m-%d:%H:%M:%S", +) file_handler.setFormatter(formatter) inference_logger = logging.getLogger("function-calling-inference") inference_logger.addHandler(file_handler) + def print_nous_text_art(suffix=None): font = "nancyj" ascii_text = " nousresearch" @@ -43,45 +47,61 @@ def print_nous_text_art(suffix=None): ascii_art = text2art(ascii_text, font=font) print(ascii_art) + def get_fewshot_examples(num_fewshot): """return a list of few shot examples""" - example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json') - with open(example_path, 'r') as file: - examples = json.load(file) # Use json.load with the file object, not the file path + example_path = os.path.join(script_dir, "prompt_assets", "few_shot.json") + with open(example_path, "r") as file: + examples = json.load( + file + ) # Use json.load with the file object, not the file path if num_fewshot > len(examples): - raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).") + raise ValueError( + f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples)." + ) return examples[:num_fewshot] + def get_chat_template(chat_template): """read chat template from jinja file""" - template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2") + template_path = os.path.join(script_dir, "chat_templates", f"{chat_template}.j2") if not os.path.exists(template_path): print inference_logger.error(f"Template file not found: {chat_template}") return None try: - with open(template_path, 'r') as file: + with open(template_path, "r") as file: template = file.read() return template except Exception as e: print(f"Error loading template: {e}") return None + def get_assistant_message(completion, chat_template, eos_token): """define and match pattern to find the assistant message""" completion = completion.strip() if chat_template == "zephyr": - assistant_pattern = re.compile(r'<\|assistant\|>((?:(?!<\|assistant\|>).)*)$', re.DOTALL) + assistant_pattern = re.compile( + r"<\|assistant\|>((?:(?!<\|assistant\|>).)*)$", re.DOTALL + ) elif chat_template == "chatml": - assistant_pattern = re.compile(r'<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$', re.DOTALL) + assistant_pattern = re.compile( + r"<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$", + re.DOTALL, + ) elif chat_template == "vicuna": - assistant_pattern = re.compile(r'ASSISTANT:\s*((?:(?!ASSISTANT:).)*)$', re.DOTALL) + assistant_pattern = re.compile( + r"ASSISTANT:\s*((?:(?!ASSISTANT:).)*)$", re.DOTALL + ) else: - raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.") - + raise NotImplementedError( + f"Handling for chat_template '{chat_template}' is not implemented." + ) + assistant_match = assistant_pattern.search(completion) if assistant_match: assistant_content = assistant_match.group(1).strip() @@ -93,6 +113,7 @@ def get_assistant_message(completion, chat_template, eos_token): inference_logger.info("No match found for the assistant pattern") return assistant_content + def validate_and_extract_tool_calls(assistant_content): inference_logger.info(f"assistant_content: {assistant_content}") validation_result = False @@ -120,10 +141,12 @@ def validate_and_extract_tool_calls(assistant_content): # Fallback to ast.literal_eval if json.loads fails json_data = ast.literal_eval(json_text) except (SyntaxError, ValueError) as eval_err: - error_message = f"JSON parsing failed with both json.loads and ast.literal_eval:\n"\ - f"- JSON Decode Error: {json_err}\n"\ - f"- Fallback Syntax/Value Error: {eval_err}\n"\ - f"- Problematic JSON text: {json_text}" + error_message = ( + f"JSON parsing failed with both json.loads and ast.literal_eval:\n" + f"- JSON Decode Error: {json_err}\n" + f"- Fallback Syntax/Value Error: {eval_err}\n" + f"- Problematic JSON text: {json_text}" + ) inference_logger.error(error_message) continue except Exception as e: @@ -141,17 +164,18 @@ def validate_and_extract_tool_calls(assistant_content): # Return default values if no valid data is extracted return validation_result, tool_calls, error_message + def extract_json_from_markdown(text): """ Extracts the JSON string from the given text using a regular expression pattern. - + Args: text (str): The input text containing the JSON string. - + Returns: dict: The JSON data loaded from the extracted string, or None if the JSON string is not found. """ - json_pattern = r'```json\r?\n(.*?)\r?\n```' + json_pattern = r"```json\r?\n(.*?)\r?\n```" match = re.search(json_pattern, text, re.DOTALL) if match: json_string = match.group(1) @@ -163,4 +187,3 @@ def extract_json_from_markdown(text): else: print("JSON string not found in the text.") return None - diff --git a/src/agents/hermes/validator.py b/src/agents/hermes/validator.py index 4679218..7883ab6 100644 --- a/src/agents/hermes/validator.py +++ b/src/agents/hermes/validator.py @@ -7,6 +7,7 @@ from .utils import inference_logger, extract_json_from_markdown from .schema import FunctionCall, FunctionSignature + def validate_function_call_schema(call, signatures): try: call_data = FunctionCall(**call) @@ -18,18 +19,26 @@ def validate_function_call_schema(call, signatures): signature_data = FunctionSignature(**signature) if signature_data.function.name == call_data.name: # Validate types in function arguments - for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items(): + for arg_name, arg_schema in signature_data.function.parameters.get( + "properties", {} + ).items(): if arg_name in call_data.arguments: call_arg_value = call_data.arguments[arg_name] if call_arg_value: try: - validate_argument_type(arg_name, call_arg_value, arg_schema) + validate_argument_type( + arg_name, call_arg_value, arg_schema + ) except Exception as arg_validation_error: return False, str(arg_validation_error) # Check if all required arguments are present - required_arguments = signature_data.function.parameters.get('required', []) - result, missing_arguments = check_required_arguments(call_data.arguments, required_arguments) + required_arguments = signature_data.function.parameters.get( + "required", [] + ) + result, missing_arguments = check_required_arguments( + call_data.arguments, required_arguments + ) if not result: return False, f"Missing required arguments: {missing_arguments}" @@ -41,21 +50,24 @@ def validate_function_call_schema(call, signatures): # No matching function signature found return False, f"No matching function signature found for function: {call_data.name}" + def check_required_arguments(call_arguments, required_arguments): missing_arguments = [arg for arg in required_arguments if arg not in call_arguments] return not bool(missing_arguments), missing_arguments + def validate_enum_value(arg_name, arg_value, enum_values): if arg_value not in enum_values: raise Exception( f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}" ) + def validate_argument_type(arg_name, arg_value, arg_schema): - arg_type = arg_schema.get('type', None) + arg_type = arg_schema.get("type", None) if arg_type: - if arg_type == 'string' and 'enum' in arg_schema: - enum_values = arg_schema['enum'] + if arg_type == "string" and "enum" in arg_schema: + enum_values = arg_schema["enum"] if None not in enum_values and enum_values != []: try: validate_enum_value(arg_name, arg_value, enum_values) @@ -65,20 +77,24 @@ def validate_argument_type(arg_name, arg_value, arg_schema): python_type = get_python_type(arg_type) if not isinstance(arg_value, python_type): - raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}") + raise Exception( + f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}" + ) + def get_python_type(json_type): type_mapping = { - 'string': str, - 'number': (int, float), - 'integer': int, - 'boolean': bool, - 'array': list, - 'object': dict, - 'null': type(None), + "string": str, + "number": (int, float), + "integer": int, + "boolean": bool, + "array": list, + "object": dict, + "null": type(None), } return type_mapping[json_type] + def validate_json_data(json_object, json_schema): valid = False error_message = None @@ -97,7 +113,9 @@ def validate_json_data(json_object, json_schema): result_json = extract_json_from_markdown(json_object) except Exception as e: error_message = f"JSON decoding error: {e}" - inference_logger.info(f"Validation failed for JSON data: {error_message}") + inference_logger.info( + f"Validation failed for JSON data: {error_message}" + ) return valid, result_json, error_message # Return early if both json.loads and ast.literal_eval fail @@ -111,7 +129,9 @@ def validate_json_data(json_object, json_schema): for index, item in enumerate(result_json): try: validate(instance=item, schema=json_schema) - inference_logger.info(f"Item {index+1} is valid against the schema.") + inference_logger.info( + f"Item {index+1} is valid against the schema." + ) except ValidationError as e: error_message = f"Validation failed for item {index+1}: {e}" break @@ -131,4 +151,4 @@ def validate_json_data(json_object, json_schema): else: inference_logger.info(f"Validation failed for JSON data: {error_message}") - return valid, result_json, error_message \ No newline at end of file + return valid, result_json, error_message