Skip to content

Commit

Permalink
fix: added download code for falcon
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 committed Oct 8, 2023
1 parent 6273298 commit 144c930
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,22 @@ import (
"sync"
)

func downloadFile(fp string, url string, token string, wg *sync.WaitGroup) {
const (
PublicLink = "public"
PrivateLink = "private"
DownloadFolder = "weights"
)

func getFilenameFromURL(url string) string {
return filepath.Base(url)
}

func downloadFile(folderPath string, url string, token string, wg *sync.WaitGroup) {
defer wg.Done()

fileName := getFilenameFromURL(url)
fp := filepath.Join(folderPath, fileName)

// Create the file
out, err := os.Create(fp)
if err != nil {
Expand All @@ -25,8 +38,10 @@ func downloadFile(fp string, url string, token string, wg *sync.WaitGroup) {
if err != nil {
log.Fatal(err)
}
// Add token to request header
req.Header.Add("Authorization", "Bearer "+token)
// If token is provided, add to request header
if token != "" {
req.Header.Add("Authorization", "Bearer "+token)
}

// Execute the request
client := &http.Client{}
Expand Down Expand Up @@ -72,7 +87,44 @@ func (wc *WriteCounter) Write(p []byte) (int, error) {
return n, nil
}

func getURLsForModel(baseURL, modelVersion string) []string {
func falconCommonURLs(modelVersion string) []string {
return []string{
fmt.Sprintf("https://huggingface.co/tiiuae/%s/raw/main/config.json", modelVersion),
fmt.Sprintf("https://huggingface.co/tiiuae/%s/raw/main/pytorch_model.bin.index.json", modelVersion),
fmt.Sprintf("https://huggingface.co/tiiuae/%s/raw/main/tokenizer.json", modelVersion),
fmt.Sprintf("https://huggingface.co/tiiuae/%s/raw/main/tokenizer_config.json", modelVersion),
fmt.Sprintf("https://huggingface.co/tiiuae/%s/raw/main/special_tokens_map.json", modelVersion),
fmt.Sprintf("https://huggingface.co/tiiuae/%s/raw/main/configuration_falcon.py", modelVersion),
fmt.Sprintf("https://huggingface.co/tiiuae/%s/raw/main/generation_config.json", modelVersion),
fmt.Sprintf("https://huggingface.co/tiiuae/%s/raw/main/modeling_falcon.py", modelVersion),
}
}

func falconModelURLs(modelVersion string, count int) (urls []string) {
for i := 1; i <= count; i++ {
url := fmt.Sprintf("https://huggingface.co/tiiuae/%s/resolve/main/pytorch_model-%05d-of-%05d.bin", modelVersion, i, count)
urls = append(urls, url)
}
return
}

func getURLsForModel(linkType, baseURL, modelVersion string) []string {
if linkType == PublicLink {
switch modelVersion {
case "falcon-7b", "falcon-7b-instruct":
return append(falconModelURLs(modelVersion, 2), falconCommonURLs(modelVersion)...)
case "falcon-40b", "falcon-40b-instruct":
return append(falconModelURLs(modelVersion, 9), falconCommonURLs(modelVersion)...)
default:
log.Fatalf("Invalid model version for public link: %s", modelVersion)
return nil
}
} else {
return getPrivateURLsForModel(baseURL, modelVersion)
}
}

func getPrivateURLsForModel(baseURL, modelVersion string) []string {
switch modelVersion {
case "llama-2-7b":
return []string{
Expand Down Expand Up @@ -114,8 +166,9 @@ func getURLsForModel(baseURL, modelVersion string) []string {
baseURL + "llama-2-70b-chat/consolidated.06.pth",
baseURL + "llama-2-70b-chat/consolidated.07.pth",
}

default:
log.Fatalf("Invalid model version: %s", modelVersion)
log.Fatalf("Invalid model version for private link: %s", modelVersion)
return nil
}
}
Expand All @@ -130,29 +183,32 @@ func ensureDirExists(dirName string) {
}

func main() {
if len(os.Args) != 4 {
log.Fatalf("Usage: %s <model_version> <external_IP> <external_port>", os.Args[0])
}

token := os.Getenv("AUTH_TOKEN_ENV_VAR")
if token == "" {
log.Fatal("AUTH_TOKEN_ENV_VAR not set!")
if len(os.Args) < 3 {
log.Fatalf("Usage: %s <link_type> <model_version> [external_IP] [external_port]", os.Args[0])
}
externalIP := os.Args[2]
externalPort := os.Args[3]
baseURL := "http://" + externalIP + ":" + externalPort + "/download/"

ensureDirExists("weights")
linkType := os.Args[1]
modelVersion := os.Args[2]
ensureDirExists(DownloadFolder)

modelVersion := os.Args[1]
urls := getURLsForModel(baseURL, modelVersion)
token := ""
baseURL := ""
if linkType == PrivateLink {
token = os.Getenv("AUTH_TOKEN_ENV_VAR")
if token == "" {
log.Fatal("AUTH_TOKEN_ENV_VAR not set!")
}
externalIP := os.Args[2]
externalPort := os.Args[3]
baseURL = "http://" + externalIP + ":" + externalPort + "/download/"
}

urls := getURLsForModel(linkType, baseURL, modelVersion)
var wg sync.WaitGroup

for i, url := range urls {
fp := fmt.Sprintf("weights/consolidated.%02d.pth", i)
for _, url := range urls {
wg.Add(1)
go downloadFile(fp, url, token, &wg)
go downloadFile(DownloadFolder, url, token, &wg)
}

wg.Wait()
Expand Down
10 changes: 5 additions & 5 deletions docker/presets/falcon/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ FROM nvcr.io/nvidia/pytorch:23.06-py3
# Set the working directory
WORKDIR /workspace/falcon

# Install Go
RUN apt-get update && apt-get install -y --no-install-recommends golang-go

# First, copy just the requirements.txt file and install dependencies
# This is done before copying the code to utilize Docker's layer caching and
# avoid reinstalling dependencies unless the requirements file changes.
Expand All @@ -12,11 +15,8 @@ RUN pip install --no-cache-dir -r requirements.txt

ARG FALCON_MODEL_NAME

# Download the model and tokenizer
RUN python3 -c "from transformers import AutoModelForCausalLM, AutoTokenizer; \
model_name = '${FALCON_MODEL_NAME}'; \
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True); \
AutoTokenizer.from_pretrained(model_name)"
COPY docker/presets/download_script.go /workspace/download_script.go
RUN go run /workspace/download_script.go "public" ${FALCON_MODEL_NAME}

# Copy the entire 'presets/falcon' folder to the working directory
COPY pkg/presets/falcon /workspace/falcon
4 changes: 2 additions & 2 deletions docker/presets/llama-2/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ ARG WEB_SERVER_AUTH_TOKEN
ENV AUTH_TOKEN_ENV_VAR=${WEB_SERVER_AUTH_TOKEN}

# Copy Go download script into the Docker image
COPY docker/presets/llama-2/download_script.go /workspace/download_script.go
COPY docker/presets/download_script.go /workspace/download_script.go

# Use Go download script to fetch model weights
RUN go run /workspace/download_script.go ${LLAMA_VERSION} ${EXTERNAL_IP} ${EXTERNAL_PORT}
RUN go run /workspace/download_script.go "private" ${LLAMA_VERSION} ${EXTERNAL_IP} ${EXTERNAL_PORT}

ADD ${SRC_DIR} /workspace/llama/llama-2

0 comments on commit 144c930

Please sign in to comment.