-
Notifications
You must be signed in to change notification settings - Fork 3
/
get_embed.py
executable file
·94 lines (79 loc) · 4.06 KB
/
get_embed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
Generate embeddings for referring expressions or names of known landmarks.
"""
import os
import argparse
import json
from pathlib import Path
import openai
from gpt import GPT3
from utils import load_from_file, save_to_file
openai.api_key = os.getenv("OPENAI_API_KEY")
def load_names(fpath):
"""
Load names of known objects in given environment.
Assume 1 name per line in txt file, e.g. data/osm/osm_landmarks_corlw.txt
Assume 1 dictionary of key: landmark name, value: semantic info in json file, e.g. data/osm/lmks/boston.json
"""
ftype = os.path.splitext(fpath)[-1][1:]
if ftype == "txt":
with open(fpath, 'r') as rf:
names = [line.strip() for line in rf.readlines()]
else:
raise ValueError(f"ERROR: file type {ftype} not recognized")
return names
def generate_embeds(embed_model, save_dpath, lmk2sem, keep_keys=(), embed_engine=None, exp_name="lang2ltl-api", update_embed=True):
"""
Generate a database of known landmarks and their embeddings.
:param embed_model: model used to generate embeddings.
:param save_dpath: folder to save generated embeddings.
:param lmk2sem: known landmarks and their semantic information in dict or file.
:param keep_keys: filter semantic information of landmarks used to construct embeddings.
:param embed_engine: embedding engine to use with embedding model , e.g., text-embedding-ada-002
:param exp_name: experiment ID used in file name.
:param update_embed: if to append new embeddings to existing embeddings, and overwrite if same landmark name.
"""
# Load existing embeddings
embed_dpath = os.path.join(save_dpath, "lmk_sem_embeds")
os.makedirs(embed_dpath, exist_ok=True)
lmk_fname = Path(lmk2sem).stem if isinstance(lmk2sem, str) else exp_name
if embed_model == "gpt3":
save_fpath = os.path.join(embed_dpath, f"obj2embed_{lmk_fname}_{embed_model}-{embed_engine}.pkl")
else:
save_fpath = os.path.join(embed_dpath, f"obj2embed_{lmk_fname}_{embed_model}.pkl")
name2embed = {}
if os.path.isfile(save_fpath):
name2embed = load_from_file(save_fpath)
# Generate new embeddings if needed
lmk2sem = load_from_file(lmk2sem) if isinstance(lmk2sem, str) else lmk2sem
if embed_model == "gpt3":
embed_module = GPT3(embed_engine)
else:
raise ValueError(f"ERROR: embedding module not recognized: {embed_model}")
for lmk, sem in lmk2sem.items():
if lmk not in name2embed or update_embed:
sem_filtered = {"name": lmk}
if keep_keys:
sem_filtered.update({k: v for k, v in sem.items() if k in keep_keys})
name2embed[lmk] = embed_module.get_embedding(json.dumps(sem_filtered))
save_to_file(name2embed, save_fpath)
return name2embed, save_fpath
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--env", type=str, default="cleanup", choices=["osm", "cleanup"], help="fpath or dpath to lmks.")
parser.add_argument("--model", type=str, default="gpt3", choices=["gpt3", "llama"])
parser.add_argument("--embed_engine", type=str, default="text-embedding-ada-002")
args = parser.parse_args()
env_dpath = os.path.join("data", args.env)
lmk_dpath = os.path.join(env_dpath, "lmks")
lmk_fpaths = [os.path.join(lmk_dpath, fname) for fname in os.listdir(lmk_dpath) if "json" in fname]
keep_keys = [
"amenity", "shop", "addr:street",
"short_name", "building", "building:part" "leisure", "tourism", "historic",
"healthcare", "area", "landuse", "waterway", "aeroway", "highway", "office", "operator", "brand", "branch",
"cuisine", "beauty", "official_name", "alt_name", "station", "railway", "subway",
] if args.env == "osm" else []
for idx, lmk_fpath in enumerate(lmk_fpaths):
print(f"generating landmark embedding for {lmk_fpath}")
_, save_fpath = generate_embeds(args.model, env_dpath, lmk_fpath, keep_keys, args.embed_engine)
print(f"{idx}: embeddings generated by model: {args.model}-{args.embed_engine}\nstored at: {save_fpath}\n")