Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add support for OLMo architecture #3046

Merged
merged 8 commits into from
Dec 14, 2024
Merged

Conversation

Lanssi
Copy link
Contributor

@Lanssi Lanssi commented Nov 24, 2024

This PR add support for OLMo architecture.

Additional support: add support for clip-qkv.

Test: already tested on android(pixel 4) and cuda(setting tensor_parallel_shrads=2)

Test model: amd/AMD-OLMo-1B(without clip_qkv) and allenai/OLMo-1B-0724-hf(with clip_qkv).
However, generation quality of the latter one is not so good as expected even though I've tried different implementation of the clip_qkv mechanism, e.g. te.compute and nn.maximum/minimum.
And finally, I checked the doc and following one is the most simplified:

if self.clip_qkv is not None:
        qkv = qkv.maximum(-self.clip_qkv).minimum(self.clip_qkv)

But still the result isn't good enough.

This is output from CLI:

/AMD-OLMo-1B-SFT-q4f16_1-cuda.so --device cuda --overrides "tensor_parallel_shards=2"
[2024-11-24 09:24:04] INFO auto_device.py:79: Found device: cuda:0
[2024-11-24 09:24:04] INFO engine_base.py:143: Using library model: ./dist/libs/AMD-OLMo-1B-SFT-q4f16_1-cuda.so
[09:24:04] /workspace/mlc-llm/cpp/serve/config.cc:688: Under mode "local", max batch size will be set to 4, max KV cache token capacity will be set to 2048, prefill chunk size will be set to 2048. 
[09:24:04] /workspace/mlc-llm/cpp/serve/config.cc:688: Under mode "interactive", max batch size will be set to 1, max KV cache token capacity will be set to 2048, prefill chunk size will be set to 2048. 
[09:24:04] /workspace/mlc-llm/cpp/serve/config.cc:688: Under mode "server", max batch size will be set to 128, max KV cache token capacity will be set to 11593, prefill chunk size will be set to 2048. 
[09:24:04] /workspace/mlc-llm/cpp/serve/config.cc:769: The actual engine mode is "interactive". So max batch size is 1, max KV cache token capacity is 2048, prefill chunk size is 2048.
[09:24:04] /workspace/mlc-llm/cpp/serve/config.cc:774: Estimated total single GPU memory usage: 2026.319 MB (Parameters: 631.266 MB. KVCache: 336.268 MB. Temporary buffer: 1058.785 MB). The actual usage might be slightly larger than the estimated number.
You can use the following special commands:
  /help               print the special commands
  /exit               quit the cli
  /stats              print out stats of last request (token/sec)
  /metrics            print out full engine metrics
  /reset              restart a fresh chat
  /set [overrides]    override settings in the generation config. For example,
                      `/set temperature=0.5;top_p=0.8;seed=23;max_tokens=100;stop=str1,str2`
                      Note: Separate stop words in the `stop` option with commas (,).
  Multi-line input: Use escape+enter to start a new line.

>>> what is the result of 1 + 1?
The result is 2.
>>>

And this is the output from Andorid(pixel 4):
Screenshot_20241016-021653

Please note that this is my first PR. If I got something missed, please point it out. Thanks!

@tlopex
Copy link
Contributor

tlopex commented Nov 24, 2024

@Lanssi Thanks for your contribution! I’ll take a look at your code once it passes CI.

@Lanssi
Copy link
Contributor Author

Lanssi commented Nov 24, 2024

@Lanssi Thanks for your contribution! I’ll take a look at your code once it passes CI.

Yes! Thanks!

@tlopex tlopex self-assigned this Nov 24, 2024
@tlopex tlopex self-requested a review November 24, 2024 14:33
@Lanssi
Copy link
Contributor Author

Lanssi commented Nov 27, 2024

@tlopex The following content is for supplement.
I'm curious about the reason for the low generation quality problem(happens in allenai/OLMo-1B-0724-hf. the AMD one looks fine). And the following are my guess:

  1. quality of the model itself
  2. quantization degraded the generation quality
  3. some went wrong when converting the weight
  4. wrong in the model script, specifically the clip_qkv mechanism
  5. bad conv-template

And I did some tests.


  1. Concerning the first point, I use the official demo in https://huggingface.co/allenai/OLMo-1B-0724-hf.
    input What is the answer of 1+1? and output A: 2\n\nQ: What is the answer to 2+2+2+
    input <|endoftext|>What is the answer of 1+1?<|endoftext|> and output Premise: "The man is wearing a black shirt and is holding a guitar." If this premise

  2. Concerning the forth point, I tested the nn.Tensor.maximum/minimum. This is the script:

import tvm
from tvm import relax
from tvm.relax.frontend import nn
import numpy as np

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.clip_qkv = 0.5

    def forward(self, a: nn.Tensor):
        return a.maximum(-self.clip_qkv).minimum(self.clip_qkv)
    
MySpec = {
    "forward": {
        "a": nn.spec.Tensor([10, 10], dtype="float32"),
    }
}

mod, _ = MyModule().export_tvm(spec=MySpec)

target = tvm.target.Target("llvm")
ex = relax.build(mod, target)
device = tvm.cpu()
vm = relax.VirtualMachine(ex, device)

# input data
input = np.random.randn(10, 10).astype("float32")
tvm_input = tvm.nd.array(input, device=device)

output = vm["forward"](tvm_input).numpy()

print(input)
print('\n')
print(output)

And this is the result:

