Skip to content

Commit

Permalink
Enable aliyun OSS
Browse files Browse the repository at this point in the history
Add data_source.py, vdb bench now can download dataset
from Aliyun OSS.

Signed-off-by: yangxuan <[email protected]>
  • Loading branch information
XuanYang-cn authored and alwayslove2013 committed Jan 18, 2024
1 parent 34e5794 commit fd2b186
Show file tree
Hide file tree
Showing 10 changed files with 418 additions and 136 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"streamlit_extras",
"tqdm",
"s3fs",
"oss2",
"psutil",
"polars",
"plotly",
Expand Down
78 changes: 78 additions & 0 deletions tests/test_data_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import logging
import pathlib
import pytest
from vectordb_bench.backend.data_source import AliyunOSSReader, AwsS3Reader
from vectordb_bench.backend.dataset import Dataset, DatasetManager

log = logging.getLogger(__name__)

class TestReader:
@pytest.mark.parametrize("size", [
100_000,
1_000_000,
10_000_000,
])
def test_cohere(self, size):
cohere = Dataset.COHERE.manager(size)
self.per_dataset_test(cohere)

@pytest.mark.parametrize("size", [
100_000,
1_000_000,
])
def test_gist(self, size):
gist = Dataset.GIST.manager(size)
self.per_dataset_test(gist)

@pytest.mark.parametrize("size", [
1_000_000,
])
def test_glove(self, size):
glove = Dataset.GLOVE.manager(size)
self.per_dataset_test(glove)

@pytest.mark.parametrize("size", [
500_000,
5_000_000,
# 50_000_000,
])
def test_sift(self, size):
sift = Dataset.SIFT.manager(size)
self.per_dataset_test(sift)

@pytest.mark.parametrize("size", [
50_000,
500_000,
5_000_000,
])
def test_openai(self, size):
openai = Dataset.OPENAI.manager(size)
self.per_dataset_test(openai)


def per_dataset_test(self, dataset: DatasetManager):
s3_reader = AwsS3Reader()
all_files = s3_reader.ls_all(dataset.data.dir_name)


remote_f_names = []
for file in all_files:
remote_f = pathlib.Path(file).name
if dataset.data.use_shuffled and remote_f.startswith("train"):
continue

elif (not dataset.data.use_shuffled) and remote_f.startswith("shuffle"):
continue

remote_f_names.append(remote_f)


assert set(dataset.data.files) == set(remote_f_names)

aliyun_reader = AliyunOSSReader()
for fname in dataset.data.files:
p = pathlib.Path("benchmark", dataset.data.dir_name, fname)
assert aliyun_reader.bucket.object_exists(p.as_posix())

log.info(f"downloading to {dataset.data_dir}")
aliyun_reader.read(dataset.data.dir_name.lower(), dataset.data.files, dataset.data_dir)
33 changes: 32 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from vectordb_bench.backend.dataset import Dataset
from vectordb_bench.backend.dataset import Dataset, get_files
import logging
import pytest
from pydantic import ValidationError
Expand Down Expand Up @@ -34,3 +34,34 @@ def test_iter_cohere(self):
for i in cohere_10m:
log.debug(i.head(1))


class TestGetFiles:
@pytest.mark.parametrize("train_count", [
1,
10,
50,
100,
])
@pytest.mark.parametrize("with_gt", [True, False])
def test_train_count(self, train_count, with_gt):
files = get_files(train_count, True, with_gt)
log.info(files)

if with_gt:
assert len(files) - 4 == train_count
else:
assert len(files) - 1 == train_count

@pytest.mark.parametrize("use_shuffled", [True, False])
def test_use_shuffled(self, use_shuffled):
files = get_files(1, use_shuffled, True)
log.info(files)

trains = [f for f in files if "train" in f]
if use_shuffled:
for t in trains:
assert "shuffle_train" in t
else:
for t in trains:
assert "shuffle" not in t
assert "train" in t
6 changes: 4 additions & 2 deletions vectordb_bench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
env.read_env(".env")

