Skip to content

Commit

Permalink
modify prompt and speed up
Browse files Browse the repository at this point in the history
  • Loading branch information
shepardxia committed Jun 15, 2024
1 parent 5bc0ab4 commit 0664b6d
Show file tree
Hide file tree
Showing 17 changed files with 778 additions and 2,789 deletions.
4 changes: 2 additions & 2 deletions AgentS1.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def translate_query(self, ):


response = self.client.chat.completions.create(
model=self.model_trans,
model=self.model_dial,
response_format={ "type": "json_object" },
temperature=0,
messages=query_trans
Expand All @@ -158,7 +158,7 @@ def get_variable(self, question):

prop[-1]['content'] += "Additionally, we would focus on question like: " + question
prop[-1]['content'] += " Therefore, make sure the variable values at least cover those in the question. "
prop[-1]['content'] += "Think about the question first. What can you infer to help with modeling? Feel free to choose other variables that help model that question. However, only provide the variables but don't proceed to solve the question. "
prop[-1]['content'] += "Propose 2~3 variables in your final summary. "


response = self.client.chat.completions.create(
Expand Down
67 changes: 54 additions & 13 deletions AgentS2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import argparse
dotenv.load_dotenv('.env')
import concurrent.futures



Expand Down Expand Up @@ -184,7 +185,7 @@ def get_prob(self, message, target_var, ):

prob_prompt = self.prompt('get_prob')
vals = self.var_dict[target_var]
message += f"\nYou should extract {len(vals)} probabilities, in the order of " + ', '.join(vals) + "."
message += f"\nextract in the order of " + ', '.join(vals) + "."
prob_prompt.append({'role': 'user', 'content': message})

response = self.client.chat.completions.create(
Expand Down Expand Up @@ -279,6 +280,8 @@ def propose_constraints(self, shuffle=False):
prop_mess, prop_assist = self.chat(cons_prompt)

self.mega_records += [i.copy() for i in prop_mess[-2:]]
#print(schema_text)
#print(prop_assist)


#print("Translating the constraints...")
Expand Down Expand Up @@ -320,6 +323,7 @@ def populate_constraint_prob(self, question_js, add_city=True, conf=False, rejec
got = 0
looped += 1
if looped > 3:
print("Warning: could not get the probabilities on ", question_text)
return None
self.mega_records += [i.copy() for i in val_mess[-2:]]
else:
Expand Down Expand Up @@ -359,46 +363,83 @@ def compile(self, add_city=True, conf=True):

#print("Populating the marginals...")
# first ask for marginals on the variables
for var in schema['Variables']:
#print("Populating marginal for ", var['Name'])
def process_variable(var):
# Prepare the question JSON for each variable
question_js = {'Target': [{'Name': var['Name'], 'Value': var['Value']}], 'Condition': []}
question_js = self.populate_constraint_prob(question_js, add_city, conf)
if question_js is not None:
self.constraints_rec.append(question_js)
return question_js

# Use ThreadPoolExecutor to parallelize the API calls
with concurrent.futures.ThreadPoolExecutor() as executor:
# Submit tasks to the executor for each variable
futures = [executor.submit(process_variable, var) for var in schema['Variables']]

# Process the results as they complete
for future in concurrent.futures.as_completed(futures):
question_js = future.result()
if question_js is not None:
self.constraints_rec.append(question_js)
else:
print("Warning: skipping variable")


self.records['schema'] = schema
self.log()

#print("\n--------------------------------------------\n")
#print("Populating the constraints...")
seen = set()
for i, entry in enumerate(schema['Constraints']):

def process_entry(entry, i, var_dict, other, add_city, conf, seen):
if len(entry['Condition']) == 0 or len(entry['Target']) == 0:
print("Warning: Target and Condition should both be proposed.")
continue
return None
for cond in entry['Condition']:
if 'Value' in cond:
print("Warning: Condition should not have Value.")
#print("Populating constraint ", i)
tar = entry['Target'][0]['Name']
tar_vals = self.var_dict[tar]
tar_vals = var_dict[tar]
new_tar = [{'Name': tar, 'Value': tar_vals}]
cond = entry['Condition'][0]['Name']

if (tar, cond) in seen:
continue
return None
else:
seen.add((tar, cond))

for cond_val in self.var_dict[cond]:
if self.other:
new_cond = [{'Name': cond, 'Value': [cond_val], 'Other': [i for i in self.var_dict[cond] if i != cond_val]}]
constraints_rec = []
for cond_val in var_dict[cond]:
if other:
new_cond = [{'Name': cond, 'Value': [cond_val], 'Other': [i for i in var_dict[cond] if i != cond_val]}]
else:
new_cond = [{'Name': cond, 'Value': [cond_val]}]
question_js = {'Target': new_tar, 'Condition': new_cond}
question_js = self.populate_constraint_prob(question_js, add_city, conf)
if question_js is not None:
self.constraints_rec.append(question_js)
constraints_rec.append(question_js)
return constraints_rec

def parallel_process(schema, var_dict, other, add_city, conf):
seen = set()
all_constraints_rec = []

with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(process_entry, entry, i, var_dict, other, add_city, conf, seen)
for i, entry in enumerate(schema['Constraints'])]

for future in concurrent.futures.as_completed(futures):
result = future.result()
if result:
all_constraints_rec.extend(result)

return all_constraints_rec

# Assuming `self.populate_constraint_prob`, `self.var_dict`, etc. are defined elsewhere
all_constraints_rec = parallel_process(schema, self.var_dict, self.other, add_city, conf)
self.constraints_rec += all_constraints_rec




self.log()
Expand Down
Binary file modified Experiments/.DS_Store
Binary file not shown.
Binary file removed Experiments/BrowardCounty/.DS_Store
Binary file not shown.
Loading

0 comments on commit 0664b6d

Please sign in to comment.