Skip to content

Commit

Permalink
Refactored name
Browse files Browse the repository at this point in the history
  • Loading branch information
Xpitfire committed Apr 20, 2023
1 parent 8996617 commit 8d3d27b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/exp_graph_coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def classify(self, usr: ai.Symbol):
assert isinstance(usr, ai.Symbol)
usr_embed = usr.embed()
embeddings = self.embed_opt()
similarities = [usr_embed.cos_sim(emb) for emb in embeddings]
similarities = [usr_embed.similarity(emb) for emb in embeddings]
similarities = sorted(zip(self.options, similarities), key=lambda x: x[1], reverse=True)

return similarities[0][0]
Expand Down
6 changes: 3 additions & 3 deletions symai/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,11 +576,11 @@ def _func(_) -> list:
pass
return self._sym_return_type(_func(self))

def cos_sim(self, other: Any) -> float:
def similarity(self, other: Any, metric='cosine') -> float:
if not isinstance(self.value, np.ndarray): v = np.array(self.value).squeeze()[:, None]
if not isinstance(other, np.ndarray): o = np.array(other).squeeze()[:, None]
if not isinstance(other, np.ndarray): other = np.array(other).squeeze()[:, None]

return (v.T@o / (v.T@v)**.5 * (o.T@o)**.5).item()
return (v.T@other / (v.T@v)**.5 * (other.T@other)**.5).item()

# TODO: improve how to set max_tokens
def stream(self, expr: "Expression",
Expand Down

0 comments on commit 8d3d27b

Please sign in to comment.