class config:
ALIYUN_OSS_URL = "assets.zilliz.com.cn/benchmark/"
AWS_S3_URL = "assets.zilliz.com/benchmark/"

LOG_LEVEL = env.str("LOG_LEVEL", "INFO")

DEFAULT_DATASET_URL = env.str("DEFAULT_DATASET_URL", "assets.zilliz.com/benchmark/")
DEFAULT_DATASET_URL_ALIYUN = env.str("DEFAULT_DATASET_URL", "assets.zilliz.com.cn/benchmark/")
DEFAULT_DATASET_URL = env.str("DEFAULT_DATASET_URL", AWS_S3_URL)
DATASET_LOCAL_DIR = env.path("DATASET_LOCAL_DIR", "/tmp/vectordb_bench/dataset")
NUM_PER_BATCH = env.int("NUM_PER_BATCH", 5000)

Expand Down
14 changes: 11 additions & 3 deletions vectordb_bench/backend/assembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .task_runner import CaseRunner, RunningStatus, TaskRunner
from ..models import TaskConfig
from ..backend.clients import EmptyDBCaseConfig
from ..backend.data_source import DatasetSource
import logging


Expand All @@ -10,7 +11,7 @@

class Assembler:
@classmethod
def assemble(cls, run_id , task: TaskConfig) -> CaseRunner:
def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner:
c_cls = task.case_config.case_id.case_cls

c = c_cls()
Expand All @@ -22,14 +23,21 @@ def assemble(cls, run_id , task: TaskConfig) -> CaseRunner:
config=task,
ca=c,
status=RunningStatus.PENDING,
dataset_source=source,
)

return runner

@classmethod
def assemble_all(cls, run_id: str, task_label: str, tasks: list[TaskConfig]) -> TaskRunner:
def assemble_all(
cls,
run_id: str,
task_label: str,
tasks: list[TaskConfig],
source: DatasetSource,
) -> TaskRunner:
"""group by case type, db, and case dataset"""
runners = [cls.assemble(run_id, task) for task in tasks]
runners = [cls.assemble(run_id, task, source) for task in tasks]
load_runners = [r for r in runners if r.ca.label == CaseLabel.Load]
perf_runners = [r for r in runners if r.ca.label == CaseLabel.Performance]

Expand Down
204 changes: 204 additions & 0 deletions vectordb_bench/backend/data_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import logging
import pathlib
import typing
from enum import Enum
from tqdm import tqdm
from hashlib import md5
import os
from abc import ABC, abstractmethod

from .. import config

logging.getLogger("s3fs").setLevel(logging.CRITICAL)

log = logging.getLogger(__name__)

DatasetReader = typing.TypeVar("DatasetReader")

class DatasetSource(Enum):
S3 = "S3"
AliyunOSS = "AliyunOSS"

def reader(self) -> DatasetReader:
if self == DatasetSource.S3:
return AwsS3Reader()

if self == DatasetSource.AliyunOSS:
return AliyunOSSReader()


class DatasetReader(ABC):
source: DatasetSource
remote_root: str

@abstractmethod
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True):
"""read dataset files from remote_root to local_ds_root,
Args:
dataset(str): for instance "sift_small_500k"
files(list[str]): all filenames of the dataset
local_ds_root(pathlib.Path): whether to write the remote data.
check_etag(bool): whether to check the etag
"""
pass

@abstractmethod
def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
pass


class AliyunOSSReader(DatasetReader):
source: DatasetSource = DatasetSource.AliyunOSS
remote_root: str = config.ALIYUN_OSS_URL

def __init__(self):
import oss2
self.bucket = oss2.Bucket(oss2.AnonymousAuth(), self.remote_root, "benchmark", True)

def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool:
info = self.bucket.get_object_meta(remote.as_posix())

# check size equal
remote_size, local_size = info.content_length, os.path.getsize(local)
if remote_size != local_size:
log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
return False

