-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
221 lines (164 loc) · 6.13 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
from datetime import datetime
import inspect
import re
import json
import traceback
from colorama import Fore, Style
import httpx
import os
import pandas as pd
# custom list class with a overwrite_key attribute
class YieldList(list):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.overwrite_key = None
def replace_whitespace(s):
pattern = re.compile(r'",\s*"')
return re.sub(pattern, '", "', s)
def fix_JSON(json_message=None):
result = json_message
json_message = replace_whitespace(json_message)
try:
# First, try to load the JSON string as is
result = json.loads(json_message)
except json.JSONDecodeError as e:
try:
# If the JSON string can't be loaded, it means there are unescaped characters
# Use Python's string escape to escape the string
escaped_message = json_message.encode("unicode_escape").decode("utf-8")
# Try loading the JSON string again
result = json.loads(escaped_message)
except Exception as e_inner:
# If it still fails, print the error
print("Error while trying to fix JSON string: ", str(e_inner))
return None
except Exception as e:
print("Unexpected error: ", str(e))
return None
return result
def api_response(ran_successfully=False, **extra):
"""Returns a JSON object with the ran_successfully key and any extra keys passed in."""
return {"ran_successfully": ran_successfully, **extra}
def missing_param_error(param_name):
"""Returns a JSON object with the error_message key and a message saying that the param_name is missing."""
return api_response(
error_message=f"Missing parameter in request: {param_name}. Request must contain question, agent, and/or generate_report/get_report params."
)
def success_str(msg=""):
return f"{Fore.GREEN}{Style.BRIGHT}{msg}{Style.RESET_ALL}"
def error_str(msg=""):
return f"{Fore.RED}{Style.BRIGHT}{msg}{Style.RESET_ALL}"
def log_str(msg=""):
return f"{Fore.BLUE}{Style.BRIGHT}{msg}{Style.RESET_ALL}"
def warn_str(msg=""):
return f"{Fore.YELLOW}{Style.BRIGHT}{msg}{Style.RESET_ALL}"
def log_success(msg=""):
print(f"{Fore.GREEN}{Style.BRIGHT}{msg}{Style.RESET_ALL}")
def log_error(msg=""):
print(f"{Fore.RED}{Style.BRIGHT}{msg}{Style.RESET_ALL}")
def log_msg(msg=""):
print(f"{Fore.BLUE}{Style.BRIGHT}{msg}{Style.RESET_ALL}")
def log_warn(msg=""):
print(f"{Fore.YELLOW}{Style.BRIGHT}{msg}{Style.RESET_ALL}")
simple_tool_types = {
"DBColumn": "Column name",
"DBColumnList": "List of column names",
"pandas.core.frame.DataFrame": "Dataframe",
"str": "String",
"int": "Integer",
"float": "Float",
"bool": "Boolean",
"list[str]": "List of strings",
"list": "List",
"DropdownSingleSelect": "String",
}
def create_simple_tool_types(_type):
# if type starts with DBColumnList...
if _type.startswith("DBColumnList"):
return "List of column names"
if _type.startswith("ListWithDefault"):
return "List"
else:
return simple_tool_types.get(_type, _type)
def get_clean_plan(analysis_data):
generated_plan = analysis_data.get("gen_steps", {}).get("steps", [])
cleaned_plan = []
for item in generated_plan:
cleaned_item = {}
for key, value in item.items():
if key in [
"tool_name",
"model_generated_inputs",
"outputs_storage_keys",
"done",
"error_message",
]:
# if key is model_generated_inputs, just change it to inputs
if key == "model_generated_inputs":
cleaned_item["inputs"] = value
else:
cleaned_item[key] = value
cleaned_plan.append(cleaned_item)
return cleaned_plan
def snake_case(s):
return re.sub(r"(?<!^)(?=[A-Z])", "_", s).lower()
class SqlExecutionError(Exception):
def __init__(self, sql, error_message):
# Call the base class constructor with the parameters it needs
super().__init__(f"{error_message}")
# Now for your custom code...
self.sql = sql
async def make_request(url, payload, verbose=False):
print(f"Making request to {url}")
async with httpx.AsyncClient() as client:
r = await client.post(
url,
json=payload,
)
if verbose:
print(f"Response: {r.text}")
return r
def deduplicate_columns(df: pd.DataFrame):
# de-duplicate column names
# if the same column name exists more than once, add a suffix
deduplicated_df = df.copy()
columns = deduplicated_df.columns.tolist()
seen = {}
for i, item in enumerate(columns):
if item in seen:
columns[i] = f"{item}_{seen[item]}"
seen[item] += 1
else:
seen[item] = 1
deduplicated_df.columns = columns
return deduplicated_df
def filter_function_inputs(fn, inputs):
"""
Used to filter down a dict's keys to only the parameters that are required by a function
Creates a filtering function that can be run on a dict
Which will filter the dict's keys based on the return value (True or False)
of the filter function
If the function takes kwargs, we always return true from the function
Otherwise, we check if the key is in the function's parameters
"""
# if a function takes kwargs, then we will
sig = inspect.signature(fn)
params = sig.parameters.values()
has_kwargs = any([True for p in params if p.kind == p.VAR_KEYWORD])
f = lambda _: True
if not has_kwargs:
param_names = [p.name for p in params]
f = lambda key: key in param_names
return {k: v for k, v in inputs.items() if f(k)}, f
def wrap_in_async(fn):
"""
If a function isn't async, wrap it in an async function for create_Task to work
"""
wrapped_fn = fn
if not inspect.iscoroutinefunction(fn):
async def async_fn(**kwargs):
return fn(**kwargs)
wrapped_fn = async_fn
return wrapped_fn
def add_indent(level=1):
return "...." * level