Skip to content

Commit

Permalink
chore: refactor llama package model_download script
Browse files Browse the repository at this point in the history
  • Loading branch information
YrrepNoj committed Jul 25, 2024
1 parent 80327c3 commit 41cb87b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
41 changes: 29 additions & 12 deletions packages/llama-cpp-python/scripts/model_download.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
import os
import hashlib
import urllib.request

from huggingface_hub import hf_hub_download
REPO_ID = os.environ.get("REPO_ID", "")
FILENAME = os.environ.get("FILENAME", "")
REVISION = os.environ.get("REVISION", "main")
CHECKSUM = os.environ.get("SHA256_CHECKSUM", "")
OUTPUT_FILE = os.environ.get("OUTPUT_FILE", ".model/model.gguf")

REPO_ID = os.environ.get("REPO_ID", "TheBloke/SynthIA-7B-v2.0-GGUF")
FILENAME = os.environ.get("FILENAME", "synthia-7b-v2.0.Q4_K_M.gguf")
REVISION = os.environ.get("REVISION", "3f65d882253d1f15a113dabf473a7c02a004d2b5")

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def download_model():
# Check if the model is already downloaded.
if os.path.exists(OUTPUT_FILE) and CHECKSUM != "":
if hashlib.sha256(open(OUTPUT_FILE, "rb").read()).hexdigest() == CHECKSUM:
print("Model already downloaded.")
return

hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
local_dir=".model",
local_dir_use_symlinks=False,
revision=REVISION,
)
# Validate that require environment variables are provided
if REPO_ID == "" or FILENAME == "":
print("Please provide REPO_ID and FILENAME environment variables.")
return

# Download the model!
print("Downloading model... This may take a while.")
if not os.path.exists(".model"):
os.mkdir(".model")
urllib.request.urlretrieve(
f"https://huggingface.co/{REPO_ID}/resolve/{REVISION}/{FILENAME}", OUTPUT_FILE
)


if __name__ == "__main__":
download_model()
11 changes: 7 additions & 4 deletions packages/llama-cpp-python/zarf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ components:
actions:
onCreate:
before:
# TODO: Lets see if this can check the checksum of the model.gguf to shortcut this download..
- cmd: |
mkdir -p .model
wget https://huggingface.co/TheBloke/SynthIA-7B-v2.0-GGUF/resolve/main/synthia-7b-v2.0.Q4_K_M.gguf -q -O .model/model.gguf
# NOTE: This assumes python is installed and in $PATH
- cmd: python scripts/model_download.py
env:
- REPO_ID=TheBloke/SynthIA-7B-v2.0-GGUF
- FILENAME=synthia-7b-v2.0.Q4_K_M.gguf
- REVISION=3f65d882253d1f15a113dabf473a7c02a004d2b5
- SHA256_CHECKSUM=5d6369d456446c40a9fd149525747d8dc494196686861c43b00f9230a166ba82

0 comments on commit 41cb87b

Please sign in to comment.