# check etag equal
if check_etag:
return match_etag(info.etag.strip('"').lower(), local)


return True

def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = False):
downloads = []
if not local_ds_root.exists():
log.info(f"local dataset root path not exist, creating it: {local_ds_root}")
local_ds_root.mkdir(parents=True)
downloads = [(pathlib.Path("benchmark", dataset, f), local_ds_root.joinpath(f)) for f in files]

else:
for file in files:
remote_file = pathlib.Path("benchmark", dataset, file)
local_file = local_ds_root.joinpath(file)

if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, check_etag)):
log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
downloads.append((remote_file, local_file))

if len(downloads) == 0:
return

log.info(f"Start to downloading files, total count: {len(downloads)}")
for remote_file, local_file in tqdm(downloads):
log.debug(f"downloading file {remote_file} to {local_ds_root}")
self.bucket.get_object_to_file(remote_file.as_posix(), local_file.as_posix())

log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")



class AwsS3Reader(DatasetReader):
source: DatasetSource = DatasetSource.S3
remote_root: str = config.AWS_S3_URL

def __init__(self):
import s3fs
self.fs = s3fs.S3FileSystem(
anon=True,
client_kwargs={'region_name': 'us-west-2'}
)

def ls_all(self, dataset: str):
dataset_root_dir = pathlib.Path(self.remote_root, dataset)
log.info(f"listing dataset: {dataset_root_dir}")
names = self.fs.ls(dataset_root_dir)
for n in names:
log.info(n)
return names


def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True):
downloads = []
if not local_ds_root.exists():
log.info(f"local dataset root path not exist, creating it: {local_ds_root}")
local_ds_root.mkdir(parents=True)
downloads = [pathlib.Path(self.remote_root, dataset, f) for f in files]

else:
for file in files:
remote_file = pathlib.Path(self.remote_root, dataset, file)
local_file = local_ds_root.joinpath(file)

if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, check_etag)):
log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
downloads.append(remote_file)

if len(downloads) == 0:
return

log.info(f"Start to downloading files, total count: {len(downloads)}")
for s3_file in tqdm(downloads):
log.debug(f"downloading file {s3_file} to {local_ds_root}")
self.fs.download(s3_file, local_ds_root.as_posix())

log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")


def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool:
# info() uses ls() inside, maybe we only need to ls once
info = self.fs.info(remote)

# check size equal
remote_size, local_size = info.get("size"), os.path.getsize(local)
if remote_size != local_size:
log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
return False

# check etag equal
if check_etag:
return match_etag(info.get('ETag', "").strip('"'), local)

return True


def match_etag(expected_etag: str, local_file) -> bool:
"""Check if local files' etag match with S3"""
def factor_of_1MB(filesize, num_parts):
x = filesize / int(num_parts)
y = x % 1048576
return int(x + 1048576 - y)

def calc_etag(inputfile, partsize):
md5_digests = []
with open(inputfile, 'rb') as f:
for chunk in iter(lambda: f.read(partsize), b''):
md5_digests.append(md5(chunk).digest())
return md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests))

def possible_partsizes(filesize, num_parts):
return lambda partsize: partsize < filesize and (float(filesize) / float(partsize)) <= num_parts

filesize = os.path.getsize(local_file)
le = ""
if '-' not in expected_etag: # no spliting uploading
with open(local_file, 'rb') as f:
le = md5(f.read()).hexdigest()
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
return expected_etag == le
else:
num_parts = int(expected_etag.split('-')[-1])
partsizes = [ ## Default Partsizes Map
8388608, # aws_cli/boto3
15728640, # s3cmd
factor_of_1MB(filesize, num_parts) # Used by many clients to upload large files
]

for partsize in filter(possible_partsizes(filesize, num_parts), partsizes):
le = calc_etag(local_file, partsize)
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
if expected_etag == le:
return True
return False
Loading

0 comments on commit fd2b186

Please sign in to comment.