Skip to content

Commit

Permalink
Add rate runner
Browse files Browse the repository at this point in the history
Signed-off-by: yangxuan <[email protected]>
  • Loading branch information
XuanYang-cn committed Nov 14, 2024
1 parent 1ab46dd commit 49bccd1
Show file tree
Hide file tree
Showing 9 changed files with 413 additions and 11 deletions.
88 changes: 88 additions & 0 deletions tests/test_rate_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Iterable
import argparse
from vectordb_bench.backend.dataset import Dataset, DatasetSource
from vectordb_bench.backend.runner.rate_runner import RatedMultiThreadingInsertRunner
from vectordb_bench.backend.runner.read_write_runner import ReadWriteRunner
from vectordb_bench.backend.clients import DB, VectorDB
from vectordb_bench.backend.clients.milvus.config import FLATConfig
from vectordb_bench.backend.clients.zilliz_cloud.config import AutoIndexConfig

import logging

log = logging.getLogger("vectordb_bench")
log.setLevel(logging.DEBUG)

def get_rate_runner(db):
cohere = Dataset.COHERE.manager(100_000)
prepared = cohere.prepare(DatasetSource.AliyunOSS)
assert prepared
runner = RatedMultiThreadingInsertRunner(
rate = 10,
db = db,
dataset = cohere,
)

return runner

def test_rate_runner(db, insert_rate):
runner = get_rate_runner(db)

_, t = runner.run_with_rate()
log.info(f"insert run done, time={t}")

def test_read_write_runner(db, insert_rate, conc: list, search_stage: Iterable[float], read_dur_after_write: int, local: bool=False):
cohere = Dataset.COHERE.manager(1_000_000)
if local is True:
source = DatasetSource.AliyunOSS
else:
source = DatasetSource.S3
prepared = cohere.prepare(source)
assert prepared

rw_runner = ReadWriteRunner(
db=db,
dataset=cohere,
insert_rate=insert_rate,
search_stage=search_stage,
read_dur_after_write=read_dur_after_write,
concurrencies=conc
)
rw_runner.run_read_write()


def get_db(db: str, config: dict) -> VectorDB:
if db == DB.Milvus.name:
return DB.Milvus.init_cls(dim=768, db_config=config, db_case_config=FLATConfig(metric_type="COSINE"), drop_old=True, pre_load=True)
elif db == DB.ZillizCloud.name:
return DB.ZillizCloud.init_cls(dim=768, db_config=config, db_case_config=AutoIndexConfig(metric_type="COSINE"), drop_old=True, pre_load=True)
else:
raise ValueError(f"unknown db: {db}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-r", "--insert_rate", type=int, default="1000", help="insert entity row count per seconds, cps")
parser.add_argument("-d", "--db", type=str, default=DB.Milvus.name, help="db name")
parser.add_argument("-t", "--duration", type=int, default=300, help="stage search duration in seconds")
parser.add_argument("--use_s3", action='store_true', help="whether to use S3 dataset")

flags = parser.parse_args()

# TODO read uri, user, password from .env
config = {
"uri": "http://localhost:19530",
"user": "",
"password": "",
}

conc = (1, 15, 50)
search_stage = (0.5, 0.6, 0.7, 0.8, 0.9, 1.0)

db = get_db(flags.db, config)
test_read_write_runner(
db=db,
insert_rate=flags.insert_rate,
conc=conc,
search_stage=search_stage,
read_dur_after_write=flags.duration,
local=flags.use_s3)
2 changes: 1 addition & 1 deletion vectordb_bench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class config:

DEFAULT_DATASET_URL = env.str("DEFAULT_DATASET_URL", AWS_S3_URL)
DATASET_LOCAL_DIR = env.path("DATASET_LOCAL_DIR", "/tmp/vectordb_bench/dataset")
NUM_PER_BATCH = env.int("NUM_PER_BATCH", 5000)
NUM_PER_BATCH = env.int("NUM_PER_BATCH", 100)

