-
Notifications
You must be signed in to change notification settings - Fork 0
/
tot_baseline.py
46 lines (37 loc) · 1.39 KB
/
tot_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# import os
# import json
# import argparse
# import csv
# from dotenv import load_dotenv
# from openai import OpenAI
# from tot.methods.bfs import solve
# from tot.tasks.game24 import Game24Task
# import time
# load_dotenv()
# client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# def get_solutions(index):
# args = argparse.Namespace(backend='gpt-4', temperature=0.7, task='game24',naive_run=False, prompt_sample=None, method_generate='propose', method_evaluate='value', method_select='greedy', n_generate_sample=1, n_evaluate_sample=3, n_select_sample=5)
# task = Game24Task()
# ys, infos = solve(args, task, 900)
# return ys, infos
# if __name__ == '__main__':
# print(get_solutions(900))
import os
import json
import argparse
import csv
from dotenv import load_dotenv
from openai import OpenAI
from tot.methods.bfs import solve
from tot.tasks.game24 import Game24Task
import time
# Load the API key from .env file
load_dotenv()
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def get_solutions(index):
args = argparse.Namespace(backend='gpt-4', temperature=0.7, task='game24',naive_run=False, prompt_sample=None, method_generate='propose', method_evaluate='value', method_select='greedy', n_generate_sample=1, n_evaluate_sample=3, n_select_sample=5)
task = Game24Task()
ys, infos = solve(args, task, 900)
return ys
if __name__ == '__main__':
print(get_solutions(900))