-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
255 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
""" | ||
Handle HTTP download requests from content URLs | ||
""" | ||
|
||
import os | ||
import shutil | ||
import zipfile | ||
|
||
from tqdm import tqdm | ||
from urllib.request import urlopen | ||
from construe.exceptions import DownloadError | ||
|
||
from .signature import sha256sum | ||
|
||
|
||
# Download chunk size | ||
CHUNK = 524288 | ||
|
||
|
||
def download_zip(url, out, signature, replace=False, extract=True): | ||
""" | ||
Download a zipped file at the given URL saving it to the out directory. Once | ||
downloaded, verify the signature to make sure the download hasn't been tampered | ||
with or corrupted. If the file already exists it will be overwritten only if | ||
replace=True. If extract=True then the file will be unzipped. | ||
""" | ||
# Get the name of the file from the URL | ||
basename = os.path.basename(url) | ||
name, _ = os.path.splitext(basename) | ||
|
||
# Get the archive and data directory paths | ||
archive = os.path.join(out, basename) | ||
datadir = os.path.join(out, name) | ||
|
||
# If the archive exists cleanup or raise override exception | ||
if os.path.exists(archive): | ||
if not replace: | ||
raise DownloadError( | ||
f"dataset already exists at {archive}, set replace=False to overwrite" | ||
) | ||
|
||
shutil.rmtree(datadir) | ||
os.remove(archive) | ||
|
||
# Create the output directory if it does not exist | ||
if not os.path.exists(datadir): | ||
os.mkdir(datadir) | ||
|
||
# Fetch the response in a streaming fashion and write it to disk. | ||
response = urlopen(url) | ||
content_length = int(response.headers["Content-Length"]) | ||
|
||
with open(archive, "wb") as f: | ||
pbar = tqdm( | ||
unit="B", total=content_length, desc=f"Downloading {basename}", leave=False | ||
) | ||
while True: | ||
chunk = response.read(CHUNK) | ||
if not chunk: | ||
break | ||
f.write(chunk) | ||
pbar.update(len(chunk)) | ||
|
||
# Compare the signature of the archive to the expected one | ||
if sha256sum(archive) != signature: | ||
raise DownloadError("Download signature does not match hardcoded signature!") | ||
|
||
# If extract, extract the zipfile. | ||
if extract: | ||
zf = zipfile.ZipFile(archive) | ||
zf.extractall(path=datadir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
""" | ||
Handle downloading models from content URLs | ||
""" | ||
|
||
from functools import partial | ||
|
||
from .path import get_models_home | ||
from .manifest import load_manifest | ||
from ..cloud.download import download_zip | ||
from .path import NSFW, LOWLIGHT, OFFENSIVE, GLINER | ||
from .path import MOONDREAM, WHISPER, MOBILENET, MOBILEVIT | ||
|
||
from ..exceptions import ModelsError | ||
|
||
|
||
def download_model(url, signature, model_home=None, replace=False, extract=True): | ||
""" | ||
Downloads the zipped model file specified at the given URL saving it to the models | ||
directory specified by ``get_models_home``. The download is verified with the | ||
given signature then extracted. | ||
""" | ||
model_home = get_models_home(model_home) | ||
download_zip(url, model_home, signature=signature, replace=replace, extract=extract) | ||
|
||
|
||
def _download_model(name, model_home=None, replace=False, extract=True): | ||
""" | ||
Downloads the zipped model file specified using the manifest URL, saving it to the | ||
models directory specified by ``get_models_home``. The download is verified with | ||
the given signature then extracted. | ||
""" | ||
models = load_manifest() | ||
if name not in models: | ||
raise ModelsError(f"no model named {name} exists") | ||
|
||
info = models[name] | ||
info.update({"model_home": model_home, "replace": replace, "extract": extract}) | ||
download_model(**info) | ||
|
||
|
||
download_moondream = partial(_download_model, MOONDREAM) | ||
download_whisper = partial(_download_model, WHISPER) | ||
download_mobilenet = partial(_download_model, MOBILENET) | ||
download_mobilevit = partial(_download_model, MOBILEVIT) | ||
download_nsfw = partial(_download_model, NSFW) | ||
download_lowlight = partial(_download_model, LOWLIGHT) | ||
download_offensive = partial(_download_model, OFFENSIVE) | ||
download_gliner = partial(_download_model, GLINER) | ||
|
||
|
||
DOWNLOADERS = [ | ||
download_moondream, download_whisper, download_mobilenet, download_mobilevit, | ||
download_nsfw, download_lowlight, download_offensive, download_gliner, | ||
] | ||
|
||
|
||
def download_all_models(model_home=None, replace=True, extract=True): | ||
for f in DOWNLOADERS: | ||
f(model_home=model_home, replace=replace, extract=extract) |
Oops, something went wrong.