Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue setting huggingface.prompt_builder = 'llama2' when using sagemaker as client #31

Open
bcarsley opened this issue Aug 17, 2023 · 2 comments

Comments

@bcarsley
Copy link

So I'm building a class that can alternate between both the huggingface and sagemaker clients and I declare all my os.environs at the top of the class like so:

os.environ["AWS_ACCESS_KEY_ID"] = "<key_id>"
os.environ["AWS_SECRET_ACCESS_KEY"] = '<access_key>'
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
os.environ["HF_TOKEN"] = "<hf_token>"
os.environ["HUGGINGFACE_PROMPT"] = "llama2"

and even later on in the class, just to be sure, I declare huggingface.prompt_builder = 'llama2'
tried importing build_llama2_prompt directly and passing it as a callable, that also didn't work
tried setting sagemaker.prompt_builder = 'llama2' just for fun to see if that would do anything...nope

Still get the warning telling me I haven't set a prompt builder, which is kinda weird, plus it's clear that occasionally the prompt is being formatted a bit weirdly (because the same prompt formatted as in the example below when passed directly to the sagemaker endpoint yields a somewhat better response from the same endpoint)

it's nbd that this doesn't work super well for me, I might just be being stupid about it, below is how I've just worked around it by manually implementing w/ sagemaker's HuggingFacePredictor cls:

llm = sagemaker.huggingface.model.HuggingFacePredictor('llama-party', sess)
def build_llama2_prompt(messages):
    startPrompt = "<s>[INST] "
    endPrompt = " [/INST]"
    conversation = []
    for index, message in enumerate(messages):
        if message["role"] == "system" and index == 0:
            conversation.append(f"<<SYS>>\n{message['content']}\n<</SYS>>\n\n")
        elif message["role"] == "user":
            conversation.append(message["content"].strip())
        else:
            conversation.append(f" [/INST] {message.content}</s><s>[INST] ")

    return startPrompt + "".join(conversation) + endPrompt

prompt = build_llama2_prompt(messages)

payload = {
  "inputs":  prompt,
  "parameters": {
    "do_sample": True,
    "top_p": 0.6,
    "temperature": 0.9,
    "top_k": 50,
    "max_new_tokens": 512,
    "repetition_penalty": 1.03,
    "stop": ["</s>"]
  }
}

chat = llm.predict(payload)

print(chat[0]["generated_text"][len(prompt):])

this code was pretty much fully taken from the sagemaker llama deployment blog post here: https://www.philschmid.de/sagemaker-llama-llm

works fine, just don't know why the same code doesn't work right inside of the lib (easyllm)

@philschmid
Copy link
Owner

Can you check out this example for sagemaker? and see if it works? https://philschmid.github.io/easyllm/examples/sagemaker-chat-completion-api/#1-import-the-easyllm-library

@murdadesmaeeli
Copy link

Hi @bcarsley, were you able to resolve this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants