Skip to content

Commit

Permalink
Model Downloaders
Browse files Browse the repository at this point in the history
  • Loading branch information
bbengfort committed Dec 20, 2024
1 parent a1643ee commit a5afa9f
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 70 deletions.
71 changes: 71 additions & 0 deletions construe/cloud/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Handle HTTP download requests from content URLs
"""

import os
import shutil
import zipfile

from tqdm import tqdm
from urllib.request import urlopen
from construe.exceptions import DownloadError

from .signature import sha256sum


# Download chunk size
CHUNK = 524288


def download_zip(url, out, signature, replace=False, extract=True):
"""
Download a zipped file at the given URL saving it to the out directory. Once
downloaded, verify the signature to make sure the download hasn't been tampered
with or corrupted. If the file already exists it will be overwritten only if
replace=True. If extract=True then the file will be unzipped.
"""
# Get the name of the file from the URL
basename = os.path.basename(url)
name, _ = os.path.splitext(basename)

# Get the archive and data directory paths
archive = os.path.join(out, basename)
datadir = os.path.join(out, name)

# If the archive exists cleanup or raise override exception
if os.path.exists(archive):
if not replace:
raise DownloadError(
f"dataset already exists at {archive}, set replace=False to overwrite"
)

shutil.rmtree(datadir)
os.remove(archive)

# Create the output directory if it does not exist
if not os.path.exists(datadir):
os.mkdir(datadir)

# Fetch the response in a streaming fashion and write it to disk.
response = urlopen(url)
content_length = int(response.headers["Content-Length"])

with open(archive, "wb") as f:
pbar = tqdm(
unit="B", total=content_length, desc=f"Downloading {basename}", leave=False
)
while True:
chunk = response.read(CHUNK)
if not chunk:
break
f.write(chunk)
pbar.update(len(chunk))

# Compare the signature of the archive to the expected one
if sha256sum(archive) != signature:
raise DownloadError("Download signature does not match hardcoded signature!")

# If extract, extract the zipfile.
if extract:
zf = zipfile.ZipFile(archive)
zf.extractall(path=datadir)
6 changes: 4 additions & 2 deletions construe/cloud/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import glob
import json

from ..exceptions import UploadError

try:
from google.cloud import storage
except ImportError:
Expand All @@ -24,7 +26,7 @@ def upload(name, path, client=None, bucket=CONSTRUE_BUCKET):
client = connect_storage()

if not os.path.exists(path) or not os.path.isfile(path):
raise ValueError("no zip file exists at " + path)
raise UploadError("no zip file exists at " + path)

bucket = client.get_bucket(bucket)
blob = bucket.blob(name)
Expand All @@ -44,7 +46,7 @@ def connect_storage(credentials=None):
credentials = credentials or find_service_account()

if credentials is None:
raise RuntimeError(
raise UploadError(
"could not find service account credentials: "
"set either $GOOGLE_APPLICATION_CREDENTIALS to the path "
"or store the credentials in the .secret folder"
Expand Down
66 changes: 9 additions & 57 deletions construe/datasets/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,14 @@
Handle downloading datasets from our content URL
"""

import os
import zipfile

from tqdm import tqdm
from functools import partial
from urllib.request import urlopen

from ..cloud.signature import sha256sum
from .path import get_data_home
from .manifest import load_manifest
from .path import get_data_home, cleanup_dataset
from ..cloud.download import download_zip
from .path import DIALECTS, LOWLIGHT, REDDIT, MOVIES, ESSAYS, AEGIS, NSFW


from construe.exceptions import DatasetsError


# Downlod chunk size
CHUNK = 524288
from ..exceptions import DatasetsError


