forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9d3cfa4
commit f8d1673
Showing
4 changed files
with
125 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from __future__ import annotations | ||
import fnmatch | ||
import os | ||
import shutil | ||
import tempfile | ||
from pathlib import Path | ||
from typing import List, Optional, Tuple | ||
import signal | ||
|
||
import boto3 | ||
|
||
|
||
def is_s3(model_or_path: str) -> bool: | ||
return model_or_path.lower().startswith('s3://') | ||
|
||
|
||
def _filter_allow(paths: List[str], patterns: List[str]) -> List[str]: | ||
return [ | ||
path for path in paths if any( | ||
fnmatch.fnmatch(path, pattern) for pattern in patterns) | ||
] | ||
|
||
|
||
def _filter_ignore(paths: List[str], patterns: List[str]) -> List[str]: | ||
return [ | ||
path for path in paths | ||
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns) | ||
] | ||
|
||
|
||
def glob(s3 = None, | ||
path: str = "", | ||
allow_pattern: Optional[List[str]] = None) -> List[str]: | ||
if s3 is None: | ||
s3 = boto3.client("s3") | ||
bucket_name, _, paths = list_files( | ||
s3, path=path, allow_pattern=allow_pattern) | ||
return [f"s3://{bucket_name}/{path}" for path in paths] | ||
|
||
def list_files( | ||
s3, | ||
path: str, | ||
allow_pattern: Optional[List[str]] = None, | ||
ignore_pattern: Optional[List[str]] = None | ||
) -> Tuple[str, str, List[str]]: | ||
parts = path.removeprefix('s3://').split('/') | ||
prefix = '/'.join(parts[1:]) | ||
bucket_name = parts[0] | ||
|
||
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix) | ||
paths = [obj['Key'] for obj in objects.get('Contents', [])] | ||
|
||
paths = _filter_ignore(paths, ["*/"]) | ||
if allow_pattern is not None: | ||
paths = _filter_allow(paths, allow_pattern) | ||
|
||
if ignore_pattern is not None: | ||
paths = _filter_ignore(paths, ignore_pattern) | ||
|
||
return bucket_name, prefix, paths | ||
|
||
|
||
class S3Model: | ||
def __init__(self) -> None: | ||
self.s3 = boto3.client('s3') | ||
for sig in (signal.SIGINT, signal.SIGTERM): | ||
existing_handler = signal.getsignal(sig) | ||
signal.signal(sig, self.close_by_signal(existing_handler)) | ||
self.dir = tempfile.mkdtemp(dir='/dev/shm') | ||
|
||
def __del__(self): | ||
self.close() | ||
|
||
def close(self) -> None: | ||
if os.path.exists(self.dir): | ||
shutil.rmtree(self.dir) | ||
|
||
def close_by_signal(self, existing_handler=None): | ||
def new_handler(signum, frame): | ||
self.close() | ||
if existing_handler: | ||
existing_handler(signum, frame) | ||
return new_handler | ||
|
||
def pull_files(self, | ||
s3_model_path: str = "", | ||
allow_pattern: Optional[List[str]] = None, | ||
ignore_pattern: Optional[List[str]] = None) -> str: | ||
bucket_name, base_dir, files = list_files(self.s3, s3_model_path, | ||
allow_pattern, | ||
ignore_pattern) | ||
if len(files) == 0: | ||
return self.dir | ||
|
||
for file in files: | ||
destination_file = self.dir + file.removeprefix(base_dir) | ||
local_dir = Path(destination_file).parent | ||
os.makedirs(local_dir, exist_ok=True) | ||
self.s3.download_file(bucket_name, file, destination_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters