From 6dbf812031c6a6598d5b5f18fefbfc08428a8578 Mon Sep 17 00:00:00 2001 From: rishiraj Date: Mon, 15 Apr 2024 04:05:41 +0530 Subject: [PATCH] fix numpy issue --- README.md | 30 ++++++++++++++++++++++++------ spanking/main.py | 26 +++++++++++++------------- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 8b26ade..054f8a4 100644 --- a/README.md +++ b/README.md @@ -39,14 +39,26 @@ vector_db.update_text(index, new_text) ``` This will update the text and its corresponding embedding at the specified index with the new text. -6. Iterate over the stored texts: +6. Save the database to a file: +```python +vector_db.save('vector_db.pkl') +``` +This will save the current state of the `VectorDB` instance to a file named 'vector_db.pkl'. + +7. Load the database from a file: +```python +vector_db = VectorDB.load('vector_db.pkl') +``` +This will load the `VectorDB` instance from the file named 'vector_db.pkl' and return it. + +8. Iterate over the stored texts: ```python for text in vector_db: print(text) ``` This will iterate over all the texts stored in the database. -7. Access individual texts by index: +9. Access individual texts by index: ```python index = 2 text = vector_db[index] @@ -54,7 +66,7 @@ print(text) ``` This will retrieve the text at the specified index. -8. Get the number of texts in the database: +10. Get the number of texts in the database: ```python num_texts = len(vector_db) print(num_texts) @@ -84,9 +96,15 @@ vector_db.update_text(1, "i enjoy playing chess") # Delete a text vector_db.delete_text(2) -# Iterate over the stored texts -print("\nStored texts:") -for text in vector_db: +# Save the database +vector_db.save('vector_db.pkl') + +# Load the database +loaded_vector_db = VectorDB.load('vector_db.pkl') + +# Iterate over the stored texts in the loaded database +print("\nStored texts in the loaded database:") +for text in loaded_vector_db: print(text) ``` diff --git a/spanking/main.py b/spanking/main.py index 37fc3ac..f6c224c 100644 --- a/spanking/main.py +++ b/spanking/main.py @@ -8,36 +8,36 @@ def __init__(self, model_name='BAAI/bge-base-en-v1.5'): self.model = SentenceTransformer(model_name) self.texts = [] self.embeddings = None - + def add_texts(self, texts): - new_embeddings = self.model.encode(texts, normalize_embeddings=True) + new_embeddings = jnp.array(self.model.encode(texts, normalize_embeddings=True)) if self.embeddings is None: self.embeddings = new_embeddings else: self.embeddings = jnp.concatenate((self.embeddings, new_embeddings), axis=0) self.texts.extend(texts) - + def delete_text(self, index): if 0 <= index < len(self.texts): self.texts.pop(index) - self.embeddings = jnp.delete(self.embeddings, index, axis=0) + self.embeddings = self.embeddings.at[index].delete() else: raise IndexError("Invalid index") - + def update_text(self, index, new_text): if 0 <= index < len(self.texts): self.texts[index] = new_text - new_embedding = self.model.encode([new_text], normalize_embeddings=True).squeeze() - self.embeddings = (self.embeddings).at[index].set(new_embedding) + new_embedding = jnp.array(self.model.encode([new_text], normalize_embeddings=True)).squeeze() + self.embeddings = self.embeddings.at[index].set(new_embedding) else: raise IndexError("Invalid index") - + def search(self, query, top_k=5): - query_embedding = self.model.encode([query], normalize_embeddings=True) + query_embedding = jnp.array(self.model.encode([query], normalize_embeddings=True)) similarities = jnp.dot(self.embeddings, query_embedding.T).squeeze() top_indices = jnp.argsort(similarities)[-top_k:][::-1] return [(self.texts[i], similarities[i]) for i in top_indices] - + def save(self, file_path): with open(file_path, 'wb') as file: pickle.dump(self, file) @@ -46,13 +46,13 @@ def save(self, file_path): def load(file_path): with open(file_path, 'rb') as file: return pickle.load(file) - + def __len__(self): return len(self.texts) - + def __getitem__(self, index): return self.texts[index] - + def __iter__(self): return iter(self.texts)