Skip to content

Commit

Permalink
Improve import_parquet.py example
Browse files Browse the repository at this point in the history
  • Loading branch information
xthexder committed Feb 8, 2024
1 parent e4f44aa commit 826911d
Showing 1 changed file with 42 additions and 13 deletions.
55 changes: 42 additions & 13 deletions examples/import_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,33 @@
import turbopuffer as tpuf
import traceback
import threading
from queue import Queue
from queue import Queue, Full

NUM_THREADS = 4
START_OFFSET = 0


def read_docs_to_queue(queue, parquet_files):
def read_docs_to_queue(queue, parquet_files, exiting):
try:
file_offset = 0
for parquet_file in parquet_files:
df = pd.read_parquet(parquet_file).rename(columns={'emb': 'vector'})
while queue.full() and not exiting.is_set():
time.sleep(1)
if exiting.is_set():
return
# Add any attribute columns to include after 'emb'
df = pd.read_parquet(parquet_file, columns=['emb']).rename(columns={'emb': 'vector'})
if 'id' not in df.keys():
df['id'] = range(file_offset, file_offset+len(df))
queue.put(df)
if file_offset >= START_OFFSET:
while not exiting.is_set():
try:
queue.put(df, timeout=1)
break
except Full:
pass
print(f'Loaded {parquet_file}, file_offset from {file_offset} to {file_offset + len(df)}')
else:
print(f'Skipped {parquet_file}, file_offset from {file_offset} to {file_offset + len(df)}')
file_offset += len(df)
except Exception:
print('Failed to read batch:')
Expand All @@ -29,16 +43,14 @@ def read_docs_to_queue(queue, parquet_files):
queue.put(None) # Signal the end of the documents


def upsert_docs_from_queue(input_queue, dataset_name):
def upsert_docs_from_queue(input_queue, dataset_name, exiting):
ns = tpuf.Namespace(dataset_name)

batch = input_queue.get()
while batch is not None:
while batch is not None and not exiting.is_set():
try:
ns.upsert(batch)
print(f"Completed {batch['id'][0]}..{batch['id'][batch.shape[0]-1]}")
except KeyboardInterrupt:
break
except Exception:
print(f"Failed to upsert batch: {batch['id'][0]}..{batch['id'][batch.shape[0]-1]}")
traceback.print_exc()
Expand All @@ -53,8 +65,23 @@ def main(dataset_name, input_path):
print(f"No .parquet files found in: {input_glob}")
sys.exit(1)

sorted_files = sorted(sorted(parquet_files), key=len)

ns = tpuf.Namespace(dataset_name)
if ns.exists():
print(f'The namespace "{ns.name}" already exists!')
existing_dims = ns.dimensions()
print(f'Vectors: {ns.approx_count()}, dimensions: {existing_dims}')
response = input('Delete namespace? [y/N]: ')
if response == 'y':
ns.delete_all()
else:
print('Cancelled')
sys.exit(1)

exiting = threading.Event()
doc_queue = Queue(NUM_THREADS)
read_thread = threading.Thread(target=read_docs_to_queue, args=(doc_queue, parquet_files))
read_thread = threading.Thread(target=read_docs_to_queue, args=(doc_queue, sorted_files, exiting))
upsert_threads = []

start_time = time.monotonic()
Expand All @@ -63,17 +90,19 @@ def main(dataset_name, input_path):
read_thread.start()

for _ in range(NUM_THREADS):
upsert_thread = threading.Thread(target=upsert_docs_from_queue, args=(doc_queue, dataset_name))
upsert_thread = threading.Thread(target=upsert_docs_from_queue, args=(doc_queue, dataset_name, exiting))
upsert_threads.append(upsert_thread)
upsert_thread.start()

read_thread.join()

for upsert_thread in upsert_threads:
upsert_thread.join()

except KeyboardInterrupt:
exiting.set()
sys.exit(1)
finally:
print('Upserted', doc_queue.qsize(), 'documents')
print('DONE!')
print('Took:', (time.monotonic() - start_time), 'seconds')


Expand Down

0 comments on commit 826911d

Please sign in to comment.