def download_data(url, signature, data_home=None, replace=False, extract=True):
Expand All @@ -29,53 +19,15 @@ def download_data(url, signature, data_home=None, replace=False, extract=True):
the download with the given signature and extracts the archive.
"""
data_home = get_data_home(data_home)

# Get the name of the file from the URL
basename = os.path.basename(url)
name, _ = os.path.splitext(basename)

# Get the archive and data directory paths
archive = os.path.join(data_home, basename)
datadir = os.path.join(data_home, name)

# If the archive exists cleanup or raise override exception
if os.path.exists(archive):
if not replace:
raise DatasetsError(
f"dataset already exists at {archive}, set replace=False to overwrite"
)
cleanup_dataset(name, data_home=data_home)

# Create the output directory if it does not exist
if not os.path.exists(datadir):
os.mkdir(datadir)

# Fetch the response in a streaming fashion and write it to disk.
response = urlopen(url)
content_length = int(response.headers["Content-Length"])

with open(archive, "wb") as f:
pbar = tqdm(
unit="B", total=content_length, desc=f"Downloading {basename}", leave=False
)
while True:
chunk = response.read(CHUNK)
if not chunk:
break
f.write(chunk)
pbar.update(len(chunk))

# Compare the signature of the archive to the expected one
if sha256sum(archive) != signature:
raise ValueError("Download signature does not match hardcoded signature!")

# If extract, extract the zipfile.
if extract:
zf = zipfile.ZipFile(archive)
zf.extractall(path=datadir)
download_zip(url, data_home, signature=signature, replace=replace, extract=extract)


def _download_dataset(name, sample=True, data_home=True, replace=False, extract=True):
"""
Downloads the zipped data set specified using the manifest URL, saving it to the
data directory specified by ``get_data_home``. The download is verified with
the given signature then extracted.
"""
if sample and not name.endswith("-sample"):
name = name + "-sample"

Expand Down
8 changes: 3 additions & 5 deletions construe/datasets/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def find_dataset_path(dataset, data_home=None, fname=None, ext=None, raises=True
"""
Looks up the path to the dataset specified in the data home directory,
which is found using the ``get_data_home`` function. By default data home
is colocated with the code, but can be modified with the CONSTRUE_DATA
environment variable, or passing in a different directory.
is in a config directory in the user's home folder, but can be modified with the
$CONSTRUE_DATA environment variable, or passing in a different directory.
If the dataset is not found a ``DatasetsError`` is raised by default.
"""
Expand All @@ -80,9 +80,7 @@ def find_dataset_path(dataset, data_home=None, fname=None, ext=None, raises=True
return None

raise DatasetsError(
("could not find dataset at {} - does it need to be downloaded?").format(
path
)
f"could not find dataset at {path} - does it need to be downloaded?"
)

return path
Expand Down
12 changes: 12 additions & 0 deletions construe/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,22 @@ class ConstrueError(ClickException):
pass


class DownloadError(ConstrueError):
pass


class UploadError(ConstrueError):
pass


class DatasetsError(ConstrueError):
pass


class ModelsError(ConstrueError):
pass


class DeviceError(ConstrueError):

def __init__(self, e):
Expand Down
59 changes: 59 additions & 0 deletions construe/models/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Handle downloading models from content URLs
"""

from functools import partial

from .path import get_models_home
from .manifest import load_manifest
from ..cloud.download import download_zip
from .path import NSFW, LOWLIGHT, OFFENSIVE, GLINER
from .path import MOONDREAM, WHISPER, MOBILENET, MOBILEVIT

from ..exceptions import ModelsError


def download_model(url, signature, model_home=None, replace=False, extract=True):
"""
Downloads the zipped model file specified at the given URL saving it to the models
directory specified by ``get_models_home``. The download is verified with the
given signature then extracted.
"""
model_home = get_models_home(model_home)
download_zip(url, model_home, signature=signature, replace=replace, extract=extract)


def _download_model(name, model_home=None, replace=False, extract=True):
"""
Downloads the zipped model file specified using the manifest URL, saving it to the
models directory specified by ``get_models_home``. The download is verified with
the given signature then extracted.
"""
models = load_manifest()
if name not in models:
raise ModelsError(f"no model named {name} exists")

info = models[name]
info.update({"model_home": model_home, "replace": replace, "extract": extract})
download_model(**info)


download_moondream = partial(_download_model, MOONDREAM)
download_whisper = partial(_download_model, WHISPER)
download_mobilenet = partial(_download_model, MOBILENET)
download_mobilevit = partial(_download_model, MOBILEVIT)
download_nsfw = partial(_download_model, NSFW)
download_lowlight = partial(_download_model, LOWLIGHT)
download_offensive = partial(_download_model, OFFENSIVE)
download_gliner = partial(_download_model, GLINER)


DOWNLOADERS = [
download_moondream, download_whisper, download_mobilenet, download_mobilevit,
download_nsfw, download_lowlight, download_offensive, download_gliner,
]


def download_all_models(model_home=None, replace=True, extract=True):
for f in DOWNLOADERS:
f(model_home=model_home, replace=replace, extract=extract)
Loading

0 comments on commit a5afa9f

Please sign in to comment.