-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
72 lines (55 loc) · 2.71 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from src.utils.prompting import *
from src.utils.load_data import CitationGraph
from src.model.llm import APIModel
from src.utils.process_output import *
from typing import Dict, Any
import random
import os
from dotenv import load_dotenv
load_dotenv()
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
random.seed(42)
def select_random_paper_with_comment(dataset: CitationGraph) -> Dict[str, Any]:
papers_with_comments = [i for i in range(len(dataset)) if dataset[i]['comment']]
if not papers_with_comments:
raise ValueError("No papers with comments found in the database.")
random_index = random.choice(papers_with_comments)
return dataset[random_index]
# Example usage
Counter = tokenCounter()
model = APIModel(model = "gpt-4o-2024-05-13", api_key = OPENAI_API_KEY, api_url = "https://api.openai.com/v1/chat/completions")
db = torch.load("processed/ICLR_2017.pt")
paper_dataset = CitationGraph(db)
random_paper = select_random_paper_with_comment(paper_dataset)
area_finding_prompt = area_finding_prompt(random_paper)
prompts = [area_finding_prompt]
outputs = model.batch_chat(text_batch=prompts, temperature=1)
domains = extract_domains(outputs)
with open('extracted_domains.txt', 'w', encoding='utf-8') as f:
f.write(f"\n\n{'-'*200}\n\n".join(outputs))
prompts = [topic_finding_prompt(random_paper, domains[0])]
outputs = model.batch_chat(text_batch=prompts, temperature=1)
aspects = extract_aspects(outputs)
with open('extracted_aspects.txt', 'w', encoding='utf-8') as f:
f.write(f"\n\n{'-'*200}\n\n".join(outputs))
prompts = [aspect_writing_prompt(random_paper, domains[0], aspects[0])]
outputs = model.batch_chat(text_batch=prompts, temperature=1)
with open('aspect_comments.txt', 'w', encoding='utf-8') as f:
f.write(f"\n\n{'-'*200}\n\n".join(outputs))
aspect_txt = extract_txt_file('extracted_aspects.txt')
prompts = [merged_aspect_prompt(random_paper, domains[0], aspect_txt)]
outputs = model.batch_chat(text_batch=prompts, temperature=1)
merged_aspects = extract_merged_aspects(outputs)
with open('merged_aspects.txt', 'w', encoding='utf-8') as f:
f.write(f"\n\n{'-'*200}\n\n".join(outputs))
prompts = [aspect_writing_prompt(random_paper, domains[0], merged_aspects[0])]
outputs = model.batch_chat(text_batch=prompts, temperature=1)
with open('merged_aspect_comments_0.txt', 'w', encoding='utf-8') as f:
f.write(f"\n\n{'-'*200}\n\n".join(outputs))
for i in range(1, 5, 1):
file_name = 'merged_aspects_comments_' + str(i) + '.txt'
prompts = [aspect_writing_prompt(random_paper, domains[i], merged_aspects[0])]
outputs = model.batch_chat(text_batch=prompts, temperature=1)
with open(file_name, 'w', encoding='utf-8') as f:
f.write(f"\n\n{'-'*200}\n\n".join(outputs))