From 41cb87bea3d50f782b2ff8b8bef2f48ac1a66d8c Mon Sep 17 00:00:00 2001 From: Jon Perry Date: Tue, 16 Jul 2024 17:04:54 -0400 Subject: [PATCH] chore: refactor llama package model_download script --- .../scripts/model_download.py | 41 +++++++++++++------ packages/llama-cpp-python/zarf.yaml | 11 +++-- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/packages/llama-cpp-python/scripts/model_download.py b/packages/llama-cpp-python/scripts/model_download.py index 7a41889429..a5cc79e52b 100644 --- a/packages/llama-cpp-python/scripts/model_download.py +++ b/packages/llama-cpp-python/scripts/model_download.py @@ -1,17 +1,34 @@ import os +import hashlib +import urllib.request -from huggingface_hub import hf_hub_download +REPO_ID = os.environ.get("REPO_ID", "") +FILENAME = os.environ.get("FILENAME", "") +REVISION = os.environ.get("REVISION", "main") +CHECKSUM = os.environ.get("SHA256_CHECKSUM", "") +OUTPUT_FILE = os.environ.get("OUTPUT_FILE", ".model/model.gguf") -REPO_ID = os.environ.get("REPO_ID", "TheBloke/SynthIA-7B-v2.0-GGUF") -FILENAME = os.environ.get("FILENAME", "synthia-7b-v2.0.Q4_K_M.gguf") -REVISION = os.environ.get("REVISION", "3f65d882253d1f15a113dabf473a7c02a004d2b5") -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +def download_model(): + # Check if the model is already downloaded. + if os.path.exists(OUTPUT_FILE) and CHECKSUM != "": + if hashlib.sha256(open(OUTPUT_FILE, "rb").read()).hexdigest() == CHECKSUM: + print("Model already downloaded.") + return -hf_hub_download( - repo_id=REPO_ID, - filename=FILENAME, - local_dir=".model", - local_dir_use_symlinks=False, - revision=REVISION, -) + # Validate that require environment variables are provided + if REPO_ID == "" or FILENAME == "": + print("Please provide REPO_ID and FILENAME environment variables.") + return + + # Download the model! + print("Downloading model... This may take a while.") + if not os.path.exists(".model"): + os.mkdir(".model") + urllib.request.urlretrieve( + f"https://huggingface.co/{REPO_ID}/resolve/{REVISION}/{FILENAME}", OUTPUT_FILE + ) + + +if __name__ == "__main__": + download_model() diff --git a/packages/llama-cpp-python/zarf.yaml b/packages/llama-cpp-python/zarf.yaml index d656b7a0b9..1f5a107873 100644 --- a/packages/llama-cpp-python/zarf.yaml +++ b/packages/llama-cpp-python/zarf.yaml @@ -38,7 +38,10 @@ components: actions: onCreate: before: - # TODO: Lets see if this can check the checksum of the model.gguf to shortcut this download.. - - cmd: | - mkdir -p .model - wget https://huggingface.co/TheBloke/SynthIA-7B-v2.0-GGUF/resolve/main/synthia-7b-v2.0.Q4_K_M.gguf -q -O .model/model.gguf + # NOTE: This assumes python is installed and in $PATH + - cmd: python scripts/model_download.py + env: + - REPO_ID=TheBloke/SynthIA-7B-v2.0-GGUF + - FILENAME=synthia-7b-v2.0.Q4_K_M.gguf + - REVISION=3f65d882253d1f15a113dabf473a7c02a004d2b5 + - SHA256_CHECKSUM=5d6369d456446c40a9fd149525747d8dc494196686861c43b00f9230a166ba82