DROP_OLD = env.bool("DROP_OLD", True)
USE_SHUFFLED_DATA = env.bool("USE_SHUFFLED_DATA", True)
Expand Down
5 changes: 3 additions & 2 deletions vectordb_bench/backend/clients/milvus/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(
self.case_config.index_param(),
index_name=self._index_name,
)
# self._pre_load(coll)
if kwargs.get("pre_load") is True:
self._pre_load(col)

connections.disconnect("default")

Expand All @@ -92,7 +93,7 @@ def _optimize(self):
self._post_insert()
log.info(f"{self.name} optimizing before search")
try:
self.col.load()
self.col.load(refresh=True)
except Exception as e:
log.warning(f"{self.name} optimize error: {e}")
raise e from None
Expand Down
13 changes: 8 additions & 5 deletions vectordb_bench/backend/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ class CustomDataset(BaseDataset):
dir: str
file_num: int
isCustom: bool = True

@validator("size")
def verify_size(cls, v):
return v

@property
def label(self) -> str:
return "Custom"
Expand All @@ -73,7 +73,8 @@ def dir_name(self) -> str:
@property
def file_count(self) -> int:
return self.file_num



class LAION(BaseDataset):
name: str = "LAION"
dim: int = 768
Expand Down Expand Up @@ -242,13 +243,15 @@ def __init__(self, dataset: DatasetManager):
self._cur = None
self._sub_idx = [0 for i in range(len(self._ds.train_files))] # iter num for each file

def __iter__(self):
return self

def _get_iter(self, file_name: str):
p = pathlib.Path(self._ds.data_dir, file_name)
log.info(f"Get iterator for {p.name}")
if not p.exists():
raise IndexError(f"No such file {p}")
log.warning(f"No such file: {p}")
return ParquetFile(p).iter_batches(config.NUM_PER_BATCH)
return ParquetFile(p, memory_map=True, pre_buffer=True).iter_batches(config.NUM_PER_BATCH)

def __next__(self) -> pd.DataFrame:
"""return the data in the next file of the training list"""
Expand Down
89 changes: 88 additions & 1 deletion vectordb_bench/backend/runner/mp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def search(self, test_data: list[list[float]], q: mp.Queue, cond: mp.Condition)
log.warning(f"VectorDB search_embedding error: {e}")
traceback.print_exc(chain=True)
raise e from None

latencies.append(time.perf_counter() - s)
count += 1
# loop through the test data
Expand All @@ -87,6 +87,8 @@ def get_mp_context():
log.debug(f"MultiProcessingSearchRunner get multiprocessing start method: {mp_start_method}")
return mp.get_context(mp_start_method)



def _run_all_concurrencies_mem_efficient(self) -> float:
max_qps = 0
conc_num_list = []
Expand Down Expand Up @@ -145,3 +147,88 @@ def run(self) -> float:

def stop(self) -> None:
pass

def run_by_dur(self, duration: int) -> float:
return self._run_by_dur(duration)

def _run_by_dur(self, duration: int) -> float:
max_qps = 0
try:
for conc in self.concurrencies:
with mp.Manager() as m:
q, cond = m.Queue(), m.Condition()
with concurrent.futures.ProcessPoolExecutor(mp_context=self.get_mp_context(), max_workers=conc) as executor:
log.info(f"Start search_by_dur {duration}s in concurrency {conc}, filters: {self.filters}")
future_iter = [executor.submit(self.search_by_dur, duration, self.test_data, q, cond) for i in range(conc)]
# Sync all processes
while q.qsize() < conc:
sleep_t = conc if conc < 10 else 10
time.sleep(sleep_t)

with cond:
cond.notify_all()
log.info(f"Syncing all process and start concurrency search, concurrency={conc}")

start = time.perf_counter()
all_count = sum([r.result() for r in future_iter])
cost = time.perf_counter() - start

