Skip to content

Commit

Permalink
Merge pull request #7 from rishiraj/development
Browse files Browse the repository at this point in the history
fix numpy issue
  • Loading branch information
rishiraj authored Apr 14, 2024
2 parents e7d178c + 6dbf812 commit 6fc2e74
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 19 deletions.
30 changes: 24 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,34 @@ vector_db.update_text(index, new_text)
```
This will update the text and its corresponding embedding at the specified index with the new text.

6. Iterate over the stored texts:
6. Save the database to a file:
```python
vector_db.save('vector_db.pkl')
```
This will save the current state of the `VectorDB` instance to a file named 'vector_db.pkl'.

7. Load the database from a file:
```python
vector_db = VectorDB.load('vector_db.pkl')
```
This will load the `VectorDB` instance from the file named 'vector_db.pkl' and return it.

8. Iterate over the stored texts:
```python
for text in vector_db:
print(text)
```
This will iterate over all the texts stored in the database.

7. Access individual texts by index:
9. Access individual texts by index:
```python
index = 2
text = vector_db[index]
print(text)
```
This will retrieve the text at the specified index.

8. Get the number of texts in the database:
10. Get the number of texts in the database:
```python
num_texts = len(vector_db)
print(num_texts)
Expand Down Expand Up @@ -84,9 +96,15 @@ vector_db.update_text(1, "i enjoy playing chess")
# Delete a text
vector_db.delete_text(2)

# Iterate over the stored texts
print("\nStored texts:")
for text in vector_db:
# Save the database
vector_db.save('vector_db.pkl')

# Load the database
loaded_vector_db = VectorDB.load('vector_db.pkl')

# Iterate over the stored texts in the loaded database
print("\nStored texts in the loaded database:")
for text in loaded_vector_db:
print(text)
```

Expand Down
26 changes: 13 additions & 13 deletions spanking/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,36 @@ def __init__(self, model_name='BAAI/bge-base-en-v1.5'):
self.model = SentenceTransformer(model_name)
self.texts = []
self.embeddings = None

def add_texts(self, texts):
new_embeddings = self.model.encode(texts, normalize_embeddings=True)
new_embeddings = jnp.array(self.model.encode(texts, normalize_embeddings=True))
if self.embeddings is None:
self.embeddings = new_embeddings
else:
self.embeddings = jnp.concatenate((self.embeddings, new_embeddings), axis=0)
self.texts.extend(texts)

def delete_text(self, index):
if 0 <= index < len(self.texts):
self.texts.pop(index)
self.embeddings = jnp.delete(self.embeddings, index, axis=0)
self.embeddings = self.embeddings.at[index].delete()
else:
raise IndexError("Invalid index")

def update_text(self, index, new_text):
if 0 <= index < len(self.texts):
self.texts[index] = new_text
new_embedding = self.model.encode([new_text], normalize_embeddings=True).squeeze()
self.embeddings = (self.embeddings).at[index].set(new_embedding)
new_embedding = jnp.array(self.model.encode([new_text], normalize_embeddings=True)).squeeze()
self.embeddings = self.embeddings.at[index].set(new_embedding)
else:
raise IndexError("Invalid index")

def search(self, query, top_k=5):
query_embedding = self.model.encode([query], normalize_embeddings=True)
query_embedding = jnp.array(self.model.encode([query], normalize_embeddings=True))
similarities = jnp.dot(self.embeddings, query_embedding.T).squeeze()
top_indices = jnp.argsort(similarities)[-top_k:][::-1]
return [(self.texts[i], similarities[i]) for i in top_indices]

def save(self, file_path):
with open(file_path, 'wb') as file:
pickle.dump(self, file)
Expand All @@ -46,13 +46,13 @@ def save(self, file_path):
def load(file_path):
with open(file_path, 'rb') as file:
return pickle.load(file)

def __len__(self):
return len(self.texts)

def __getitem__(self, index):
return self.texts[index]

def __iter__(self):
return iter(self.texts)

Expand Down

0 comments on commit 6fc2e74

Please sign in to comment.