-
Notifications
You must be signed in to change notification settings - Fork 151
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add data_source.py, vdb bench now can download dataset from Aliyun OSS. Signed-off-by: yangxuan <[email protected]>
- Loading branch information
1 parent
34e5794
commit fd2b186
Showing
10 changed files
with
418 additions
and
136 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ dependencies = [ | |
"streamlit_extras", | ||
"tqdm", | ||
"s3fs", | ||
"oss2", | ||
"psutil", | ||
"polars", | ||
"plotly", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import logging | ||
import pathlib | ||
import pytest | ||
from vectordb_bench.backend.data_source import AliyunOSSReader, AwsS3Reader | ||
from vectordb_bench.backend.dataset import Dataset, DatasetManager | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
class TestReader: | ||
@pytest.mark.parametrize("size", [ | ||
100_000, | ||
1_000_000, | ||
10_000_000, | ||
]) | ||
def test_cohere(self, size): | ||
cohere = Dataset.COHERE.manager(size) | ||
self.per_dataset_test(cohere) | ||
|
||
@pytest.mark.parametrize("size", [ | ||
100_000, | ||
1_000_000, | ||
]) | ||
def test_gist(self, size): | ||
gist = Dataset.GIST.manager(size) | ||
self.per_dataset_test(gist) | ||
|
||
@pytest.mark.parametrize("size", [ | ||
1_000_000, | ||
]) | ||
def test_glove(self, size): | ||
glove = Dataset.GLOVE.manager(size) | ||
self.per_dataset_test(glove) | ||
|
||
@pytest.mark.parametrize("size", [ | ||
500_000, | ||
5_000_000, | ||
# 50_000_000, | ||
]) | ||
def test_sift(self, size): | ||
sift = Dataset.SIFT.manager(size) | ||
self.per_dataset_test(sift) | ||
|
||
@pytest.mark.parametrize("size", [ | ||
50_000, | ||
500_000, | ||
5_000_000, | ||
]) | ||
def test_openai(self, size): | ||
openai = Dataset.OPENAI.manager(size) | ||
self.per_dataset_test(openai) | ||
|
||
|
||
def per_dataset_test(self, dataset: DatasetManager): | ||
s3_reader = AwsS3Reader() | ||
all_files = s3_reader.ls_all(dataset.data.dir_name) | ||
|
||
|
||
remote_f_names = [] | ||
for file in all_files: | ||
remote_f = pathlib.Path(file).name | ||
if dataset.data.use_shuffled and remote_f.startswith("train"): | ||
continue | ||
|
||
elif (not dataset.data.use_shuffled) and remote_f.startswith("shuffle"): | ||
continue | ||
|
||
remote_f_names.append(remote_f) | ||
|
||
|
||
assert set(dataset.data.files) == set(remote_f_names) | ||
|
||
aliyun_reader = AliyunOSSReader() | ||
for fname in dataset.data.files: | ||
p = pathlib.Path("benchmark", dataset.data.dir_name, fname) | ||
assert aliyun_reader.bucket.object_exists(p.as_posix()) | ||
|
||
log.info(f"downloading to {dataset.data_dir}") | ||
aliyun_reader.read(dataset.data.dir_name.lower(), dataset.data.files, dataset.data_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
import logging | ||
import pathlib | ||
import typing | ||
from enum import Enum | ||
from tqdm import tqdm | ||
from hashlib import md5 | ||
import os | ||
from abc import ABC, abstractmethod | ||
|
||
from .. import config | ||
|
||
logging.getLogger("s3fs").setLevel(logging.CRITICAL) | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
DatasetReader = typing.TypeVar("DatasetReader") | ||
|
||
class DatasetSource(Enum): | ||
S3 = "S3" | ||
AliyunOSS = "AliyunOSS" | ||
|
||
def reader(self) -> DatasetReader: | ||
if self == DatasetSource.S3: | ||
return AwsS3Reader() | ||
|
||
if self == DatasetSource.AliyunOSS: | ||
return AliyunOSSReader() | ||
|
||
|
||
class DatasetReader(ABC): | ||
source: DatasetSource | ||
remote_root: str | ||
|
||
@abstractmethod | ||
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True): | ||
"""read dataset files from remote_root to local_ds_root, | ||
Args: | ||
dataset(str): for instance "sift_small_500k" | ||
files(list[str]): all filenames of the dataset | ||
local_ds_root(pathlib.Path): whether to write the remote data. | ||
check_etag(bool): whether to check the etag | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool: | ||
pass | ||
|
||
|
||
class AliyunOSSReader(DatasetReader): | ||
source: DatasetSource = DatasetSource.AliyunOSS | ||
remote_root: str = config.ALIYUN_OSS_URL | ||
|
||
def __init__(self): | ||
import oss2 | ||
self.bucket = oss2.Bucket(oss2.AnonymousAuth(), self.remote_root, "benchmark", True) | ||
|
||
def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool: | ||
info = self.bucket.get_object_meta(remote.as_posix()) | ||
|
||
# check size equal | ||
remote_size, local_size = info.content_length, os.path.getsize(local) | ||
if remote_size != local_size: | ||
log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]") | ||
return False | ||
|
||
# check etag equal | ||
if check_etag: | ||
return match_etag(info.etag.strip('"').lower(), local) | ||
|
||
|
||
return True | ||
|
||
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = False): | ||
downloads = [] | ||
if not local_ds_root.exists(): | ||
log.info(f"local dataset root path not exist, creating it: {local_ds_root}") | ||
local_ds_root.mkdir(parents=True) | ||
downloads = [(pathlib.Path("benchmark", dataset, f), local_ds_root.joinpath(f)) for f in files] | ||
|
||
else: | ||
for file in files: | ||
remote_file = pathlib.Path("benchmark", dataset, file) | ||
local_file = local_ds_root.joinpath(file) | ||
|
||
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, check_etag)): | ||
log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list") | ||
downloads.append((remote_file, local_file)) | ||
|
||
if len(downloads) == 0: | ||
return | ||
|
||
log.info(f"Start to downloading files, total count: {len(downloads)}") | ||
for remote_file, local_file in tqdm(downloads): | ||
log.debug(f"downloading file {remote_file} to {local_ds_root}") | ||
self.bucket.get_object_to_file(remote_file.as_posix(), local_file.as_posix()) | ||
|
||
log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}") | ||
|
||
|
||
|
||
class AwsS3Reader(DatasetReader): | ||
source: DatasetSource = DatasetSource.S3 | ||
remote_root: str = config.AWS_S3_URL | ||
|
||
def __init__(self): | ||
import s3fs | ||
self.fs = s3fs.S3FileSystem( | ||
anon=True, | ||
client_kwargs={'region_name': 'us-west-2'} | ||
) | ||
|
||
def ls_all(self, dataset: str): | ||
dataset_root_dir = pathlib.Path(self.remote_root, dataset) | ||
log.info(f"listing dataset: {dataset_root_dir}") | ||
names = self.fs.ls(dataset_root_dir) | ||
for n in names: | ||
log.info(n) | ||
return names | ||
|
||
|
||
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path, check_etag: bool = True): | ||
downloads = [] | ||
if not local_ds_root.exists(): | ||
log.info(f"local dataset root path not exist, creating it: {local_ds_root}") | ||
local_ds_root.mkdir(parents=True) | ||
downloads = [pathlib.Path(self.remote_root, dataset, f) for f in files] | ||
|
||
else: | ||
for file in files: | ||
remote_file = pathlib.Path(self.remote_root, dataset, file) | ||
local_file = local_ds_root.joinpath(file) | ||
|
||
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file, check_etag)): | ||
log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list") | ||
downloads.append(remote_file) | ||
|
||
if len(downloads) == 0: | ||
return | ||
|
||
log.info(f"Start to downloading files, total count: {len(downloads)}") | ||
for s3_file in tqdm(downloads): | ||
log.debug(f"downloading file {s3_file} to {local_ds_root}") | ||
self.fs.download(s3_file, local_ds_root.as_posix()) | ||
|
||
log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}") | ||
|
||
|
||
def validate_file(self, remote: pathlib.Path, local: pathlib.Path, check_etag: bool) -> bool: | ||
# info() uses ls() inside, maybe we only need to ls once | ||
info = self.fs.info(remote) | ||
|
||
# check size equal | ||
remote_size, local_size = info.get("size"), os.path.getsize(local) | ||
if remote_size != local_size: | ||
log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]") | ||
return False | ||
|
||
# check etag equal | ||
if check_etag: | ||
return match_etag(info.get('ETag', "").strip('"'), local) | ||
|
||
return True | ||
|
||
|
||
def match_etag(expected_etag: str, local_file) -> bool: | ||
"""Check if local files' etag match with S3""" | ||
def factor_of_1MB(filesize, num_parts): | ||
x = filesize / int(num_parts) | ||
y = x % 1048576 | ||
return int(x + 1048576 - y) | ||
|
||
def calc_etag(inputfile, partsize): | ||
md5_digests = [] | ||
with open(inputfile, 'rb') as f: | ||
for chunk in iter(lambda: f.read(partsize), b''): | ||
md5_digests.append(md5(chunk).digest()) | ||
return md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests)) | ||
|
||
def possible_partsizes(filesize, num_parts): | ||
return lambda partsize: partsize < filesize and (float(filesize) / float(partsize)) <= num_parts | ||
|
||
filesize = os.path.getsize(local_file) | ||
le = "" | ||
if '-' not in expected_etag: # no spliting uploading | ||
with open(local_file, 'rb') as f: | ||
le = md5(f.read()).hexdigest() | ||
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}") | ||
return expected_etag == le | ||
else: | ||
num_parts = int(expected_etag.split('-')[-1]) | ||
partsizes = [ ## Default Partsizes Map | ||
8388608, # aws_cli/boto3 | ||
15728640, # s3cmd | ||
factor_of_1MB(filesize, num_parts) # Used by many clients to upload large files | ||
] | ||
|
||
for partsize in filter(possible_partsizes(filesize, num_parts), partsizes): | ||
le = calc_etag(local_file, partsize) | ||
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}") | ||
if expected_etag == le: | ||
return True | ||
return False |
Oops, something went wrong.