[[ 0.8002662   1.3287255   0.6966691   0.01359426  1.4751362  -2.1837575
   0.31622827  0.29916984  1.6361551  -0.5623241 ]
 [-0.06786224  0.7354883  -1.1480126   1.5657566   0.70848846 -1.8228586
   1.6091303  -0.49641845 -1.9719931   0.03000641]
 [ 0.24146347 -0.37006825  0.6260457  -0.650811    2.9995894   0.01520522
   0.37607008  2.2640219  -0.15675819 -1.4302232 ]
 [ 1.277718   -0.25887656  1.8693565   1.3728012  -1.1130005   0.57336384
  -0.6806153   1.4629978  -0.06308465  1.2775687 ]
 [ 1.1519296   1.1705179   0.5521642   0.5356738  -0.02438113 -3.465944
   1.1917357   0.85313046 -0.41413978 -1.0877508 ]
 [ 0.584423    1.5158817  -0.20225888 -0.50226486 -0.8140207  -0.8861371
  -1.6989475  -0.44207108  0.9688998   1.9392946 ]
 [ 0.2277535  -1.3662196   0.34696814 -0.48505998  0.09624723  1.2511327
  -0.7616591   1.414671    0.5722716   0.67328775]
 [-0.8591674   0.81546694  0.8905721   0.8360764   0.22661494 -0.591125
   1.121892   -0.50307155 -0.6415272  -0.21494472]
 [-0.3014515   2.2457047   0.53163207 -2.7661052   1.279993    0.3640672
   0.9343894  -1.8342396   1.0801686   0.14708626]
 [ 1.0259734   0.63859665  1.3871721  -0.27964365  2.4129944   0.54449517
  -0.3690681  -1.0342233   0.9981198   0.7621166 ]]


[[ 0.5         0.5         0.5         0.01359426  0.5        -0.5
   0.31622827  0.29916984  0.5        -0.5       ]
 [-0.06786224  0.5        -0.5         0.5         0.5        -0.5
   0.5        -0.49641845 -0.5         0.03000641]
 [ 0.24146347 -0.37006825  0.5        -0.5         0.5         0.01520522
   0.37607008  0.5        -0.15675819 -0.5       ]
 [ 0.5        -0.25887656  0.5         0.5        -0.5         0.5
  -0.5         0.5        -0.06308465  0.5       ]
 [ 0.5         0.5         0.5         0.5        -0.02438113 -0.5
   0.5         0.5        -0.41413978 -0.5       ]
 [ 0.5         0.5        -0.20225888 -0.5        -0.5        -0.5
  -0.5        -0.44207108  0.5         0.5       ]
 [ 0.2277535  -0.5         0.34696814 -0.48505998  0.09624723  0.5
  -0.5         0.5         0.5         0.5       ]
 [-0.5         0.5         0.5         0.5         0.22661494 -0.5
   0.5        -0.5        -0.5        -0.21494472]
 [-0.3014515   0.5         0.5        -0.5         0.5         0.3640672
   0.5        -0.5         0.5         0.14708626]
 [ 0.5         0.5         0.5        -0.27964365  0.5         0.5
  -0.3690681  -0.5         0.5         0.5       ]]

It seems to be running well.

  1. Concerning the fifth point, I tried to modified the conv-template in "mlc_chat_config.json" by removing the system_prefix_token_ids or setting add_role_after_system_message to false, respectively. But the generation quality is not good.

Maybe we can further test the 7B variant(link: https://huggingface.co/allenai/OLMo-7B-0724-hf). But for the moment I can't meet the hardware requirement to test it.

@tlopex
Copy link
Contributor

tlopex commented Nov 27, 2024

@Lanssi Thank your so much for your additional supplement and testing!

First, about your fourth point , the implement of clip-qkv, I think using nn.Tensor.maximum/minimum is a great way. So I don't consider that the bad quality is caused by it.

Second, could you please tell me which way you chose to quantize the model?

Besides, I can help test the 7B variant if needed.
and cc @MasterJH5574

@Lanssi
Copy link
Contributor Author

Lanssi commented Nov 27, 2024

@Lanssi Thank your so much for your additional supplement and testing!

First, about your fourth point , the implement of clip-qkv, I think using nn.Tensor.maximum/minimum is a great way. So I don't consider that the bad quality is caused by it.

Second, could you please tell me which way you chose to quantize the model?

Besides, I can help test the 7B variant if needed. and cc @MasterJH5574

Sure. I tested q0f16 merely on android, q4f16_1 and q4f32_1 both on android and cuda. And, thanks for your help!

@tlopex
Copy link
Contributor

tlopex commented Dec 13, 2024

Overall it looks great to me! I tested it on my CUDA device with q4f16_1 and it worked well.
But I am not sure whether we can support awqquant and per_tensor_quant for this model.
Please have a look at it.
cc @MasterJH5574

@Lanssi
Copy link
Contributor Author

Lanssi commented Dec 13, 2024

Overall it looks great to me! I tested it on my CUDA device with q4f16_1 and it worked well. But I am not sure whether we can support awqquant and per_tensor_quant for this model. Please have a look at it. cc @MasterJH5574

@tlopex Thanks for reviewing my code. I will check it.

Copy link
Member

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Lanssi for contributing! We can remove some quantizations in follow-up PRs if they are not supported.

@MasterJH5574 MasterJH5574 merged commit 385cef2 into mlc-ai:main Dec 14, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants