diff --git a/construe/cloud/download.py b/construe/cloud/download.py new file mode 100644 index 0000000..2b3211b --- /dev/null +++ b/construe/cloud/download.py @@ -0,0 +1,71 @@ +""" +Handle HTTP download requests from content URLs +""" + +import os +import shutil +import zipfile + +from tqdm import tqdm +from urllib.request import urlopen +from construe.exceptions import DownloadError + +from .signature import sha256sum + + +# Download chunk size +CHUNK = 524288 + + +def download_zip(url, out, signature, replace=False, extract=True): + """ + Download a zipped file at the given URL saving it to the out directory. Once + downloaded, verify the signature to make sure the download hasn't been tampered + with or corrupted. If the file already exists it will be overwritten only if + replace=True. If extract=True then the file will be unzipped. + """ + # Get the name of the file from the URL + basename = os.path.basename(url) + name, _ = os.path.splitext(basename) + + # Get the archive and data directory paths + archive = os.path.join(out, basename) + datadir = os.path.join(out, name) + + # If the archive exists cleanup or raise override exception + if os.path.exists(archive): + if not replace: + raise DownloadError( + f"dataset already exists at {archive}, set replace=False to overwrite" + ) + + shutil.rmtree(datadir) + os.remove(archive) + + # Create the output directory if it does not exist + if not os.path.exists(datadir): + os.mkdir(datadir) + + # Fetch the response in a streaming fashion and write it to disk. + response = urlopen(url) + content_length = int(response.headers["Content-Length"]) + + with open(archive, "wb") as f: + pbar = tqdm( + unit="B", total=content_length, desc=f"Downloading {basename}", leave=False + ) + while True: + chunk = response.read(CHUNK) + if not chunk: + break + f.write(chunk) + pbar.update(len(chunk)) + + # Compare the signature of the archive to the expected one + if sha256sum(archive) != signature: + raise DownloadError("Download signature does not match hardcoded signature!") + + # If extract, extract the zipfile. + if extract: + zf = zipfile.ZipFile(archive) + zf.extractall(path=datadir) diff --git a/construe/cloud/gcp.py b/construe/cloud/gcp.py index 23f31b1..c84a406 100644 --- a/construe/cloud/gcp.py +++ b/construe/cloud/gcp.py @@ -6,6 +6,8 @@ import glob import json +from ..exceptions import UploadError + try: from google.cloud import storage except ImportError: @@ -24,7 +26,7 @@ def upload(name, path, client=None, bucket=CONSTRUE_BUCKET): client = connect_storage() if not os.path.exists(path) or not os.path.isfile(path): - raise ValueError("no zip file exists at " + path) + raise UploadError("no zip file exists at " + path) bucket = client.get_bucket(bucket) blob = bucket.blob(name) @@ -44,7 +46,7 @@ def connect_storage(credentials=None): credentials = credentials or find_service_account() if credentials is None: - raise RuntimeError( + raise UploadError( "could not find service account credentials: " "set either $GOOGLE_APPLICATION_CREDENTIALS to the path " "or store the credentials in the .secret folder" diff --git a/construe/datasets/download.py b/construe/datasets/download.py index ae5982f..3e1d399 100644 --- a/construe/datasets/download.py +++ b/construe/datasets/download.py @@ -2,24 +2,14 @@ Handle downloading datasets from our content URL """ -import os -import zipfile - -from tqdm import tqdm from functools import partial -from urllib.request import urlopen -from ..cloud.signature import sha256sum +from .path import get_data_home from .manifest import load_manifest -from .path import get_data_home, cleanup_dataset +from ..cloud.download import download_zip from .path import DIALECTS, LOWLIGHT, REDDIT, MOVIES, ESSAYS, AEGIS, NSFW - -from construe.exceptions import DatasetsError - - -# Downlod chunk size -CHUNK = 524288 +from ..exceptions import DatasetsError def download_data(url, signature, data_home=None, replace=False, extract=True): @@ -29,53 +19,15 @@ def download_data(url, signature, data_home=None, replace=False, extract=True): the download with the given signature and extracts the archive. """ data_home = get_data_home(data_home) - - # Get the name of the file from the URL - basename = os.path.basename(url) - name, _ = os.path.splitext(basename) - - # Get the archive and data directory paths - archive = os.path.join(data_home, basename) - datadir = os.path.join(data_home, name) - - # If the archive exists cleanup or raise override exception - if os.path.exists(archive): - if not replace: - raise DatasetsError( - f"dataset already exists at {archive}, set replace=False to overwrite" - ) - cleanup_dataset(name, data_home=data_home) - - # Create the output directory if it does not exist - if not os.path.exists(datadir): - os.mkdir(datadir) - - # Fetch the response in a streaming fashion and write it to disk. - response = urlopen(url) - content_length = int(response.headers["Content-Length"]) - - with open(archive, "wb") as f: - pbar = tqdm( - unit="B", total=content_length, desc=f"Downloading {basename}", leave=False - ) - while True: - chunk = response.read(CHUNK) - if not chunk: - break - f.write(chunk) - pbar.update(len(chunk)) - - # Compare the signature of the archive to the expected one - if sha256sum(archive) != signature: - raise ValueError("Download signature does not match hardcoded signature!") - - # If extract, extract the zipfile. - if extract: - zf = zipfile.ZipFile(archive) - zf.extractall(path=datadir) + download_zip(url, data_home, signature=signature, replace=replace, extract=extract) def _download_dataset(name, sample=True, data_home=True, replace=False, extract=True): + """ + Downloads the zipped data set specified using the manifest URL, saving it to the + data directory specified by ``get_data_home``. The download is verified with + the given signature then extracted. + """ if sample and not name.endswith("-sample"): name = name + "-sample" diff --git a/construe/datasets/path.py b/construe/datasets/path.py index 24d6923..f1c5770 100644 --- a/construe/datasets/path.py +++ b/construe/datasets/path.py @@ -55,8 +55,8 @@ def find_dataset_path(dataset, data_home=None, fname=None, ext=None, raises=True """ Looks up the path to the dataset specified in the data home directory, which is found using the ``get_data_home`` function. By default data home - is colocated with the code, but can be modified with the CONSTRUE_DATA - environment variable, or passing in a different directory. + is in a config directory in the user's home folder, but can be modified with the + $CONSTRUE_DATA environment variable, or passing in a different directory. If the dataset is not found a ``DatasetsError`` is raised by default. """ @@ -80,9 +80,7 @@ def find_dataset_path(dataset, data_home=None, fname=None, ext=None, raises=True return None raise DatasetsError( - ("could not find dataset at {} - does it need to be downloaded?").format( - path - ) + f"could not find dataset at {path} - does it need to be downloaded?" ) return path diff --git a/construe/exceptions.py b/construe/exceptions.py index cacc4c8..5790dd9 100644 --- a/construe/exceptions.py +++ b/construe/exceptions.py @@ -11,10 +11,22 @@ class ConstrueError(ClickException): pass +class DownloadError(ConstrueError): + pass + + +class UploadError(ConstrueError): + pass + + class DatasetsError(ConstrueError): pass +class ModelsError(ConstrueError): + pass + + class DeviceError(ConstrueError): def __init__(self, e): diff --git a/construe/models/download.py b/construe/models/download.py new file mode 100644 index 0000000..4bffba4 --- /dev/null +++ b/construe/models/download.py @@ -0,0 +1,59 @@ +""" +Handle downloading models from content URLs +""" + +from functools import partial + +from .path import get_models_home +from .manifest import load_manifest +from ..cloud.download import download_zip +from .path import NSFW, LOWLIGHT, OFFENSIVE, GLINER +from .path import MOONDREAM, WHISPER, MOBILENET, MOBILEVIT + +from ..exceptions import ModelsError + + +def download_model(url, signature, model_home=None, replace=False, extract=True): + """ + Downloads the zipped model file specified at the given URL saving it to the models + directory specified by ``get_models_home``. The download is verified with the + given signature then extracted. + """ + model_home = get_models_home(model_home) + download_zip(url, model_home, signature=signature, replace=replace, extract=extract) + + +def _download_model(name, model_home=None, replace=False, extract=True): + """ + Downloads the zipped model file specified using the manifest URL, saving it to the + models directory specified by ``get_models_home``. The download is verified with + the given signature then extracted. + """ + models = load_manifest() + if name not in models: + raise ModelsError(f"no model named {name} exists") + + info = models[name] + info.update({"model_home": model_home, "replace": replace, "extract": extract}) + download_model(**info) + + +download_moondream = partial(_download_model, MOONDREAM) +download_whisper = partial(_download_model, WHISPER) +download_mobilenet = partial(_download_model, MOBILENET) +download_mobilevit = partial(_download_model, MOBILEVIT) +download_nsfw = partial(_download_model, NSFW) +download_lowlight = partial(_download_model, LOWLIGHT) +download_offensive = partial(_download_model, OFFENSIVE) +download_gliner = partial(_download_model, GLINER) + + +DOWNLOADERS = [ + download_moondream, download_whisper, download_mobilenet, download_mobilevit, + download_nsfw, download_lowlight, download_offensive, download_gliner, +] + + +def download_all_models(model_home=None, replace=True, extract=True): + for f in DOWNLOADERS: + f(model_home=model_home, replace=replace, extract=extract) diff --git a/construe/models/path.py b/construe/models/path.py index 2ecaaaf..8bad6d0 100644 --- a/construe/models/path.py +++ b/construe/models/path.py @@ -3,8 +3,11 @@ """ import os +import shutil from pathlib import Path +from ..cloud.signature import sha256sum +from construe.exceptions import ModelsError # Fixtures is where model data being prepared is stored @@ -14,18 +17,27 @@ # Models dir is the location of downloaded model files MODELSDIR = Path.home() / ".construe" / "models" +# Names of the models +MOONDREAM = "moondream" +WHISPER = "whisper" +MOBILENET = "mobilenet" +MOBILEVIT = "mobilevit" +NSFW = "nsfw" +LOWLIGHT = "lowlight" +OFFENSIVE = "offensive" +GLINER = "gliner" + def get_models_home(path=None): """ Return the path of the Construe models directory. This folder is used by model loaders to avoid downloading model parameters several times. - By default, this folder is colocated with the code in the install directory - so that data shipped with the package can be easily located. Alternatively - it can be set by the ``$CONSTRUE_DATA`` environment variable, or - programmatically by giving a folder path. Note that the ``'~'`` symbol is - expanded to the user home directory, and environment variables are also - expanded when resolving the path. + By default, this folder is in a config directory in the users home folderso the + model can be can be easily located. Alternatively it can be set by the + ``$CONSTRUE_MODELS`` environment variable, or programmatically by giving a folder + path. Note that the ``'~'`` symbol is expanded to the user home directory, and + environment variables are also expanded when resolving the path. """ if path is None: path = os.environ.get("CONSTRUE_MODELS", MODELSDIR) @@ -37,3 +49,82 @@ def get_models_home(path=None): os.makedirs(path) return path + + +def find_model_path(model, models_home=None, fname=None, ext=None, raises=True): + """ + Looks up the path to the model specified in the models home directory. The storage + location of the models can be set with the $CONSTRUE_MODELS environment variable. + + If the model is not found a ``ModelsError`` is raised by default. + """ + # Resolve the root directory that stores the models + models_home = get_models_home(models_home) + + # Determine the path to the model + if fname is None: + if ext is None: + path = os.path.join(models_home, model) + else: + path = os.path.join(models_home, model, "{}{}".format(model, ext)) + else: + path = os.path.join(models_home, model, fname) + + if not os.path.exists(path): + if not raises: + return None + + raise ModelsError( + f"could not find model at {path} - does it need to be downloaded?" + ) + + return path + + +def model_exists(model, model_home=None, fname=None, ext=None): + """ + Checks to see if the specified model exists in the model home directory. + """ + path = find_model_path(model, model_home, fname, ext, False) + if path is not None: + return os.path.exists(path) + return False + + +def model_tflite_exists(model, model_home): + """ + Checks to see if the model .tflite file exists or not. + """ + return model_exists(model, model_home=model_home, ext=".tflite") + + +def model_archive(model, signature, model_home=None, ext=".zip"): + """ + Checks to see if the model archive file exists and determines if it is the latest + version by comparing the signature specified with the archive signature. + """ + model_home = get_models_home(model_home) + path = os.path.join(model_home, model+ext) + + if os.path.exists(path) and os.path.isfile(path): + return sha256sum(path) == signature + return False + + +def cleanup_model(model, model_home=None, archive=".zip"): + removed = 0 + model_home = get_models_home(model_home) + + # Paths to remove + datadir = os.path.join(model_home, model) + archive = os.path.join(model_home, model+archive) + + if os.path.exists(datadir): + shutil.rmtree(datadir) + removed += 1 + + if os.path.exists(archive): + os.remove(archive) + removed += 1 + + return removed