-
Notifications
You must be signed in to change notification settings - Fork 0
/
fewshot.py
52 lines (39 loc) · 1.68 KB
/
fewshot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import pandas as pd
import json
class FewShot:
def __init__(self, file_paths=None):
self.df = pd.DataFrame()
self.unique_tags = set()
if file_paths is None:
file_paths = ["Data/hisham_sarwar_processed.json",
"Data/irfan_malik_processed.json",
"Data/usman_asif_processed.json"]
self.load_posts(file_paths)
def load_posts(self, file_paths):
for file_path in file_paths:
with open(file_path, encoding="utf-8") as f:
posts = json.load(f)
temp_df = pd.json_normalize(posts)
temp_df['length'] = temp_df['line_count'].apply(self.categorize_length)
all_tags = temp_df['tags'].apply(lambda x: x).sum()
self.unique_tags.update(all_tags)
self.df = pd.concat([self.df, temp_df], ignore_index=True)
def get_filtered_posts(self, length, tag):
df_filtered = self.df[
(self.df['tags'].apply(lambda tags: tag in tags)) &
(self.df['length'] == length)
]
return df_filtered.to_dict(orient='records')
def categorize_length(self, line_count):
if line_count < 5:
return "Short"
elif 5 <= line_count <= 10:
return "Medium"
else:
return "Long"
def get_tags(self):
return list(self.unique_tags)
if __name__ == "__main__":
fs = FewShot(["Data/hisham_sarwar_processed.json", "Data/irfan_malik_processed.json", "Data/usman_asif_processed.json"])
posts = fs.get_filtered_posts("Medium", "Artificial Intelligence")
print(posts)