qps = round(all_count / cost, 4)
log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}")

if qps > max_qps:
max_qps = qps
log.info(f"Update largest qps with concurrency {conc}: current max_qps={max_qps}")
except Exception as e:
log.warning(f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}")
traceback.print_exc()

# No results available, raise exception
if max_qps == 0.0:
raise e from None

finally:
self.stop()

return max_qps


def search_by_dur(self, dur: int, test_data: list[list[float]], q: mp.Queue, cond: mp.Condition) -> int:
# sync all process
q.put(1)
with cond:
cond.wait()

with self.db.init():
num, idx = len(test_data), random.randint(0, len(test_data) - 1)

start_time = time.perf_counter()
count = 0
while time.perf_counter() < start_time + dur:
s = time.perf_counter()
try:
self.db.search_embedding(
test_data[idx],
self.k,
self.filters,
)
except Exception as e:
log.warning(f"VectorDB search_embedding error: {e}")
traceback.print_exc(chain=True)
raise e from None

count += 1
# loop through the test data
idx = idx + 1 if idx < num - 1 else 0

if count % 500 == 0:
log.debug(f"({mp.current_process().name:16}) search_count: {count}, latest_latency={time.perf_counter()-s}")

total_dur = round(time.perf_counter() - start_time, 4)
log.debug(
f"{mp.current_process().name:16} search {self.duration}s: "
f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}"
)

return count

79 changes: 79 additions & 0 deletions vectordb_bench/backend/runner/rate_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import logging
import time
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp


from vectordb_bench.backend.clients import api
from vectordb_bench.backend.dataset import DataSetIterator
from vectordb_bench.backend.utils import time_it
from vectordb_bench import config

from .util import get_data, is_futures_completed, get_future_exceptions
log = logging.getLogger(__name__)


class RatedMultiThreadingInsertRunner:
def __init__(
self,
rate: int, # numRows per second
db: api.VectorDB,
dataset_iter: DataSetIterator,
normalize: bool = False,
timeout: float | None = None,
):
self.timeout = timeout if isinstance(timeout, (int, float)) else None
self.dataset = dataset_iter
self.db = db
self.normalize = normalize
self.insert_rate = rate
self.batch_rate = rate // config.NUM_PER_BATCH

def send_insert_task(self, db, emb: list[list[float]], metadata: list[str]):
db.insert_embeddings(emb, metadata)

@time_it
def run_with_rate(self, q: mp.Queue):
with ThreadPoolExecutor(max_workers=mp.cpu_count()) as executor:
executing_futures = []

@time_it
def submit_by_rate() -> bool:
rate = self.batch_rate
for data in self.dataset:
emb, metadata = get_data(data, self.normalize)
executing_futures.append(executor.submit(self.send_insert_task, self.db, emb, metadata))
rate -= 1

if rate == 0:
return False
return rate == self.batch_rate

with self.db.init():
while True:
start_time = time.perf_counter()
finished, elapsed_time = submit_by_rate()
if finished is True:
q.put(None, block=True)
log.info(f"End of dataset, left unfinished={len(executing_futures)}")
return

q.put(True, block=False)
wait_interval = 1 - elapsed_time if elapsed_time < 1 else 0.001

e, completed = is_futures_completed(executing_futures, wait_interval)
if completed is True:
ex = get_future_exceptions(executing_futures)
if ex is not None:
log.warn(f"task error, terminating, err={ex}")
q.put(None)
executor.shutdown(wait=True, cancel_futures=True)
raise ex
else:
log.debug(f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} task in 1s, wait_interval={wait_interval:.2f}")
executing_futures = []
else:
log.warning(f"Failed to finish tasks in 1s, {e}, waited={wait_interval:.2f}, try to check the next round")
dur = time.perf_counter() - start_time
if dur < 1:
time.sleep(1 - dur)
Loading

0 comments on commit 49bccd1

Please sign in to comment.