-
Notifications
You must be signed in to change notification settings - Fork 1
/
search_emojis_cli.py
106 lines (80 loc) · 2.81 KB
/
search_emojis_cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python3
import numpy as np
from collections import defaultdict, Counter
import pickle
from tqdm import tqdm
from typing import List, Dict, Any, Tuple
import emoji as em
from sentence_transformers import SentenceTransformer
from qdrant_client import models, QdrantClient
# supress warnings coming from Hugging Face library
import warnings
warnings.filterwarnings('ignore')
# read emoji dictionary from desk
with open('./data/emoji_embeddings_dict.pkl', 'rb') as file:
emoji_dict = pickle.load(file)
# initialize sentence encoder
embedding_model = 'paraphrase-multilingual-MiniLM-L12-v2'
sentence_encoder = SentenceTransformer(embedding_model)
# make a new dict for embeddings, delete embedding key from old dict
# deletion is importatn becauase Qdrant cannot handle np.array payloods
embedding_dict:Dict[str, np.array] = {}
for emoji in emoji_dict:
embedding_dict[emoji] = np.array(emoji_dict[emoji]['embedding'])
del emoji_dict[emoji]['embedding']
embedding_dim = embedding_dict[emoji].shape[0]
# initialize vector database client
vector_DB_client = QdrantClient(":memory:")
vector_DB_client.create_collection(
collection_name="EMOJIS",
vectors_config=models.VectorParams(
size=embedding_dim,
distance=models.Distance.COSINE,
),
)
# populate the collection with emojis and embeddings
vector_DB_client.upload_points(
collection_name="EMOJIS",
points=[
models.PointStruct(
id=idx,
vector=embedding_dict[emoji].tolist(),
payload=emoji_dict[emoji]
)
for idx, emoji in enumerate(emoji_dict)
],
)
def return_simialr_emojis(query: str) -> None:
"""
Return similar emojis to a given query
Args:
query (str): The query string to search for similar emojis.
Returns:
None
"""
hits = vector_DB_client.search(
collection_name="EMOJIS",
query_vector=sentence_encoder.encode(query).tolist(),
limit=15,
)
hit_emojis = set()
for i, hit in enumerate(hits, start=1):
emoji_char = hit.payload['Emoji']
score = hit.score
# to handle emojies with multiple byte characters
_ord = ' '.join(str(ord(c)) for c in emoji_char)
s = len(emoji_char) + 7
emoji_desc = ' '.join(em.demojize(emoji_char).split('_'))[1:-1].upper()
if emoji_char not in hit_emojis:
print(f"{emoji_char:<{s}} ", end='')
#print(f"{i:<1} {emoji_char:<{s}}", end='')
#print(f"{score:<7.3f}", end= '')
print(f"{emoji_desc:<55}")
hit_emojis.add(emoji_char)
# return_simialr_emojis(
# "innovation"
# ) # animal you can find in Australiaa
query = input("\nEnter a query: ")
while query:
return_simialr_emojis(query)
query = input("\nEnter a query: ")