From bcaa8a3689e6ae2e84ec6b57e995ee8a7904a19e Mon Sep 17 00:00:00 2001 From: Casper Date: Thu, 15 Feb 2024 16:49:21 +0100 Subject: [PATCH] v0.2.0 (#330) Co-authored-by: jinz2014 <7799920+jinz2014@users.noreply.github.com> Co-authored-by: Jin Z <5zj@cousteau.ftpn.ornl.gov> --- .github/workflows/build.yaml | 2 +- .github/workflows/docs.yaml | 28 ++ README.md | 31 +- awq/__init__.py | 4 +- awq/evaluation/__init__.py | 2 +- awq/evaluation/eval_utils.py | 65 ++- awq/evaluation/humaneval_utils.py | 111 +++-- awq/evaluation/kl_divergence.py | 70 ++- awq/models/_config.py | 50 ++- awq/models/aquila.py | 121 ++--- awq/models/auto.py | 16 +- awq/models/baichuan.py | 110 ++--- awq/models/base.py | 232 +++++++--- awq/models/bloom.py | 57 ++- awq/models/falcon.py | 108 +++-- awq/models/gpt_bigcode.py | 56 ++- awq/models/gpt_neox.py | 50 ++- awq/models/gptj.py | 59 ++- awq/models/llama.py | 121 ++--- awq/models/llava.py | 123 +++--- awq/models/mistral.py | 121 ++--- awq/models/mixtral.py | 180 +++++--- awq/models/mpt.py | 100 +++-- awq/models/opt.py | 77 ++-- awq/models/qwen.py | 2 +- awq/models/qwen2.py | 104 +++-- awq/models/yi.py | 118 ++--- awq/modules/act.py | 3 +- awq/modules/fused/attn.py | 119 +++-- awq/modules/fused/block.py | 192 ++++++-- awq/modules/fused/cache.py | 35 +- awq/modules/fused/mlp.py | 30 +- awq/modules/fused/model.py | 104 +++-- awq/modules/fused/moe.py | 461 ++++++++++++++++++++ awq/modules/linear/__init__.py | 2 +- awq/modules/linear/gemm.py | 43 +- awq/modules/linear/marlin.py | 2 +- awq/quantize/quantizer.py | 16 +- awq/quantize/scale.py | 56 ++- awq/utils/calib_data.py | 28 +- awq/utils/fused_utils.py | 161 +++++-- awq/utils/module.py | 9 +- awq/utils/packing_utils.py | 3 +- awq/utils/parallel.py | 3 +- awq/utils/quant_utils.py | 2 +- awq/utils/utils.py | 52 ++- docs/examples.md | 318 ++++++++++++++ docs/index.md | 51 +++ docs/reference/index.md | 6 + examples/README.md | 4 + examples/awq_to_gguf_quant.py | 48 -- examples/basic_transformers.py | 30 -- examples/basic_vllm.py | 56 --- examples/benchmark.py | 12 +- examples/exllama_generate.py | 28 -- examples/{basic_generate.py => generate.py} | 0 examples/llava_generate.py | 25 -- examples/llava_quant.py | 23 - examples/marlin_generate.py | 33 -- examples/marlin_quant.py | 22 - examples/mixtral_quant.py | 30 -- examples/quant_custom_data.py | 35 -- examples/{basic_quant.py => quantize.py} | 5 +- examples/tinyllama_generate.py | 36 -- examples/{awq_train.py => train.py} | 6 +- mkdocs.yml | 82 ++++ setup.py | 8 +- 67 files changed, 2819 insertions(+), 1478 deletions(-) create mode 100644 .github/workflows/docs.yaml create mode 100644 awq/modules/fused/moe.py create mode 100644 docs/examples.md create mode 100644 docs/index.md create mode 100644 docs/reference/index.md create mode 100644 examples/README.md delete mode 100644 examples/awq_to_gguf_quant.py delete mode 100644 examples/basic_transformers.py delete mode 100644 examples/basic_vllm.py delete mode 100644 examples/exllama_generate.py rename examples/{basic_generate.py => generate.py} (100%) delete mode 100644 examples/llava_generate.py delete mode 100644 examples/llava_quant.py delete mode 100644 examples/marlin_generate.py delete mode 100644 examples/marlin_quant.py delete mode 100644 examples/mixtral_quant.py delete mode 100644 examples/quant_custom_data.py rename examples/{basic_quant.py => quantize.py} (83%) delete mode 100644 examples/tinyllama_generate.py rename examples/{awq_train.py => train.py} (91%) create mode 100644 mkdocs.yml diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 382654a5..436c1d4f 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -91,7 +91,7 @@ jobs: # Install torch $cudaVersion = $env:CUDA_VERSION.Replace('.', '') $cudaVersionPytorch = $cudaVersion.Substring(0, $cudaVersion.Length - 1) - if ([int]$cudaVersionPytorch -gt 118) { $pytorchVersion = "torch==2.1.0" } else {$pytorchVersion = "torch==2.0.1"} + if ([int]$cudaVersionPytorch -gt 118) { $pytorchVersion = "torch==2.2.0" } else {$pytorchVersion = "torch==2.0.1"} python -m pip install --upgrade --no-cache-dir $pytorchVersion+cu$cudaVersionPytorch --index-url https://download.pytorch.org/whl/cu$cudaVersionPytorch python -m pip install build setuptools wheel ninja diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml new file mode 100644 index 00000000..1cec7346 --- /dev/null +++ b/.github/workflows/docs.yaml @@ -0,0 +1,28 @@ +name: Documentation +on: + push: + branches: + - main +permissions: + contents: write +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Git Credentials + run: | + git config user.name github-actions[bot] + git config user.email 41898282+github-actions[bot]@users.noreply.github.com + - uses: actions/setup-python@v4 + with: + python-version: 3.x + - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV + - uses: actions/cache@v3 + with: + key: mkdocs-material-${{ env.cache_id }} + path: .cache + restore-keys: | + mkdocs-material-docs + - run: pip install mkdocstrings-python mkdocs-material griffe-typingdoc + - run: mkdocs gh-deploy --force \ No newline at end of file diff --git a/README.md b/README.md index 9c4d1a5a..96d2cb0d 100644 --- a/README.md +++ b/README.md @@ -70,33 +70,6 @@ All three methods will install the latest and correct kernels for your system fr If your system is not supported (i.e. not on the release page), you can build the kernels yourself by following the instructions in [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases) and then install AutoAWQ from source. -## Supported models - -The detailed support list: - -| Models | Sizes | -| -------- | --------------------------- | -| LLaMA-2 | 7B/13B/70B | -| LLaMA | 7B/13B/30B/65B | -| Mistral | 7B | -| Vicuna | 7B/13B | -| MPT | 7B/30B | -| Falcon | 7B/40B | -| OPT | 125m/1.3B/2.7B/6.7B/13B/30B | -| Bloom | 560m/3B/7B/ | -| GPTJ | 6.7B | -| Aquila | 7B | -| Aquila2 | 7B/34B | -| Yi | 6B/34B | -| Qwen | 1.8B/7B/14B/72B | -| BigCode | 1B/7B/15B | -| GPT NeoX | 20B | -| GPT-J | 6B | -| LLaVa | 7B/13B | -| Mixtral | 8x7B | -| Baichuan | 7B/13B | -| QWen | 1.8B/7B/14/72B | - ## Usage Under examples, you can find examples of how to quantize, run inference, and benchmark AutoAWQ models. @@ -122,7 +95,7 @@ Fused modules are a large part of the speedup you get from AutoAWQ. The idea is - Fused modules are activated when you use `fuse_layers=True`. - A custom cache is implemented. It preallocates based on batch size and sequence length. - You cannot change the sequence length after you have created your model. - - Reference: `AutoAWQForCausalLM.from_quantized(max_new_tokens=seq_len, batch_size=batch_size)` + - Reference: `AutoAWQForCausalLM.from_quantized(max_seq_len=seq_len, batch_size=batch_size)` - The main accelerator in the fused modules comes from FasterTransformer, which is only compatible with Linux. - The `past_key_values` from `model.generate()` are only dummy values, so they cannot be used after generation. @@ -194,7 +167,7 @@ tokens = tokenizer( generation_output = model.generate( tokens, streamer=streamer, - max_new_tokens=512 + max_seq_len=512 ) ``` diff --git a/awq/__init__.py b/awq/__init__.py index c6c72583..b8a4b6a4 100644 --- a/awq/__init__.py +++ b/awq/__init__.py @@ -1,2 +1,2 @@ -__version__ = "0.1.8" -from awq.models.auto import AutoAWQForCausalLM \ No newline at end of file +__version__ = "0.2.0" +from awq.models.auto import AutoAWQForCausalLM diff --git a/awq/evaluation/__init__.py b/awq/evaluation/__init__.py index 8089ba48..a6af32b5 100644 --- a/awq/evaluation/__init__.py +++ b/awq/evaluation/__init__.py @@ -4,4 +4,4 @@ eval_mmlu, ) from awq.evaluation.humaneval_utils import eval_humaneval -from awq.evaluation.kl_divergence import eval_kl_divergence \ No newline at end of file +from awq.evaluation.kl_divergence import eval_kl_divergence diff --git a/awq/evaluation/eval_utils.py b/awq/evaluation/eval_utils.py index 04d2e86d..aa601a63 100644 --- a/awq/evaluation/eval_utils.py +++ b/awq/evaluation/eval_utils.py @@ -9,56 +9,61 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.models.whisper.english_normalizer import BasicTextNormalizer + def get_device(): if torch.backends.mps.is_available(): - return 'mps' + return "mps" elif torch.cuda.is_available(): - return 'cuda:0' + return "cuda:0" else: - return 'cpu' + return "cpu" + def evaluate_perplexity(model, tokenizer): def _perplexity(nlls, n_samples, seqlen): return torch.exp(torch.stack(nlls).sum() / (n_samples * seqlen)) - + # load and prepare dataset - data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') - data = tokenizer("\n\n".join(data['text']), return_tensors='pt') + data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + data = tokenizer("\n\n".join(data["text"]), return_tensors="pt") data = data.input_ids.to(model.device) seqlen = 2048 model = model.eval() n_samples = data.numel() // seqlen - + nlls = [] with tqdm(range(n_samples), desc="Perplexity -") as progress_bar: for i in progress_bar: - start_index = (i * seqlen) - end_index = ((i + 1) * seqlen) + start_index = i * seqlen + end_index = (i + 1) * seqlen batch = data[:, start_index:end_index].to(model.device) with torch.no_grad(): logits = model(batch).logits shift_logits = logits[:, :-1, :].contiguous().float() shift_labels = data[:, start_index:end_index][:, 1:] loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) neg_log_likelihood = loss.float() * seqlen nlls.append(neg_log_likelihood) - curr_ppl = _perplexity(nlls, i+1, seqlen) + curr_ppl = _perplexity(nlls, i + 1, seqlen) progress_bar.set_description(f"Perplexity {curr_ppl:.3f}") ppl = _perplexity(nlls, n_samples, seqlen) - + return ppl.item() + def eval_librispeech(model_id, num_samples=100, batch_size=4): try: import jiwer, librosa, soundfile except ImportError: print("Please install the following: pip install jiwer librosa soundfile") - + dataset = load_dataset("librispeech_asr", "clean", split="test", streaming=True) # Load the Whisper model pipeline for automatic speech recognition @@ -72,14 +77,15 @@ def eval_librispeech(model_id, num_samples=100, batch_size=4): # Word normalizer normalizer = BasicTextNormalizer() - + # Load the WER metric wer_metric = load_metric("wer") texts = [] audio = [] for i, data in tqdm(enumerate(dataset), total=num_samples, desc="Loading dataset"): - if len(audio) == num_samples: break + if len(audio) == num_samples: + break audio.append(data["audio"]) texts.append(data["text"]) @@ -88,8 +94,8 @@ def eval_librispeech(model_id, num_samples=100, batch_size=4): with tqdm(range(0, num_samples, batch_size), desc="Word Error Rate: -") as pbar: for i in pbar: - batch_audio = audio[i:i+batch_size] - batch_texts = texts[i:i+batch_size] + batch_audio = audio[i : i + batch_size] + batch_texts = texts[i : i + batch_size] # inference results = pipe(batch_audio, batch_size=len(batch_audio)) @@ -102,16 +108,26 @@ def eval_librispeech(model_id, num_samples=100, batch_size=4): references.extend(normalized_texts) # word error rate computation - wer = wer_metric.compute(predictions=predictions, references=references) * 100 + wer = ( + wer_metric.compute(predictions=predictions, references=references) * 100 + ) pbar.set_description(f"Word Error Rate: {wer:.3f}%") -def eval_mmlu(model_path="gpt2", num_fewshot=1, batch_size=1, device="cuda:0", task_use_pretrained=False): + +def eval_mmlu( + model_path="gpt2", + num_fewshot=1, + batch_size=1, + device="cuda:0", + task_use_pretrained=False, +): try: import vllm + VLLM_INSTALLED = True except ImportError: VLLM_INSTALLED = False - + initialize_tasks(verbosity="DEBUG") if VLLM_INSTALLED: @@ -133,12 +149,12 @@ def eval_mmlu(model_path="gpt2", num_fewshot=1, batch_size=1, device="cuda:0", t dtype="float16", trust_remote_code=True, ) - model_args = ",".join([f"{k}={v}" for k,v in model_args.items()]) + model_args = ",".join([f"{k}={v}" for k, v in model_args.items()]) results = evaluator.simple_evaluate( model=model, model_args=model_args, - tasks=['mmlu'], + tasks=["mmlu"], num_fewshot=num_fewshot, batch_size=batch_size, device=device, @@ -147,7 +163,8 @@ def eval_mmlu(model_path="gpt2", num_fewshot=1, batch_size=1, device="cuda:0", t print(evaluator.make_table(results)) -if __name__ == '__main__': + +if __name__ == "__main__": ### PERPLEXITY # model_path = 'mistralai/Mistral-7B-Instruct-v0.1' # model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") @@ -156,5 +173,5 @@ def eval_mmlu(model_path="gpt2", num_fewshot=1, batch_size=1, device="cuda:0", t ### WORD ERROR RATE # model_id = "distil-whisper/distil-small.en" # 3.594 - model_id = "distil-whisper/distil-medium.en" # 3.436 + model_id = "distil-whisper/distil-medium.en" # 3.436 eval_librispeech(model_id) diff --git a/awq/evaluation/humaneval_utils.py b/awq/evaluation/humaneval_utils.py index 4f3f3b00..73fb7f1b 100644 --- a/awq/evaluation/humaneval_utils.py +++ b/awq/evaluation/humaneval_utils.py @@ -30,25 +30,27 @@ PreTrainedTokenizer, ) + def eval_humaneval( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, out_path: str = "humaneval_out.jsonl", format_tabs: bool = True, ): - problems = {example["task_id"]: example for example in load_dataset("openai_humaneval")["test"]} - + problems = { + example["task_id"]: example + for example in load_dataset("openai_humaneval")["test"] + } + samples = [] - for i, (task_id, task) in tqdm(enumerate(problems.items()), total=len(problems) ): + for i, (task_id, task) in tqdm(enumerate(problems.items()), total=len(problems)): if format_tabs: prompt = task["prompt"].replace(" ", "\t") else: prompt = task["prompt"] - batch_completions = generate_batch_completion( - model, tokenizer, prompt, 1 - ) + batch_completions = generate_batch_completion(model, tokenizer, prompt, 1) for sample in batch_completions: result = dict( @@ -58,10 +60,10 @@ def eval_humaneval( samples += [result] - with open(out_path, 'wb') as fp: + with open(out_path, "wb") as fp: for x in samples: - fp.write((json.dumps(x) + "\n").encode('utf-8')) - + fp.write((json.dumps(x) + "\n").encode("utf-8")) + results = evaluate_functional_correctness( sample_file=out_path, k=[1], @@ -71,6 +73,7 @@ def eval_humaneval( print(results) + @torch.inference_mode() def generate_batch_completion( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt, batch_size @@ -82,7 +85,7 @@ def generate_batch_completion( generated_ids = model.generate( **inputs, use_cache=True, - max_new_tokens=512, + max_seq_len=512, temperature=0.2, top_p=0.95, do_sample=True, @@ -106,11 +109,12 @@ def fix_indents(text: str) -> str: return [filter_code(fix_indents(completion)) for completion in batch_completions] -def check_correctness(problem: Dict, completion: str, timeout: float, - completion_id: Optional[int] = None) -> Dict: +def check_correctness( + problem: Dict, completion: str, timeout: float, completion_id: Optional[int] = None +) -> Dict: """ Evaluates the functional correctness of a completion by running the test - suite provided in the problem. + suite provided in the problem. :param completion_id: an optional completion ID so we can match the results later even if execution finishes asynchronously. @@ -121,6 +125,7 @@ def unsafe_execute(): # These system calls are needed when cleaning up tempdir. import os import shutil + rmtree = shutil.rmtree rmdir = os.rmdir chdir = os.chdir @@ -130,11 +135,14 @@ def unsafe_execute(): # Construct the check program and run it. check_program = ( - problem["prompt"] + completion + "\n" + - problem["test"] + "\n" + - f"check({problem['entry_point']})" + problem["prompt"] + + completion + + "\n" + + problem["test"] + + "\n" + + f"check({problem['entry_point']})" ) - + try: exec_globals = {} with swallow_io(): @@ -175,6 +183,7 @@ def unsafe_execute(): def time_limit(seconds: float): def signal_handler(signum, frame): raise TimeoutException("Timed out!") + signal.setitimer(signal.ITIMER_REAL, seconds) signal.signal(signal.SIGALRM, signal_handler) try: @@ -204,7 +213,7 @@ class TimeoutException(Exception): class WriteOnlyStringIO(io.StringIO): - """ StringIO that throws an exception when it's read from """ + """StringIO that throws an exception when it's read from""" def read(self, *args, **kwargs): raise IOError @@ -216,12 +225,12 @@ def readlines(self, *args, **kwargs): raise IOError def readable(self, *args, **kwargs): - """ Returns True if the IO object can be read. """ + """Returns True if the IO object can be read.""" return False class redirect_stdin(contextlib._RedirectStream): # type: ignore - _stream = 'stdin' + _stream = "stdin" @contextlib.contextmanager @@ -238,13 +247,14 @@ def chdir(root): finally: os.chdir(cwd) + def stream_jsonl(filename: str) -> Iterable[Dict]: """ Parses each jsonl line and yields it as a dictionary """ if filename.endswith(".gz"): with open(filename, "rb") as gzfp: - with gzip.open(gzfp, 'rt') as fp: + with gzip.open(gzfp, "rt") as fp: for line in fp: if any(not x.isspace() for x in line): yield json.loads(line) @@ -254,6 +264,7 @@ def stream_jsonl(filename: str) -> Iterable[Dict]: if any(not x.isspace() for x in line): yield json.loads(line) + def estimate_pass_at_k( num_samples: Union[int, List[int], np.ndarray], num_correct: Union[List[int], np.ndarray], @@ -288,7 +299,10 @@ def evaluate_functional_correctness( n_workers: int = 4, timeout: float = 3.0, ): - problems = {example["task_id"]: example for example in load_dataset("openai_humaneval")["test"]} + problems = { + example["task_id"]: example + for example in load_dataset("openai_humaneval")["test"] + } # Check the generated samples against test suites. with ThreadPoolExecutor(max_workers=n_workers) as executor: @@ -308,9 +322,11 @@ def evaluate_functional_correctness( n_samples += 1 if len(completion_id) < len(problems): - include_keys = list(problems.keys())[:len(completion_id)] - print(f"Only found {len(completion_id)} solutions, reducing problems from {len(problems)}...") - problems = {k:v for k,v in problems.items() if k in include_keys} + include_keys = list(problems.keys())[: len(completion_id)] + print( + f"Only found {len(completion_id)} solutions, reducing problems from {len(problems)}..." + ) + problems = {k: v for k, v in problems.items() if k in include_keys} assert len(completion_id) == len(problems), "Some problems are not attempted." @@ -347,6 +363,7 @@ def combine_results(): return pass_at_k + def reliability_guard(maximum_memory_bytes: Optional[int] = None): """ This disables various destructive functions and prevents the generated code @@ -355,7 +372,7 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): WARNING This function is NOT a security sandbox. Untrusted code, including, model- - generated code, should not be blindly executed outside of one. See the + generated code, should not be blindly executed outside of one. See the Codex paper for more information about OpenAI's code sandbox, and proceed with caution. """ @@ -364,19 +381,28 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): if maximum_memory_bytes is not None: import resource - resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) - resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) - if not platform.uname().system == 'Darwin': - resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) + + resource.setrlimit( + resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes) + ) + resource.setrlimit( + resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes) + ) + if not platform.uname().system == "Darwin": + resource.setrlimit( + resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes) + ) faulthandler.disable() import builtins + builtins.exit = None builtins.quit = None import os - os.environ['OMP_NUM_THREADS'] = '1' + + os.environ["OMP_NUM_THREADS"] = "1" os.kill = None os.system = None @@ -407,25 +433,32 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): os.chdir = None import shutil + shutil.rmtree = None shutil.move = None shutil.chown = None import subprocess + subprocess.Popen = None # type: ignore import sys - sys.modules['ipdb'] = None - sys.modules['joblib'] = None - sys.modules['resource'] = None - sys.modules['psutil'] = None - sys.modules['tkinter'] = None -if __name__ == '__main__': + sys.modules["ipdb"] = None + sys.modules["joblib"] = None + sys.modules["resource"] = None + sys.modules["psutil"] = None + sys.modules["tkinter"] = None + + +if __name__ == "__main__": os.environ["TOKENIZERS_PARALLELISM"] = "true" from awq import AutoAWQForCausalLM - model_path = 'TheBloke/zephyr-7B-beta-AWQ' - model = AutoAWQForCausalLM.from_quantized(model_path, device_map="auto", max_new_tokens=2048) + + model_path = "TheBloke/zephyr-7B-beta-AWQ" + model = AutoAWQForCausalLM.from_quantized( + model_path, device_map="auto", max_seq_len=2048 + ) tokenizer = AutoTokenizer.from_pretrained(model_path) eval_humaneval(model, tokenizer) diff --git a/awq/evaluation/kl_divergence.py b/awq/evaluation/kl_divergence.py index f19f6bed..d24efe18 100644 --- a/awq/evaluation/kl_divergence.py +++ b/awq/evaluation/kl_divergence.py @@ -15,17 +15,20 @@ from scipy.stats import bayes_mvs from scipy.stats import t as student_t from scipy.stats.mstats import mquantiles_cimj + SCIPY_INSTALLED = True except: SCIPY_INSTALLED = False + @torch.jit.script def rel_entr(x, y): mask = (x > 0) & (y > 0) result = torch.where(mask, x * torch.log(x / y), torch.zeros_like(x)) - result[(x > 0) & (y <= 0)] = float('inf') + result[(x > 0) & (y <= 0)] = float("inf") return result + def bin_conf(p, n, z): # Binomial distribution confidence bounds # Bayes estimator when p is degenerate @@ -33,15 +36,23 @@ def bin_conf(p, n, z): p = 1 / (n + 2) if p == 1: p = 1 - 1 / (n + 2) - return z * torch.sqrt(p*(1-p)/n) + return z * torch.sqrt(p * (1 - p) / n) + -def eval_kl_divergence(ref_model: PreTrainedModel, eval_model: PreTrainedModel, tokenizer: PreTrainedTokenizer, seqlen: int): +def eval_kl_divergence( + ref_model: PreTrainedModel, + eval_model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + seqlen: int, +): if not SCIPY_INSTALLED: - raise Exception("SciPy needs to be installed for KL Divergence evaluation: pip install scipy") + raise Exception( + "SciPy needs to be installed for KL Divergence evaluation: pip install scipy" + ) # load dataset - data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') - data = tokenizer("\n\n".join(data['text']), return_tensors='pt') + data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + data = tokenizer("\n\n".join(data["text"]), return_tensors="pt") data = data.input_ids.to(ref_model.device) n_samples = data.numel() // seqlen @@ -59,11 +70,11 @@ def eval_kl_divergence(ref_model: PreTrainedModel, eval_model: PreTrainedModel, # start eval with tqdm(range(n_samples), desc="KL Div") as progress_bar: for i in progress_bar: - start_index = (i * seqlen) - end_index = ((i + 1) * seqlen) - batch_len = end_index-start_index + start_index = i * seqlen + end_index = (i + 1) * seqlen + batch_len = end_index - start_index batch = data[:, start_index:end_index] - + # get logits with torch.no_grad(): y1 = ref_model(batch)[0] @@ -75,7 +86,7 @@ def eval_kl_divergence(ref_model: PreTrainedModel, eval_model: PreTrainedModel, relative_entropy = rel_entr(y1_probs, y2_probs) kl_div = torch.sum(relative_entropy, dim=-1).squeeze(0) kls.append(torch.nan_to_num(kl_div).tolist()) - + # stats eval_argmax = torch.argmax(y2, axis=-1).squeeze(0) ref_argmax = torch.argmax(y1, axis=-1).squeeze(0) @@ -96,10 +107,10 @@ def eval_kl_divergence(ref_model: PreTrainedModel, eval_model: PreTrainedModel, f"Top 5: {top5 / samples:.4g}, " f"Top 10: {top10 / samples:.4g}" ) - - z = student_t.ppf(1 - alpha/2, samples) - m_conf = z*np.sqrt(np.mean([k**2 for k in kls])/len(kls)) - m, _, __ = bayes_mvs(kls, 1-alpha) + + z = student_t.ppf(1 - alpha / 2, samples) + m_conf = z * np.sqrt(np.mean([k**2 for k in kls]) / len(kls)) + m, _, __ = bayes_mvs(kls, 1 - alpha) q90 = np.quantile(kls, 0.90) q95 = np.quantile(kls, 0.95) q99 = np.quantile(kls, 0.99) @@ -116,20 +127,33 @@ def eval_kl_divergence(ref_model: PreTrainedModel, eval_model: PreTrainedModel, print(f"max: {np.max(kls):.4g}") print(" -- ") print("Reference top token in eval top-n probability:") - print(f" ** ref_top1: {top1 / samples:.4g} ± {bin_conf(top1/samples, samples, z):.4g}") - print(f" ** ref_top5: {top5 / samples:.4g} ± {bin_conf(top5/samples, samples, z):.4g}") - print(f" ** ref_top10: {top10 / samples:4g} ± {bin_conf(top10/samples, samples, z):.4g}") + print( + f" ** ref_top1: {top1 / samples:.4g} ± {bin_conf(top1/samples, samples, z):.4g}" + ) + print( + f" ** ref_top5: {top5 / samples:.4g} ± {bin_conf(top5/samples, samples, z):.4g}" + ) + print( + f" ** ref_top10: {top10 / samples:4g} ± {bin_conf(top10/samples, samples, z):.4g}" + ) print("Eval top token in reference top-n probability:") - print(f" ** eval_top5: {eval_top5 / samples:.4g} ± {bin_conf(eval_top5/samples, samples, z):.4g}") - print(f" ** eval_top10: {eval_top10 / samples:4g} ± {bin_conf(eval_top10/samples, samples, z):.4g}") + print( + f" ** eval_top5: {eval_top5 / samples:.4g} ± {bin_conf(eval_top5/samples, samples, z):.4g}" + ) + print( + f" ** eval_top10: {eval_top10 / samples:4g} ± {bin_conf(eval_top10/samples, samples, z):.4g}" + ) + -if __name__ == '__main__': +if __name__ == "__main__": # ref_model_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" # eval_model_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T" ref_model_path = eval_model_path = "gpt2" tokenizer = AutoTokenizer.from_pretrained(ref_model_path) ref_model = AutoModelForCausalLM.from_pretrained(ref_model_path, device_map="auto") - eval_model = AutoModelForCausalLM.from_pretrained(eval_model_path, device_map="auto") + eval_model = AutoModelForCausalLM.from_pretrained( + eval_model_path, device_map="auto" + ) - eval_kl_divergence(ref_model, eval_model, tokenizer, seqlen=1024) \ No newline at end of file + eval_kl_divergence(ref_model, eval_model, tokenizer, seqlen=1024) diff --git a/awq/models/_config.py b/awq/models/_config.py index 2972f7a7..0fbfc3a2 100644 --- a/awq/models/_config.py +++ b/awq/models/_config.py @@ -1,35 +1,28 @@ import os import json -import logging from typing import Dict, Optional, List -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field from transformers.utils.hub import PushToHubMixin, cached_file + @dataclass class AwqConfig(PushToHubMixin): quant_method: str = field(default="awq") zero_point: bool = field(default=True) q_group_size: int = field(default=128) w_bit: int = field(default=4) - version: str = field(default="GEMM") - config_file_name = "quant_config.json" + version: str = field(default="gemm") + config_file_name = "config.json" modules_to_not_convert: Optional[List] = None - def save_pretrained(self, save_dir: str, **kwargs): - logging.warning( - "`quant_config.json` is being deprecated in the future" - " in favor of quantization_config in config.json." - ) - with open(os.path.join(save_dir, self.config_file_name), "w+", encoding="utf-8") as file: - file.write(json.dumps(self.to_dict(), indent=4)) - @classmethod - def from_dict(cls, quant_config: Dict={}): + def from_dict(cls, quant_config: Dict = {}): if not quant_config: quant_config = cls() else: quant_config = cls(**quant_config) - + quant_config.version = quant_config.version.lower() + return quant_config @classmethod @@ -46,7 +39,7 @@ def from_pretrained(cls, save_dir: str, **kwargs): if os.path.isdir(save_dir): # Local resolved_config_file = os.path.join(save_dir, cls.config_file_name) - else: # Remote + else: # Remote resolved_config_file = cached_file( save_dir, cls.config_file_name, @@ -62,14 +55,21 @@ def from_pretrained(cls, save_dir: str, **kwargs): _raise_exceptions_for_connection_errors=False, _commit_hash=commit_hash, ) - + + quant_config = None if os.path.exists(resolved_config_file): - with open(resolved_config_file, 'r', encoding="utf-8") as file: + with open(resolved_config_file, "r", encoding="utf-8") as file: loaded_config = json.loads(file.read()) - quant_config = cls(**loaded_config) - else: + + quant_config = loaded_config.get("quantization_config") + + if quant_config is not None: + awq_config = cls.from_transformers_dict(cls, quant_config) + quant_config = cls(**awq_config) + + if quant_config is None: quant_config = cls() - + return quant_config def to_dict(self): @@ -90,3 +90,13 @@ def to_transformers_dict(self): "version": self.version.lower(), "modules_to_not_convert": self.modules_to_not_convert, } + + def from_transformers_dict(self, transformers_dict: Dict): + return { + "quant_method": transformers_dict.get("quant_method"), + "zero_point": transformers_dict.get("zero_point"), + "q_group_size": transformers_dict.get("group_size"), + "w_bit": transformers_dict.get("bits"), + "version": transformers_dict.get("version"), + "modules_to_not_convert": transformers_dict.get("modules_to_not_convert"), + } diff --git a/awq/models/aquila.py b/awq/models/aquila.py index 203e3f23..c2b8ce07 100644 --- a/awq/models/aquila.py +++ b/awq/models/aquila.py @@ -6,13 +6,14 @@ from awq.modules.fused.model import LlamaLikeModel from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer as OldAquilaDecoderLayer, - LlamaForCausalLM as OldAquilaForCausalLM + LlamaForCausalLM as OldAquilaForCausalLM, ) from awq.modules.fused.norm import FasterTransformerRMSNorm + class AquilaAWQForCausalLM(BaseAWQForCausalLM): layer_type = "AquilaDecoderLayer" - max_new_tokens_key = "max_position_embeddings" + max_seq_len_key = "max_position_embeddings" @staticmethod def fuse_layers(model: OldAquilaForCausalLM): @@ -22,53 +23,65 @@ def fuse_layers(model: OldAquilaForCausalLM): @staticmethod def get_model_layers(model: OldAquilaForCausalLM): return model.model.layers - + @staticmethod def get_act_for_scaling(module: OldAquilaDecoderLayer): - return dict( - is_scalable=False - ) - + return dict(is_scalable=False) + @staticmethod def move_embed(model: OldAquilaForCausalLM, device: str): model.model.embed_tokens = model.model.embed_tokens.to(device) - + @staticmethod - def get_layers_for_scaling(module: OldAquilaDecoderLayer, input_feat, module_kwargs): + def get_layers_for_scaling( + module: OldAquilaDecoderLayer, input_feat, module_kwargs + ): layers = [] # attention input - layers.append(dict( - prev_op=module.input_layernorm, - layers=[module.self_attn.q_proj, - module.self_attn.k_proj, module.self_attn.v_proj], - inp=input_feat['self_attn.q_proj'], - module2inspect=module.self_attn, kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) # attention out # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: - layers.append(dict( - prev_op=module.self_attn.v_proj, - layers=[module.self_attn.o_proj], - inp=input_feat['self_attn.o_proj'], - )) - + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + # linear 1 - layers.append(dict( - prev_op=module.post_attention_layernorm, - layers=[module.mlp.gate_proj, module.mlp.up_proj], - inp=input_feat['mlp.gate_proj'], - module2inspect=module.mlp, - )) + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj, module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp, + ) + ) # linear 2 - layers.append(dict( - prev_op=module.mlp.up_proj, - layers=[module.mlp.down_proj], - inp=input_feat['mlp.down_proj'], - )) + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) return layers @@ -78,10 +91,11 @@ def __init__(self, model: OldAquilaForCausalLM): self.model = model self.aquila_blocks: List[Tuple[str, OldAquilaDecoderLayer]] = [ - (name, module) for name, module in self.model.named_modules() - if 'AquilaDecoderLayer'.lower() in module.__class__.__name__.lower() + (name, module) + for name, module in self.model.named_modules() + if "AquilaDecoderLayer".lower() in module.__class__.__name__.lower() ] - + def fuse_transformer(self): blocks = [] @@ -92,29 +106,30 @@ def fuse_transformer(self): module, module.self_attn.q_proj, module.self_attn.k_proj, - module.self_attn.v_proj + module.self_attn.v_proj, ) norm_1 = FasterTransformerRMSNorm( - module.input_layernorm.weight, - module.input_layernorm.variance_epsilon + module.input_layernorm.weight, module.input_layernorm.variance_epsilon ) norm_2 = FasterTransformerRMSNorm( module.post_attention_layernorm.weight, - module.post_attention_layernorm.variance_epsilon + module.post_attention_layernorm.variance_epsilon, ) - blocks.append(LlamaLikeBlock( - hidden_size=self.model.config.hidden_size, - n_heads=self.model.config.num_attention_heads, - n_kv_heads=self.model.config.num_key_value_heads, - qkv_layer=qkv, - o_proj=module.self_attn.o_proj, - mlp=module.mlp, - norm_1=norm_1, - norm_2=norm_2, - dev=device, - max_seq_len=self.model.config.max_new_tokens - )) - + blocks.append( + LlamaLikeBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=module.mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_seq_len, + ) + ) + self.model.model = LlamaLikeModel( self.model.config.vocab_size, blocks, diff --git a/awq/models/auto.py b/awq/models/auto.py index c060b47f..c992061f 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -1,4 +1,5 @@ import os +import logging from transformers import AutoConfig from awq.models import * from awq.models.base import BaseAWQForCausalLM @@ -21,7 +22,7 @@ "qwen": QwenAWQForCausalLM, "baichuan": BaichuanAWQForCausalLM, "llava": LlavaAWQForCausalLM, - "qwen2": Qwen2AWQForCausalLM + "qwen2": Qwen2AWQForCausalLM, } @@ -47,7 +48,7 @@ def from_pretrained( self, model_path, trust_remote_code=True, - safetensors=False, + safetensors=True, device_map=None, **model_init_kwargs, ) -> BaseAWQForCausalLM: @@ -69,7 +70,7 @@ def from_quantized( self, quant_path, quant_filename="", - max_new_tokens=None, + max_seq_len=2048, trust_remote_code=True, fuse_layers=True, use_exllama=False, @@ -83,11 +84,18 @@ def from_quantized( os.environ["AWQ_BATCH_SIZE"] = str(batch_size) model_type = check_and_get_model_type(quant_path, trust_remote_code) + if config_kwargs.get("max_new_tokens") is not None: + max_seq_len = config_kwargs["max_new_tokens"] + logging.warning( + "max_new_tokens argument is deprecated... gracefully " + "setting max_seq_len=max_new_tokens." + ) + return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( quant_path, model_type, quant_filename, - max_new_tokens, + max_seq_len, trust_remote_code=trust_remote_code, fuse_layers=fuse_layers, use_exllama=use_exllama, diff --git a/awq/models/baichuan.py b/awq/models/baichuan.py index 2e93d8d9..230ccd05 100644 --- a/awq/models/baichuan.py +++ b/awq/models/baichuan.py @@ -8,9 +8,10 @@ ) from awq.modules.fused.norm import FasterTransformerRMSNorm + class BaichuanAWQForCausalLM(BaseAWQForCausalLM): layer_type = "BaichuanLayer" - max_new_tokens_key = "model_max_length" + max_seq_len_key = "model_max_length" @staticmethod def fuse_layers(model): @@ -20,29 +21,30 @@ def fuse_layers(model): @staticmethod def get_model_layers(model): return model.model.layers - + @staticmethod def get_act_for_scaling(module): - return dict( - is_scalable=False - ) - + return dict(is_scalable=False) + @staticmethod def move_embed(model, device: str): model.model.embed_tokens = model.model.embed_tokens.to(device) - + @staticmethod # def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs): def get_layers_for_scaling(module, input_feat, module_kwargs): layers = [] # attention input - layers.append(dict( - prev_op=module.input_layernorm, - layers=[module.self_attn.W_pack], - inp=input_feat['self_attn.W_pack'], - module2inspect=module.self_attn, kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[module.self_attn.W_pack], + inp=input_feat["self_attn.W_pack"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) # # attention out # # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 @@ -55,26 +57,32 @@ def get_layers_for_scaling(module, input_feat, module_kwargs): # attention out # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 - layers.append(dict( - prev_op=module.self_attn.W_pack, - layers=[module.self_attn.o_proj], - inp=input_feat['self_attn.o_proj'], - )) - + layers.append( + dict( + prev_op=module.self_attn.W_pack, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + # linear 1 - layers.append(dict( - prev_op=module.post_attention_layernorm, - layers=[module.mlp.gate_proj, module.mlp.up_proj], - inp=input_feat['mlp.gate_proj'], - module2inspect=module.mlp, - )) + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj, module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp, + ) + ) # linear 2 - layers.append(dict( - prev_op=module.mlp.up_proj, - layers=[module.mlp.down_proj], - inp=input_feat['mlp.down_proj'], - )) + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) return layers @@ -84,10 +92,11 @@ def __init__(self, model): self.model = model self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [ - (name, module) for name, module in self.model.named_modules() - if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower() + (name, module) + for name, module in self.model.named_modules() + if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower() ] - + def fuse_transformer(self): blocks = [] @@ -101,27 +110,28 @@ def fuse_transformer(self): # ) qkv = module.self_attn.W_pack norm_1 = FasterTransformerRMSNorm( - module.input_layernorm.weight, - module.input_layernorm.epsilon + module.input_layernorm.weight, module.input_layernorm.epsilon ) norm_2 = FasterTransformerRMSNorm( module.post_attention_layernorm.weight, - module.post_attention_layernorm.epsilon + module.post_attention_layernorm.epsilon, ) - blocks.append(LlamaLikeBlock( - hidden_size=self.model.config.hidden_size, - n_heads=self.model.config.num_attention_heads, - n_kv_heads=self.model.config.num_attention_heads, - qkv_layer=qkv, - o_proj=module.self_attn.o_proj, - mlp=module.mlp, - norm_1=norm_1, - norm_2=norm_2, - dev=device, - max_seq_len=self.model.config.max_new_tokens, - use_alibi=True - )) - + blocks.append( + LlamaLikeBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_attention_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=module.mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_seq_len, + use_alibi=True, + ) + ) + self.model.model = LlamaLikeModel( self.model.config.vocab_size, blocks, diff --git a/awq/models/base.py b/awq/models/base.py index 36e86e0a..53ee2f50 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -6,8 +6,9 @@ import torch.nn as nn from tqdm import tqdm -from typing import List, Union +from typing import List, Union, Dict from safetensors.torch import save_file +from typing_extensions import Doc, Annotated from huggingface_hub import snapshot_download from transformers.modeling_utils import shard_checkpoint @@ -27,6 +28,7 @@ PretrainedConfig, AutoProcessor, CLIPImageProcessor, + PreTrainedTokenizer, ) from accelerate.big_modeling import ( init_empty_weights, @@ -64,8 +66,21 @@ class BaseAWQForCausalLM(nn.Module): def __init__( - self, model, model_type, is_quantized, config, quant_config, processor + self, + model: Annotated[PreTrainedModel, Doc("The pretrained or quantized model.")], + model_type: Annotated[str, Doc("The model type, found in config.json.")], + is_quantized: Annotated[ + bool, Doc("Indicates if the current model is quantized.") + ], + config: Annotated[PretrainedConfig, Doc("The config of the model.")], + quant_config: Annotated[ + AwqConfig, Doc("The quantization config of the model.") + ], + processor: Annotated[ + AutoProcessor, Doc("An optional processor, e.g. for vision models.") + ], ): + """The base model for all AutoAWQ models.""" super().__init__() self.model: PreTrainedModel = model self.model_type: str = model_type @@ -75,30 +90,68 @@ def __init__( self.quant_config: AwqConfig = quant_config self.processor: CLIPImageProcessor = processor - def to(self, device: str): + def to(self, device: Annotated[str, Doc("The device to move your model to.")]): + """A utility function for moving the model to a device.""" return self.model.to(device) def forward(self, *args, **kwargs): + """A forward function that mimics the torch forward.""" return self.model(*args, **kwargs) def generate(self, *args, **kwargs): + """A generate function that mimics the HF generate function.""" with torch.inference_mode(): return self.model.generate(*args, **kwargs) @torch.no_grad() def quantize( self, - tokenizer=None, - quant_config={}, - calib_data: Union[str, List[str]] = "pileval", - split="train", - text_column="text", - duo_scaling=True, - modules_to_not_convert=None, - export_compatible=False, + tokenizer: Annotated[ + PreTrainedTokenizer, Doc("The tokenizer to use for quantization.") + ] = None, + quant_config: Annotated[ + Dict, Doc("The quantization config you want to use.") + ] = {}, + calib_data: Annotated[ + Union[str, List[str]], + Doc( + "The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples." + ), + ] = "pileval", + split: Annotated[str, Doc("The split of calib_data.")] = "train", + text_column: Annotated[str, Doc("The text column of calib_data.")] = "text", + duo_scaling: Annotated[ + bool, Doc("Whether to scale using both w/x or just x.") + ] = True, + export_compatible: Annotated[ + bool, + Doc( + "This argument avoids real quantization by only applying the scales without quantizing down to FP16." + ), + ] = False, ): + """ + The main quantization function that you can use to quantize your model. + + Example: + + ```python + from awq import AutoAWQForCausalLM + from transformers import AutoTokenizer + + model_path = "..." + model = AutoAWQForCausalLM.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path) + + quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } + model.quantize(tokenizer, quant_config) + ``` + """ self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config) + if hasattr(self, "modules_to_not_convert"): + self.quant_config.modules_to_not_convert = self.modules_to_not_convert + self.quantizer = AwqQuantizer( self, self.model, @@ -111,7 +164,7 @@ def quantize( split, text_column, duo_scaling, - modules_to_not_convert=modules_to_not_convert, + modules_to_not_convert=self.quant_config.modules_to_not_convert, export_compatible=export_compatible, ) self.quantizer.quantize() @@ -124,6 +177,9 @@ def pack(self): A utility function for the following scenario. Note that save_quantized will overwrite existing weights if you use the same quant_path. + Example: + + ```python model.quantize( tokenizer, quant_config=quant_config, @@ -132,6 +188,7 @@ def pack(self): model.save_quantized(...) # produces GGUF/other compat weights model.pack(...) # makes the model CUDA compat model.save_quantized(...) # produces CUDA compat weights + ``` """ self.quantizer.pack() @@ -139,7 +196,16 @@ def pack(self): def fuse_layers(model): pass - def save_quantized(self, save_dir, safetensors=True, shard_size="10GB"): + def save_quantized( + self, + save_dir: Annotated[str, Doc("The directory to save your model to.")], + safetensors: Annotated[ + bool, Doc("Whether to save the model as safetensors or torch files.") + ] = True, + shard_size: Annotated[ + str, Doc("The shard size for sharding large models into multiple chunks.") + ] = "5GB", + ): save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir # Save model @@ -154,7 +220,6 @@ def forward(self, x): self.model.config.quantization_config = self.quant_config.to_transformers_dict() self.model.generation_config.do_sample = True self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict()) - self.quant_config.save_pretrained(save_dir) # Vision transformers have a processor if self.processor is not None: @@ -195,14 +260,37 @@ def forward(self, x): @classmethod def from_pretrained( self, - model_path, - model_type, - torch_dtype: torch.dtype = torch.float16, - trust_remote_code=True, - safetensors=False, - device_map=None, - **model_init_kwargs, + model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")], + model_type: Annotated[str, Doc("The model type, loaded from config.json.")], + torch_dtype: Annotated[ + torch.dtype, + Doc( + "The dtype to load the model as. May not work with other values than float16." + ), + ] = torch.float16, + trust_remote_code: Annotated[ + bool, + Doc( + "Useful for Huggingface repositories that have not been integrated into transformers yet." + ), + ] = True, + safetensors: Annotated[ + bool, Doc("Whether to download/load safetensors instead of torch weights.") + ] = True, + device_map: Annotated[ + Union[str, Dict], + Doc( + "A device map that will be passed onto the model loading method from transformers." + ), + ] = None, + **model_init_kwargs: Annotated[ + Dict, + Doc( + "Additional kwargs that are passed to the model during initialization." + ), + ], ): + """A method for initialization of pretrained models, usually in FP16.""" # Get weights path and quant config model_weights_path, config, quant_config = self._load_config( self, model_path, "", safetensors, trust_remote_code=trust_remote_code @@ -240,31 +328,70 @@ def from_pretrained( @classmethod def from_quantized( self, - model_path, - model_type, - model_filename="", - max_new_tokens=None, - torch_dtype=torch.float16, - trust_remote_code=True, - safetensors=True, - is_quantized=True, - fuse_layers=False, - use_exllama=False, - use_exllama_v2=False, - version="GEMM", - device_map="balanced", - offload_folder=None, - **config_kwargs, + model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")], + model_type: Annotated[str, Doc("The model type, loaded from config.json.")], + model_filename: Annotated[ + str, Doc("Load a specific model's filename by specifying this argument.") + ] = "", + max_seq_len: Annotated[ + int, + Doc( + "The maximum sequence cached sequence length of the model. Larger values may increase loading time and memory usage." + ), + ] = None, + torch_dtype: Annotated[ + torch.dtype, + Doc( + "The dtype to load the model as. May not work with other values than float16." + ), + ] = torch.float16, + trust_remote_code: Annotated[ + bool, + Doc( + "Useful for Huggingface repositories that have not been integrated into transformers yet." + ), + ] = True, + safetensors: Annotated[ + bool, Doc("Whether to download/load safetensors instead of torch weights.") + ] = True, + fuse_layers: Annotated[ + bool, + Doc( + "Whether to use fused/optimized combination of layers for increased speed." + ), + ] = True, + use_exllama: Annotated[ + bool, Doc("Whether to map the weights to ExLlamaV1 kernels.") + ] = False, + use_exllama_v2: Annotated[ + bool, Doc("Whether to map the weights to ExLlamaV2 kernels.") + ] = False, + device_map: Annotated[ + Union[str, Dict], + Doc( + "A device map that will be passed onto the model loading method from transformers." + ), + ] = "balanced", + offload_folder: Annotated[ + str, + Doc("The folder ot offload the model to."), + ] = None, + **config_kwargs: Annotated[ + Dict, + Doc( + "Additional kwargs that are passed to the config during initialization." + ), + ], ): + """A method for initialization of a quantized model, usually in INT4.""" # [STEP 1-2] Load weights path and configs model_weights_path, config, quant_config = self._load_config( self, model_path, model_filename, safetensors, - version, trust_remote_code, - max_new_tokens=max_new_tokens, + max_seq_len=max_seq_len, **config_kwargs, ) @@ -306,7 +433,7 @@ def from_quantized( if fuse_layers: self.fuse_layers(model) - if quant_config.version == "Marlin": + if quant_config.version == "marlin": model = marlin_post_init(model) elif use_exllama: @@ -316,14 +443,14 @@ def from_quantized( # creates q4 handle and allocates scratch spaces wrt max_input_len and max_batch_size model = exllamav2_post_init( model, - max_input_len=max_new_tokens or 2048, + max_input_len=max_seq_len or 2048, max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)), ) return self( model, model_type, - is_quantized=is_quantized, + is_quantized=True, config=config, quant_config=quant_config, processor=None, @@ -334,9 +461,8 @@ def _load_config( model_path, model_filename, safetensors=True, - version="GEMM", trust_remote_code=True, - max_new_tokens=4096, + max_seq_len=4096, **config_kwargs, ): # [STEP 1] Download model if path is not a directory @@ -359,22 +485,22 @@ def _load_config( quant_config = AwqConfig.from_pretrained(model_path) # Load model config and set max generation length - if max_new_tokens is None and hasattr(self, "max_new_tokens_key"): + if max_seq_len is None and hasattr(self, "max_seq_len_key"): config = AutoConfig.from_pretrained( model_path, trust_remote_code=trust_remote_code, **config_kwargs ) - config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048) + config.max_seq_len = getattr(config, self.max_seq_len_key, 2048) # To add the generate support for Multi-modal models as well if hasattr(config, "text_config"): - config.text_config.max_new_tokens = getattr( - config, self.max_new_tokens_key, 2048 + config.text_config.max_seq_len = getattr( + config, self.max_seq_len_key, 2048 ) else: - max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens + max_seq_len = 2048 if max_seq_len is None else max_seq_len config = AutoConfig.from_pretrained( model_path, trust_remote_code=trust_remote_code, **config_kwargs ) - config.max_new_tokens = max_new_tokens + config.max_seq_len = max_seq_len return model_weights_path, config, quant_config @@ -383,7 +509,7 @@ def _load_quantized_modules( ): # Real quantization of weights assert not ( - version == "GEMV" and (use_exllama or use_exllama_v2) + version == "gemv" and (use_exllama or use_exllama_v2) ), "Exllama kernels only support GEMM version." # Get blocks of model @@ -405,15 +531,15 @@ def _load_quantized_modules( # Replace nn.Linear with WQLinear for name, module in named_linears.items(): - if version == "Marlin": + if version == "marlin": q_linear_module = WQLinear_Marlin elif use_exllama: q_linear_module = WQLinear_Exllama elif use_exllama_v2: q_linear_module = WQLinear_ExllamaV2 - elif version == "GEMM": + elif version == "gemm": q_linear_module = WQLinear_GEMM - elif version == "GEMV": + elif version == "gemv": q_linear_module = WQLinear_GEMV q_linear = q_linear_module.from_linear( diff --git a/awq/models/bloom.py b/awq/models/bloom.py index ec8ab8e8..3260379f 100644 --- a/awq/models/bloom.py +++ b/awq/models/bloom.py @@ -1,38 +1,44 @@ from .base import BaseAWQForCausalLM from transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomBlock + class BloomAWQForCausalLM(BaseAWQForCausalLM): layer_type = "BloomBlock" @staticmethod def get_model_layers(model: BloomForCausalLM): return model.transformer.h - + @staticmethod def get_act_for_scaling(module: BloomBlock): return dict( is_scalable=True, scale_name="mlp.gelu_impl", scale_layer=module.mlp.gelu_impl, - scale_shape=module.mlp.dense_h_to_4h.out_features + scale_shape=module.mlp.dense_h_to_4h.out_features, ) - + @staticmethod def move_embed(model: BloomForCausalLM, device: str): model.transformer.word_embeddings = model.transformer.word_embeddings.to(device) - model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device) - + model.transformer.word_embeddings_layernorm = ( + model.transformer.word_embeddings_layernorm.to(device) + ) + @staticmethod def get_layers_for_scaling(module: BloomBlock, input_feat, module_kwargs): layers = [] # attention input - layers.append(dict( - prev_op=module.input_layernorm, - layers=[module.self_attention.query_key_value], - inp=input_feat['self_attention.query_key_value'], - module2inspect=module, kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[module.self_attention.query_key_value], + inp=input_feat["self_attention.query_key_value"], + module2inspect=module, + kwargs=module_kwargs, + ) + ) # attention out # Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469 """ @@ -43,17 +49,22 @@ def get_layers_for_scaling(module: BloomBlock, input_feat, module_kwargs): )) """ # linear 1 - layers.append(dict( - prev_op=module.post_attention_layernorm, - layers=[module.mlp.dense_h_to_4h], - inp=input_feat['mlp.dense_h_to_4h'], - module2inspect=module, kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.dense_h_to_4h], + inp=input_feat["mlp.dense_h_to_4h"], + module2inspect=module, + kwargs=module_kwargs, + ) + ) # linear 2 - layers.append(dict( - prev_op=module.mlp.gelu_impl, - layers=[module.mlp.dense_4h_to_h], - inp=input_feat['mlp.dense_4h_to_h'], - )) + layers.append( + dict( + prev_op=module.mlp.gelu_impl, + layers=[module.mlp.dense_4h_to_h], + inp=input_feat["mlp.dense_4h_to_h"], + ) + ) - return layers \ No newline at end of file + return layers diff --git a/awq/models/falcon.py b/awq/models/falcon.py index 0ddcc990..074609b4 100644 --- a/awq/models/falcon.py +++ b/awq/models/falcon.py @@ -1,5 +1,10 @@ from .base import BaseAWQForCausalLM -from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention +from transformers.models.falcon.modeling_falcon import ( + FalconDecoderLayer as OldFalconDecoderLayer, + FalconForCausalLM, + FalconAttention, +) + class FalconAWQForCausalLM(BaseAWQForCausalLM): layer_type = "FalconDecoderLayer" @@ -15,64 +20,77 @@ def fuse_layers(model: FalconForCausalLM): @staticmethod def get_model_layers(model: FalconForCausalLM): return model.transformer.h - + @staticmethod def get_act_for_scaling(module: OldFalconDecoderLayer): return dict( is_scalable=True, scale_name="mlp.act", scale_layer=module.mlp.act, - scale_shape=module.mlp.dense_h_to_4h.out_features + scale_shape=module.mlp.dense_h_to_4h.out_features, ) - + @staticmethod def move_embed(model: FalconForCausalLM, device): model.transformer.word_embeddings = model.transformer.word_embeddings.to(device) - + @staticmethod - def get_layers_for_scaling(module: OldFalconDecoderLayer, input_feat, module_kwargs): + def get_layers_for_scaling( + module: OldFalconDecoderLayer, input_feat, module_kwargs + ): layers = [] - + # Falcon 7B (older architecture) if module.config.num_attention_heads == 71: # linear 1 + attention - layers.append(dict( - prev_op=module.input_layernorm, - layers=[module.mlp.dense_h_to_4h, module.self_attention.query_key_value], - inp=input_feat['self_attention.query_key_value'], - module2inspect=module, - kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.mlp.dense_h_to_4h, + module.self_attention.query_key_value, + ], + inp=input_feat["self_attention.query_key_value"], + module2inspect=module, + kwargs=module_kwargs, + ) + ) # Falcon 40B (newer architecture) else: # linear 1 + attention - layers.append(dict( - prev_op=module.ln_attn, - layers=[module.self_attention.query_key_value], - inp=input_feat['self_attention.query_key_value'], - module2inspect=module, - kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.ln_attn, + layers=[module.self_attention.query_key_value], + inp=input_feat["self_attention.query_key_value"], + module2inspect=module, + kwargs=module_kwargs, + ) + ) # linear 2 - layers.append(dict( - prev_op=module.ln_mlp, - layers=[module.mlp.dense_h_to_4h], - inp=input_feat['mlp.dense_h_to_4h'], - module2inspect=module, - kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.ln_mlp, + layers=[module.mlp.dense_h_to_4h], + inp=input_feat["mlp.dense_h_to_4h"], + module2inspect=module, + kwargs=module_kwargs, + ) + ) return layers + from awq.modules.fused.model import FalconModel from awq.modules.fused.block import FalconDecoderLayer + class FalconFuser: def __init__(self, model: FalconForCausalLM): self.model = model - + def fuse_transformer(self): blocks = [] @@ -88,20 +106,22 @@ def fuse_transformer(self): ln_attn = module.ln_attn ln_mlp = module.ln_mlp new_decoder_arch = True - - blocks.append(FalconDecoderLayer( - hidden_size=module.config.hidden_size, - n_heads=module.config.num_attention_heads, - qkv_layer=module.self_attention.query_key_value, - o_proj=module.self_attention.dense, - mlp=module.mlp, - dev=next(iter(module.state_dict().values())).device, - max_seq_len=self.model.config.max_new_tokens, - input_layernorm=input_layernorm, - ln_attn=ln_attn, - ln_mlp=ln_mlp, - new_decoder_arch=new_decoder_arch - )) + + blocks.append( + FalconDecoderLayer( + hidden_size=module.config.hidden_size, + n_heads=module.config.num_attention_heads, + qkv_layer=module.self_attention.query_key_value, + o_proj=module.self_attention.dense, + mlp=module.mlp, + dev=next(iter(module.state_dict().values())).device, + max_seq_len=self.model.config.max_seq_len, + input_layernorm=input_layernorm, + ln_attn=ln_attn, + ln_mlp=ln_mlp, + new_decoder_arch=new_decoder_arch, + ) + ) self.model.transformer = FalconModel( self.model.config.vocab_size, @@ -110,4 +130,4 @@ def fuse_transformer(self): self.model.transformer.ln_f, ) - setattr(self.model.transformer, "blocks", self.model.transformer.blocks) \ No newline at end of file + setattr(self.model.transformer, "blocks", self.model.transformer.blocks) diff --git a/awq/models/gpt_bigcode.py b/awq/models/gpt_bigcode.py index a6dda5ff..f3c02a1f 100644 --- a/awq/models/gpt_bigcode.py +++ b/awq/models/gpt_bigcode.py @@ -1,9 +1,13 @@ from .base import BaseAWQForCausalLM -from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM, GPTBigCodeBlock as OldGptBigCodeBlock +from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( + GPTBigCodeForCausalLM, + GPTBigCodeBlock as OldGptBigCodeBlock, +) + class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM): layer_type = "GPTBigCodeBlock" - max_new_tokens_key = "n_positions" + max_seq_len_key = "n_positions" @staticmethod def get_model_layers(model: GPTBigCodeForCausalLM): @@ -15,7 +19,7 @@ def get_act_for_scaling(module: OldGptBigCodeBlock): is_scalable=True, scale_name="mlp.act", scale_layer=module.mlp.act, - scale_shape=module.mlp.c_fc.out_features + scale_shape=module.mlp.c_fc.out_features, ) @staticmethod @@ -25,31 +29,37 @@ def move_embed(model: GPTBigCodeForCausalLM, device): model.transformer.drop = model.transformer.drop.to(device) @staticmethod - def get_layers_for_scaling(module:OldGptBigCodeBlock, input_feat, module_kwargs): + def get_layers_for_scaling(module: OldGptBigCodeBlock, input_feat, module_kwargs): layers = [] # attention input - layers.append(dict( - prev_op=module.ln_1, - layers=[module.attn.c_attn], - inp=input_feat['attn.c_attn'], - module2inspect=module.attn, - kwargs=module_kwargs - )) - + layers.append( + dict( + prev_op=module.ln_1, + layers=[module.attn.c_attn], + inp=input_feat["attn.c_attn"], + module2inspect=module.attn, + kwargs=module_kwargs, + ) + ) + # linear 1 - layers.append(dict( - prev_op=module.ln_2, - layers=[module.mlp.c_fc], - inp=input_feat['mlp.c_fc'], - module2inspect=module.mlp - )) + layers.append( + dict( + prev_op=module.ln_2, + layers=[module.mlp.c_fc], + inp=input_feat["mlp.c_fc"], + module2inspect=module.mlp, + ) + ) # linear 2 - layers.append(dict( - prev_op=module.mlp.act, - layers=[module.mlp.c_proj], - inp=input_feat['mlp.c_proj'] - )) + layers.append( + dict( + prev_op=module.mlp.act, + layers=[module.mlp.c_proj], + inp=input_feat["mlp.c_proj"], + ) + ) return layers diff --git a/awq/models/gpt_neox.py b/awq/models/gpt_neox.py index 39eab113..849dedb8 100644 --- a/awq/models/gpt_neox.py +++ b/awq/models/gpt_neox.py @@ -1,14 +1,18 @@ from .base import BaseAWQForCausalLM -from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer, GPTNeoXForCausalLM +from transformers.models.gpt_neox.modeling_gpt_neox import ( + GPTNeoXLayer, + GPTNeoXForCausalLM, +) + class GPTNeoXAWQForCausalLM(BaseAWQForCausalLM): layer_type = "GPTNeoXDecoderLayer" - max_new_tokens_key = "max_position_embeddings" + max_seq_len_key = "max_position_embeddings" @staticmethod def get_model_layers(model: GPTNeoXForCausalLM): return model.gpt_neox.layers - + @staticmethod def get_act_for_scaling(module: GPTNeoXLayer): return dict( @@ -17,21 +21,23 @@ def get_act_for_scaling(module: GPTNeoXLayer): scale_layer=module.mlp.act, scale_shape=module.mlp.dense_h_to_4h.out_features, ) - + @staticmethod def move_embed(model: GPTNeoXForCausalLM, device: str): model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device) - + @staticmethod def get_layers_for_scaling(module: GPTNeoXLayer, input_feat, module_kwargs): layers = [] # attention input - layers.append(dict( - prev_op=module.input_layernorm, - layers=[module.attention.query_key_value], - inp=input_feat['attention.query_key_value'], - )) + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[module.attention.query_key_value], + inp=input_feat["attention.query_key_value"], + ) + ) # attention out # Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469 @@ -44,17 +50,21 @@ def get_layers_for_scaling(module: GPTNeoXLayer, input_feat, module_kwargs): """ # linear 1 - layers.append(dict( - prev_op=module.post_attention_layernorm, - layers=[module.mlp.dense_h_to_4h], - inp=input_feat['mlp.dense_h_to_4h'], - )) + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.dense_h_to_4h], + inp=input_feat["mlp.dense_h_to_4h"], + ) + ) # linear 2 - layers.append(dict( - prev_op=module.mlp.act, - layers=[module.mlp.dense_4h_to_h], - inp=input_feat['mlp.dense_4h_to_h'], - )) + layers.append( + dict( + prev_op=module.mlp.act, + layers=[module.mlp.dense_4h_to_h], + inp=input_feat["mlp.dense_4h_to_h"], + ) + ) return layers diff --git a/awq/models/gptj.py b/awq/models/gptj.py index 31748cf6..178d49a9 100644 --- a/awq/models/gptj.py +++ b/awq/models/gptj.py @@ -1,53 +1,64 @@ from .base import BaseAWQForCausalLM from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJBlock + class GPTJAWQForCausalLM(BaseAWQForCausalLM): layer_type = "GPTJBlock" - max_new_tokens_key = "n_positions" + max_seq_len_key = "n_positions" @staticmethod def get_model_layers(model: GPTJForCausalLM): return model.transformer.h - + @staticmethod def get_act_for_scaling(module: GPTJBlock): return dict( is_scalable=True, scale_name="mlp.act", scale_layer=module.mlp.act, - scale_shape=module.mlp.fc_in.out_features + scale_shape=module.mlp.fc_in.out_features, ) - + @staticmethod def move_embed(model: GPTJForCausalLM, device: str): model.transformer.wte = model.transformer.wte.to(device) - + @staticmethod def get_layers_for_scaling(module: GPTJBlock, input_feat, module_kwargs): layers = [] # attention input + linear 1 - layers.append(dict( - prev_op=module.ln_1, - layers=[module.attn.q_proj, - module.attn.k_proj, module.attn.v_proj, module.mlp.fc_in], - inp=input_feat['attn.q_proj'], - module2inspect=module, - kwargs=module_kwargs - )) + layers.append( + dict( + prev_op=module.ln_1, + layers=[ + module.attn.q_proj, + module.attn.k_proj, + module.attn.v_proj, + module.mlp.fc_in, + ], + inp=input_feat["attn.q_proj"], + module2inspect=module, + kwargs=module_kwargs, + ) + ) # attention out - layers.append(dict( - prev_op=module.attn.v_proj, - layers=[module.attn.out_proj], - inp=input_feat['attn.out_proj'], - )) + layers.append( + dict( + prev_op=module.attn.v_proj, + layers=[module.attn.out_proj], + inp=input_feat["attn.out_proj"], + ) + ) # linear 2 - layers.append(dict( - prev_op=module.mlp.act, - layers=[module.mlp.fc_out], - inp=input_feat['mlp.fc_out'], - )) + layers.append( + dict( + prev_op=module.mlp.act, + layers=[module.mlp.fc_out], + inp=input_feat["mlp.fc_out"], + ) + ) - return layers \ No newline at end of file + return layers diff --git a/awq/models/llama.py b/awq/models/llama.py index 7390c58e..be6e8ecc 100644 --- a/awq/models/llama.py +++ b/awq/models/llama.py @@ -6,13 +6,14 @@ from awq.modules.fused.model import LlamaLikeModel from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer as OldLlamaDecoderLayer, - LlamaForCausalLM as OldLlamaForCausalLM + LlamaForCausalLM as OldLlamaForCausalLM, ) from awq.modules.fused.norm import FasterTransformerRMSNorm + class LlamaAWQForCausalLM(BaseAWQForCausalLM): layer_type = "LlamaDecoderLayer" - max_new_tokens_key = "max_position_embeddings" + max_seq_len_key = "max_position_embeddings" @staticmethod def fuse_layers(model: OldLlamaForCausalLM): @@ -22,53 +23,63 @@ def fuse_layers(model: OldLlamaForCausalLM): @staticmethod def get_model_layers(model: OldLlamaForCausalLM): return model.model.layers - + @staticmethod def get_act_for_scaling(module: OldLlamaDecoderLayer): - return dict( - is_scalable=False - ) - + return dict(is_scalable=False) + @staticmethod def move_embed(model: OldLlamaForCausalLM, device: str): model.model.embed_tokens = model.model.embed_tokens.to(device) - + @staticmethod def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs): layers = [] # attention input - layers.append(dict( - prev_op=module.input_layernorm, - layers=[module.self_attn.q_proj, - module.self_attn.k_proj, module.self_attn.v_proj], - inp=input_feat['self_attn.q_proj'], - module2inspect=module.self_attn, kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) # attention out # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: - layers.append(dict( - prev_op=module.self_attn.v_proj, - layers=[module.self_attn.o_proj], - inp=input_feat['self_attn.o_proj'], - )) - + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + # linear 1 - layers.append(dict( - prev_op=module.post_attention_layernorm, - layers=[module.mlp.gate_proj, module.mlp.up_proj], - inp=input_feat['mlp.gate_proj'], - module2inspect=module.mlp, - )) + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj, module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp, + ) + ) # linear 2 - layers.append(dict( - prev_op=module.mlp.up_proj, - layers=[module.mlp.down_proj], - inp=input_feat['mlp.down_proj'], - )) + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) return layers @@ -78,10 +89,11 @@ def __init__(self, model: OldLlamaForCausalLM): self.model = model self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [ - (name, module) for name, module in self.model.named_modules() - if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower() + (name, module) + for name, module in self.model.named_modules() + if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower() ] - + def fuse_transformer(self): blocks = [] @@ -92,34 +104,35 @@ def fuse_transformer(self): module, module.self_attn.q_proj, module.self_attn.k_proj, - module.self_attn.v_proj + module.self_attn.v_proj, ) norm_1 = FasterTransformerRMSNorm( - module.input_layernorm.weight, - module.input_layernorm.variance_epsilon + module.input_layernorm.weight, module.input_layernorm.variance_epsilon ) norm_2 = FasterTransformerRMSNorm( module.post_attention_layernorm.weight, - module.post_attention_layernorm.variance_epsilon + module.post_attention_layernorm.variance_epsilon, ) - blocks.append(LlamaLikeBlock( - hidden_size=self.model.config.hidden_size, - n_heads=self.model.config.num_attention_heads, - n_kv_heads=self.model.config.num_key_value_heads, - qkv_layer=qkv, - o_proj=module.self_attn.o_proj, - mlp=module.mlp, - norm_1=norm_1, - norm_2=norm_2, - dev=device, - max_seq_len=self.model.config.max_new_tokens, - rope_theta=self.model.config.rope_theta - )) - + blocks.append( + LlamaLikeBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=module.mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_seq_len, + rope_theta=self.model.config.rope_theta, + ) + ) + self.model.model = LlamaLikeModel( self.model.config.vocab_size, blocks, self.model.model.embed_tokens, self.model.model.norm, ) - setattr(self.model.model, "blocks", self.model.model.blocks) \ No newline at end of file + setattr(self.model.model, "blocks", self.model.model.blocks) diff --git a/awq/models/llava.py b/awq/models/llava.py index 5d3a2bfa..c6bd9efa 100644 --- a/awq/models/llava.py +++ b/awq/models/llava.py @@ -7,12 +7,15 @@ from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer as OldLlamaDecoderLayer, ) -from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration as OldLlavaForConditionalGeneration +from transformers.models.llava.modeling_llava import ( + LlavaForConditionalGeneration as OldLlavaForConditionalGeneration, +) from awq.modules.fused.norm import FasterTransformerRMSNorm + class LlavaAWQForCausalLM(BaseAWQForCausalLM): layer_type = "LlamaDecoderLayer" - max_new_tokens_key = "max_position_embeddings" + max_seq_len_key = "max_position_embeddings" @staticmethod def fuse_layers(model: OldLlavaForConditionalGeneration): @@ -22,53 +25,65 @@ def fuse_layers(model: OldLlavaForConditionalGeneration): @staticmethod def get_model_layers(model: OldLlavaForConditionalGeneration): return model.language_model.model.layers - + @staticmethod def get_act_for_scaling(module: OldLlamaDecoderLayer): - return dict( - is_scalable=False - ) - + return dict(is_scalable=False) + @staticmethod def move_embed(model: OldLlavaForConditionalGeneration, device: str): - model.language_model.model.embed_tokens = model.get_input_embeddings().to(device) - + model.language_model.model.embed_tokens = model.get_input_embeddings().to( + device + ) + @staticmethod def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs): layers = [] # attention input - layers.append(dict( - prev_op=module.input_layernorm, - layers=[module.self_attn.q_proj, - module.self_attn.k_proj, module.self_attn.v_proj], - inp=input_feat['self_attn.q_proj'], - module2inspect=module.self_attn, kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) # attention out # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: - layers.append(dict( - prev_op=module.self_attn.v_proj, - layers=[module.self_attn.o_proj], - inp=input_feat['self_attn.o_proj'], - )) - + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + # linear 1 - layers.append(dict( - prev_op=module.post_attention_layernorm, - layers=[module.mlp.gate_proj, module.mlp.up_proj], - inp=input_feat['mlp.gate_proj'], - module2inspect=module.mlp, - )) + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj, module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp, + ) + ) # linear 2 - layers.append(dict( - prev_op=module.mlp.up_proj, - layers=[module.mlp.down_proj], - inp=input_feat['mlp.down_proj'], - )) + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) return layers @@ -78,10 +93,11 @@ def __init__(self, model: OldLlavaForConditionalGeneration): self.model = model.language_model self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [ - (name, module) for name, module in self.model.named_modules() - if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower() + (name, module) + for name, module in self.model.named_modules() + if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower() ] - + def fuse_transformer(self): blocks = [] @@ -92,29 +108,30 @@ def fuse_transformer(self): module, module.self_attn.q_proj, module.self_attn.k_proj, - module.self_attn.v_proj + module.self_attn.v_proj, ) norm_1 = FasterTransformerRMSNorm( - module.input_layernorm.weight, - module.input_layernorm.variance_epsilon + module.input_layernorm.weight, module.input_layernorm.variance_epsilon ) norm_2 = FasterTransformerRMSNorm( module.post_attention_layernorm.weight, - module.post_attention_layernorm.variance_epsilon + module.post_attention_layernorm.variance_epsilon, ) - blocks.append(LlamaLikeBlock( - hidden_size=self.model.config.hidden_size, - n_heads=self.model.config.num_attention_heads, - n_kv_heads=self.model.config.num_key_value_heads, - qkv_layer=qkv, - o_proj=module.self_attn.o_proj, - mlp=module.mlp, - norm_1=norm_1, - norm_2=norm_2, - dev=device, - max_seq_len=self.model.config.max_new_tokens - )) - + blocks.append( + LlamaLikeBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=module.mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_seq_len, + ) + ) + self.model.model = LlamaLikeModel( self.model.config.vocab_size, blocks, diff --git a/awq/models/mistral.py b/awq/models/mistral.py index 9ddcf5b2..af4cd31a 100644 --- a/awq/models/mistral.py +++ b/awq/models/mistral.py @@ -6,13 +6,14 @@ from awq.modules.fused.model import LlamaLikeModel from transformers.models.mistral.modeling_mistral import ( MistralDecoderLayer as OldMistralDecoderLayer, - MistralForCausalLM as OldMistralForCausalLM + MistralForCausalLM as OldMistralForCausalLM, ) from awq.modules.fused.norm import FasterTransformerRMSNorm + class MistralAWQForCausalLM(BaseAWQForCausalLM): layer_type = "MistralDecoderLayer" - max_new_tokens_key = "max_position_embeddings" + max_seq_len_key = "max_position_embeddings" @staticmethod def fuse_layers(model: OldMistralForCausalLM): @@ -22,53 +23,65 @@ def fuse_layers(model: OldMistralForCausalLM): @staticmethod def get_model_layers(model: OldMistralForCausalLM): return model.model.layers - + @staticmethod def get_act_for_scaling(module: OldMistralDecoderLayer): - return dict( - is_scalable=False - ) - + return dict(is_scalable=False) + @staticmethod def move_embed(model: OldMistralForCausalLM, device: str): model.model.embed_tokens = model.model.embed_tokens.to(device) - + @staticmethod - def get_layers_for_scaling(module: OldMistralDecoderLayer, input_feat, module_kwargs): + def get_layers_for_scaling( + module: OldMistralDecoderLayer, input_feat, module_kwargs + ): layers = [] # attention input - layers.append(dict( - prev_op=module.input_layernorm, - layers=[module.self_attn.q_proj, - module.self_attn.k_proj, module.self_attn.v_proj], - inp=input_feat['self_attn.q_proj'], - module2inspect=module.self_attn, kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) # attention out # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: - layers.append(dict( - prev_op=module.self_attn.v_proj, - layers=[module.self_attn.o_proj], - inp=input_feat['self_attn.o_proj'], - )) - + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + # linear 1 - layers.append(dict( - prev_op=module.post_attention_layernorm, - layers=[module.mlp.gate_proj, module.mlp.up_proj], - inp=input_feat['mlp.gate_proj'], - module2inspect=module.mlp, - )) + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj, module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp, + ) + ) # linear 2 - layers.append(dict( - prev_op=module.mlp.up_proj, - layers=[module.mlp.down_proj], - inp=input_feat['mlp.down_proj'], - )) + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) return layers @@ -78,10 +91,11 @@ def __init__(self, model: OldMistralForCausalLM): self.model = model self.mistral_blocks: List[Tuple[str, OldMistralDecoderLayer]] = [ - (name, module) for name, module in self.model.named_modules() - if 'MistralDecoderLayer'.lower() in module.__class__.__name__.lower() + (name, module) + for name, module in self.model.named_modules() + if "MistralDecoderLayer".lower() in module.__class__.__name__.lower() ] - + def fuse_transformer(self): blocks = [] @@ -92,29 +106,30 @@ def fuse_transformer(self): module, module.self_attn.q_proj, module.self_attn.k_proj, - module.self_attn.v_proj + module.self_attn.v_proj, ) norm_1 = FasterTransformerRMSNorm( - module.input_layernorm.weight, - module.input_layernorm.variance_epsilon + module.input_layernorm.weight, module.input_layernorm.variance_epsilon ) norm_2 = FasterTransformerRMSNorm( module.post_attention_layernorm.weight, - module.post_attention_layernorm.variance_epsilon + module.post_attention_layernorm.variance_epsilon, ) - blocks.append(LlamaLikeBlock( - hidden_size=self.model.config.hidden_size, - n_heads=self.model.config.num_attention_heads, - n_kv_heads=self.model.config.num_key_value_heads, - qkv_layer=qkv, - o_proj=module.self_attn.o_proj, - mlp=module.mlp, - norm_1=norm_1, - norm_2=norm_2, - dev=device, - max_seq_len=self.model.config.max_new_tokens - )) - + blocks.append( + LlamaLikeBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=module.mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_seq_len, + ) + ) + self.model.model = LlamaLikeModel( self.model.config.vocab_size, blocks, diff --git a/awq/models/mixtral.py b/awq/models/mixtral.py index 8ca8c515..1b7e49dc 100644 --- a/awq/models/mixtral.py +++ b/awq/models/mixtral.py @@ -1,77 +1,95 @@ import tqdm +import torch from typing import List, Tuple from .base import BaseAWQForCausalLM -from awq.utils.fused_utils import fuse_qkv from awq.modules.fused.block import MixtralBlock from awq.modules.fused.model import MixtralModel +from awq.modules.fused.moe import FusedSparseMoeBlock +from awq.utils.fused_utils import fuse_qkv, fuse_linears from transformers.models.mixtral.modeling_mixtral import ( MixtralDecoderLayer as OldMixtralDecoderLayer, - MixtralForCausalLM as OldMixtralForCausalLM + MixtralForCausalLM as OldMixtralForCausalLM, ) +from awq.modules.linear import WQLinear_GEMM from awq.modules.fused.norm import FasterTransformerRMSNorm + class MixtralAWQForCausalLM(BaseAWQForCausalLM): layer_type = "MixtralDecoderLayer" - max_new_tokens_key = "max_position_embeddings" - + max_seq_len_key = "max_position_embeddings" + modules_to_not_convert = ["gate"] + @staticmethod def fuse_layers(model: OldMixtralForCausalLM): fuser = MixtralFuser(model) fuser.fuse_transformer() - + @staticmethod def get_model_layers(model: OldMixtralForCausalLM): return model.model.layers - + @staticmethod def get_act_for_scaling(module): - return dict( - is_scalable=False - ) - + return dict(is_scalable=False) + @staticmethod def move_embed(model: OldMixtralForCausalLM, device: str): model.model.embed_tokens = model.model.embed_tokens.to(device) - + @staticmethod - def get_layers_for_scaling(module: OldMixtralDecoderLayer, input_feat, module_kwargs): + def get_layers_for_scaling( + module: OldMixtralDecoderLayer, input_feat, module_kwargs + ): layers = [] # attention input - layers.append(dict( - prev_op=module.input_layernorm, - layers=[module.self_attn.q_proj, - module.self_attn.k_proj, module.self_attn.v_proj], - inp=input_feat['self_attn.q_proj'], - module2inspect=module.self_attn, kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) # attention out if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: - layers.append(dict( - prev_op=module.self_attn.v_proj, - layers=[module.self_attn.o_proj], - inp=input_feat['self_attn.o_proj'], - )) - + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + # linear in - layers.append(dict( - prev_op=module.post_attention_layernorm, - layers=[ - w for expert in module.block_sparse_moe.experts - for w in [expert.w1, expert.w3] - ], - inp=input_feat['block_sparse_moe'], - module2inspect=module.block_sparse_moe, - )) + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[ + w + for expert in module.block_sparse_moe.experts + for w in [expert.w1, expert.w3] + ], + inp=input_feat["block_sparse_moe"], + module2inspect=module.block_sparse_moe, + ) + ) # linear out for i, expert in enumerate(module.block_sparse_moe.experts): - layers.append(dict( - prev_op=expert.w3, - layers=[expert.w2], - inp=input_feat[f'block_sparse_moe.experts.{i}.w2'], - )) + layers.append( + dict( + prev_op=expert.w3, + layers=[expert.w2], + inp=input_feat[f"block_sparse_moe.experts.{i}.w2"], + ) + ) return layers @@ -81,49 +99,89 @@ def __init__(self, model: OldMixtralForCausalLM): self.model = model self.mixtral_blocks: List[Tuple[str, OldMixtralDecoderLayer]] = [ - (name, module) for name, module in self.model.named_modules() - if 'MixtralDecoderLayer'.lower() in module.__class__.__name__.lower() + (name, module) + for name, module in self.model.named_modules() + if "MixtralDecoderLayer".lower() in module.__class__.__name__.lower() ] - + def fuse_transformer(self): blocks = [] module: OldMixtralDecoderLayer for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): device = next(iter(module.state_dict().values())).device + qkv = fuse_qkv( module, module.self_attn.q_proj, module.self_attn.k_proj, - module.self_attn.v_proj + module.self_attn.v_proj, ) norm_1 = FasterTransformerRMSNorm( - module.input_layernorm.weight, - module.input_layernorm.variance_epsilon + module.input_layernorm.weight, module.input_layernorm.variance_epsilon ) + norm_2 = FasterTransformerRMSNorm( module.post_attention_layernorm.weight, - module.post_attention_layernorm.variance_epsilon + module.post_attention_layernorm.variance_epsilon, + ) + + sparse_moe = module.block_sparse_moe + if isinstance(sparse_moe.experts[0].w1, WQLinear_GEMM) and torch.cuda.device_count() == 1: + fused_w1w3s = [ + fuse_linears( + [ + sparse_moe.experts[i].w1, + sparse_moe.experts[i].w3, + ], + device, + ) + for i in range(len(sparse_moe.experts)) + ] + + stacked_w1w3s = fuse_linears( + fused_w1w3s, device, dim=0, operation=torch.stack + ) + + stacked_w2s = fuse_linears( + [expert.w2 for expert in sparse_moe.experts], + device, + dim=0, + operation=torch.stack, + ) + + sparse_moe = FusedSparseMoeBlock( + top_k=sparse_moe.top_k, + gate=sparse_moe.gate, + ws=stacked_w1w3s, + w2s=stacked_w2s, + ) + + blocks.append( + MixtralBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + moe=sparse_moe, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_seq_len, + rope_theta=self.model.config.rope_theta, + ) ) - blocks.append(MixtralBlock( - hidden_size=self.model.config.hidden_size, - n_heads=self.model.config.num_attention_heads, - n_kv_heads=self.model.config.num_key_value_heads, - qkv_layer=qkv, - o_proj=module.self_attn.o_proj, - moe=module.block_sparse_moe, - norm_1=norm_1, - norm_2=norm_2, - dev=device, - max_seq_len=self.model.config.max_new_tokens, - rope_theta=self.model.config.rope_theta - )) + model_norm = FasterTransformerRMSNorm( + self.model.model.norm.weight, + self.model.model.norm.variance_epsilon, + ) + self.model.model = MixtralModel( self.model.config.vocab_size, blocks, self.model.model.embed_tokens, - self.model.model.norm, + model_norm, ) setattr(self.model.model, "blocks", self.model.model.blocks) - diff --git a/awq/models/mpt.py b/awq/models/mpt.py index 5bc8ccd6..b042e162 100644 --- a/awq/models/mpt.py +++ b/awq/models/mpt.py @@ -1,9 +1,10 @@ from .base import BaseAWQForCausalLM from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM + class MptAWQForCausalLM(BaseAWQForCausalLM): layer_type = "MPTBlock" - max_new_tokens_key = "max_seq_len" + max_seq_len_key = "max_seq_len" @staticmethod def fuse_layers(model: MptForCausalLM): @@ -13,73 +14,84 @@ def fuse_layers(model: MptForCausalLM): @staticmethod def get_model_layers(model: MptForCausalLM): return model.transformer.blocks - + @staticmethod def get_act_for_scaling(module: OldMptBlock): return dict( is_scalable=True, scale_name="ffn.act", scale_layer=module.ffn.act, - scale_shape=module.ffn.up_proj.out_features + scale_shape=module.ffn.up_proj.out_features, ) - + @staticmethod def move_embed(model: MptForCausalLM, device: str): model.transformer.wte = model.transformer.wte.to(device) model.transformer.emb_drop = model.transformer.emb_drop.to(device) - + @staticmethod def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs): layers = [] - + if module_kwargs.get("output_attentions") is not None: module_kwargs.pop("output_attentions") # attention input - layers.append(dict( - prev_op=module.norm_1, - layers=[module.attn.Wqkv], - inp=input_feat['attn.Wqkv'], - module2inspect=module.attn, - kwargs=module_kwargs - )) + layers.append( + dict( + prev_op=module.norm_1, + layers=[module.attn.Wqkv], + inp=input_feat["attn.Wqkv"], + module2inspect=module.attn, + kwargs=module_kwargs, + ) + ) # attention output - layers.append(dict( - prev_op=module.attn.Wqkv, - layers=[module.attn.out_proj], - inp=input_feat['attn.out_proj'] - )) + layers.append( + dict( + prev_op=module.attn.Wqkv, + layers=[module.attn.out_proj], + inp=input_feat["attn.out_proj"], + ) + ) # linear 1 - layers.append(dict( - prev_op=module.norm_2, - layers=[module.ffn.up_proj], - inp=input_feat['ffn.up_proj'], - module2inspect=module.ffn - )) + layers.append( + dict( + prev_op=module.norm_2, + layers=[module.ffn.up_proj], + inp=input_feat["ffn.up_proj"], + module2inspect=module.ffn, + ) + ) # linear 2 - layers.append(dict( - prev_op=module.ffn.act, - layers=[module.ffn.down_proj], - inp=input_feat['ffn.down_proj'] - )) + layers.append( + dict( + prev_op=module.ffn.act, + layers=[module.ffn.down_proj], + inp=input_feat["ffn.down_proj"], + ) + ) return layers + from typing import List, Tuple from awq.utils.utils import set_module_name from awq.modules.fused.block import MPTBlock from awq.modules.fused.model import MPTModel + class MptFuser: def __init__(self, model: MptForCausalLM): self.model = model self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [ - (name, module) for name, module in self.model.named_modules() - if 'mptblock' in module.__class__.__name__.lower() + (name, module) + for name, module in self.model.named_modules() + if "mptblock" in module.__class__.__name__.lower() ] def fuse_transformer(self): @@ -87,17 +99,19 @@ def fuse_transformer(self): module: OldMptBlock for module in self.model.transformer.blocks: - blocks.append(MPTBlock( - self.model.config.d_model, - self.model.config.n_heads, - module.attn.Wqkv, - module.attn.out_proj, - module.ffn, - module.norm_1, - module.norm_2, - next(iter(module.state_dict().values())).device, - self.model.config.max_new_tokens - )) + blocks.append( + MPTBlock( + self.model.config.d_model, + self.model.config.n_heads, + module.attn.Wqkv, + module.attn.out_proj, + module.ffn, + module.norm_1, + module.norm_2, + next(iter(module.state_dict().values())).device, + self.model.config.max_seq_len, + ) + ) self.model.transformer = MPTModel( self.model.config.vocab_size, @@ -106,4 +120,4 @@ def fuse_transformer(self): self.model.transformer.norm_f, ) - setattr(self.model.transformer, "blocks", self.model.transformer.blocks) \ No newline at end of file + setattr(self.model.transformer, "blocks", self.model.transformer.blocks) diff --git a/awq/models/opt.py b/awq/models/opt.py index c622f96b..3209d0c1 100644 --- a/awq/models/opt.py +++ b/awq/models/opt.py @@ -1,59 +1,70 @@ from .base import BaseAWQForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer + class OptAWQForCausalLM(BaseAWQForCausalLM): layer_type = "OPTDecoderLayer" - max_new_tokens_key = "max_position_embeddings" + max_seq_len_key = "max_position_embeddings" @staticmethod def get_model_layers(model: OPTForCausalLM): return model.model.decoder.layers - + @staticmethod def get_act_for_scaling(module: OPTDecoderLayer): - return dict( - is_scalable=False - ) - + return dict(is_scalable=False) + @staticmethod def move_embed(model: OPTForCausalLM, device: str): model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device) - model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device) - + model.model.decoder.embed_positions = model.model.decoder.embed_positions.to( + device + ) + @staticmethod def get_layers_for_scaling(module: OPTDecoderLayer, input_feat, module_kwargs): layers = [] # attention input - layers.append(dict( - prev_op=module.self_attn_layer_norm, - layers=[ - module.self_attn.q_proj, - module.self_attn.k_proj, module.self_attn.v_proj], - inp=input_feat['self_attn.q_proj'], - module2inspect=module.self_attn, - kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.self_attn_layer_norm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) # attention out - layers.append(dict( - prev_op=module.self_attn.v_proj, - layers=[module.self_attn.out_proj], - inp=input_feat['self_attn.out_proj'], - )) + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.out_proj], + inp=input_feat["self_attn.out_proj"], + ) + ) # linear 1 - layers.append(dict( - prev_op=module.final_layer_norm, - layers=[module.fc1], - inp=input_feat['fc1'], - )) + layers.append( + dict( + prev_op=module.final_layer_norm, + layers=[module.fc1], + inp=input_feat["fc1"], + ) + ) # linear 2 - layers.append(dict( - prev_op=module.fc1, - layers=[module.fc2], - inp=input_feat['fc2'], - )) + layers.append( + dict( + prev_op=module.fc1, + layers=[module.fc2], + inp=input_feat["fc2"], + ) + ) - return layers \ No newline at end of file + return layers diff --git a/awq/models/qwen.py b/awq/models/qwen.py index 90d02eee..2e699d12 100644 --- a/awq/models/qwen.py +++ b/awq/models/qwen.py @@ -3,7 +3,7 @@ class QwenAWQForCausalLM(BaseAWQForCausalLM): layer_type = "QWenBlock" - max_new_tokens_key = "seq_length" + max_seq_len_key = "seq_length" @staticmethod def get_model_layers(model): diff --git a/awq/models/qwen2.py b/awq/models/qwen2.py index 1d9bc8a1..38994bf5 100644 --- a/awq/models/qwen2.py +++ b/awq/models/qwen2.py @@ -6,14 +6,14 @@ from awq.modules.fused.model import LlamaLikeModel from transformers.models.qwen2.modeling_qwen2 import ( Qwen2DecoderLayer as OldQwen2DecoderLayer, - Qwen2ForCausalLM as OldQwen2ForCausalLM + Qwen2ForCausalLM as OldQwen2ForCausalLM, ) from awq.modules.fused.norm import FasterTransformerRMSNorm class Qwen2AWQForCausalLM(BaseAWQForCausalLM): layer_type = "Qwen2DecoderLayer" - max_new_tokens_key = "max_position_embeddings" + max_seq_len_key = "max_position_embeddings" @staticmethod def fuse_layers(model: OldQwen2ForCausalLM): @@ -26,9 +26,7 @@ def get_model_layers(model: OldQwen2ForCausalLM): @staticmethod def get_act_for_scaling(module: OldQwen2DecoderLayer): - return dict( - is_scalable=False - ) + return dict(is_scalable=False) @staticmethod def move_embed(model: OldQwen2ForCausalLM, device: str): @@ -39,37 +37,49 @@ def get_layers_for_scaling(module: OldQwen2DecoderLayer, input_feat, module_kwar layers = [] # attention input - layers.append(dict( - prev_op=module.input_layernorm, - layers=[module.self_attn.q_proj, - module.self_attn.k_proj, module.self_attn.v_proj], - inp=input_feat['self_attn.q_proj'], - module2inspect=module.self_attn, kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) # attention out # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: - layers.append(dict( - prev_op=module.self_attn.v_proj, - layers=[module.self_attn.o_proj], - inp=input_feat['self_attn.o_proj'], - )) + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) # linear 1 - layers.append(dict( - prev_op=module.post_attention_layernorm, - layers=[module.mlp.gate_proj, module.mlp.up_proj], - inp=input_feat['mlp.gate_proj'], - module2inspect=module.mlp, - )) + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj, module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp, + ) + ) # linear 2 - layers.append(dict( - prev_op=module.mlp.up_proj, - layers=[module.mlp.down_proj], - inp=input_feat['mlp.down_proj'], - )) + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) return layers @@ -79,8 +89,9 @@ def __init__(self, model: OldQwen2ForCausalLM): self.model = model self.qwen2_blocks: List[Tuple[str, OldQwen2DecoderLayer]] = [ - (name, module) for name, module in self.model.named_modules() - if 'Qwen2DecoderLayer'.lower() in module.__class__.__name__.lower() + (name, module) + for name, module in self.model.named_modules() + if "Qwen2DecoderLayer".lower() in module.__class__.__name__.lower() ] def fuse_transformer(self): @@ -93,28 +104,29 @@ def fuse_transformer(self): module, module.self_attn.q_proj, module.self_attn.k_proj, - module.self_attn.v_proj + module.self_attn.v_proj, ) norm_1 = FasterTransformerRMSNorm( - module.input_layernorm.weight, - module.input_layernorm.variance_epsilon + module.input_layernorm.weight, module.input_layernorm.variance_epsilon ) norm_2 = FasterTransformerRMSNorm( module.post_attention_layernorm.weight, - module.post_attention_layernorm.variance_epsilon + module.post_attention_layernorm.variance_epsilon, + ) + blocks.append( + LlamaLikeBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=module.mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_seq_len, + ) ) - blocks.append(LlamaLikeBlock( - hidden_size=self.model.config.hidden_size, - n_heads=self.model.config.num_attention_heads, - n_kv_heads=self.model.config.num_key_value_heads, - qkv_layer=qkv, - o_proj=module.self_attn.o_proj, - mlp=module.mlp, - norm_1=norm_1, - norm_2=norm_2, - dev=device, - max_seq_len=self.model.config.max_new_tokens - )) self.model.model = LlamaLikeModel( self.model.config.vocab_size, diff --git a/awq/models/yi.py b/awq/models/yi.py index 7e61dbbc..3237c22d 100644 --- a/awq/models/yi.py +++ b/awq/models/yi.py @@ -6,9 +6,10 @@ from awq.modules.fused.model import LlamaLikeModel from awq.modules.fused.norm import FasterTransformerRMSNorm + class YiAWQForCausalLM(BaseAWQForCausalLM): layer_type = "YiDecoderLayer" - max_new_tokens_key = "max_position_embeddings" + max_seq_len_key = "max_position_embeddings" @staticmethod def fuse_layers(model): @@ -18,53 +19,63 @@ def fuse_layers(model): @staticmethod def get_model_layers(model): return model.model.layers - + @staticmethod def get_act_for_scaling(module): - return dict( - is_scalable=False - ) - + return dict(is_scalable=False) + @staticmethod def move_embed(model, device: str): model.model.embed_tokens = model.model.embed_tokens.to(device) - + @staticmethod def get_layers_for_scaling(module, input_feat, module_kwargs): layers = [] # attention input - layers.append(dict( - prev_op=module.ln1, - layers=[module.self_attn.q_proj, - module.self_attn.k_proj, module.self_attn.v_proj], - inp=input_feat['self_attn.q_proj'], - module2inspect=module.self_attn, kwargs=module_kwargs, - )) + layers.append( + dict( + prev_op=module.ln1, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) # attention out # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: - layers.append(dict( - prev_op=module.self_attn.v_proj, - layers=[module.self_attn.o_proj], - inp=input_feat['self_attn.o_proj'], - )) - + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + # linear 1 - layers.append(dict( - prev_op=module.ln2, - layers=[module.mlp.gate_proj, module.mlp.up_proj], - inp=input_feat['mlp.gate_proj'], - module2inspect=module.mlp, - )) + layers.append( + dict( + prev_op=module.ln2, + layers=[module.mlp.gate_proj, module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp, + ) + ) # linear 2 - layers.append(dict( - prev_op=module.mlp.up_proj, - layers=[module.mlp.down_proj], - inp=input_feat['mlp.down_proj'], - )) + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) return layers @@ -74,10 +85,11 @@ def __init__(self, model): self.model = model self.yi_blocks: List[Tuple[str, object]] = [ - (name, module) for name, module in self.model.named_modules() - if 'YiDecoderLayer'.lower() in module.__class__.__name__.lower() + (name, module) + for name, module in self.model.named_modules() + if "YiDecoderLayer".lower() in module.__class__.__name__.lower() ] - + def fuse_transformer(self): blocks = [] @@ -87,30 +99,30 @@ def fuse_transformer(self): module, module.self_attn.q_proj, module.self_attn.k_proj, - module.self_attn.v_proj + module.self_attn.v_proj, ) norm_1 = FasterTransformerRMSNorm( - module.ln1.weight, - module.ln1.variance_epsilon + module.ln1.weight, module.ln1.variance_epsilon ) norm_2 = FasterTransformerRMSNorm( - module.ln2.weight, - module.ln2.variance_epsilon + module.ln2.weight, module.ln2.variance_epsilon ) - blocks.append(LlamaLikeBlock( - hidden_size=self.model.config.hidden_size, - n_heads=self.model.config.num_attention_heads, - n_kv_heads=self.model.config.num_key_value_heads, - qkv_layer=qkv, - o_proj=module.self_attn.o_proj, - mlp=module.mlp, - norm_1=norm_1, - norm_2=norm_2, - dev=device, - max_seq_len=self.model.config.max_new_tokens, - rope_theta=self.model.config.rope_theta - )) - + blocks.append( + LlamaLikeBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=module.mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_seq_len, + rope_theta=self.model.config.rope_theta, + ) + ) + self.model.model = LlamaLikeModel( self.model.config.vocab_size, blocks, diff --git a/awq/modules/act.py b/awq/modules/act.py index 48305086..8ffb0c5a 100644 --- a/awq/modules/act.py +++ b/awq/modules/act.py @@ -1,10 +1,11 @@ import torch.nn as nn + class ScaledActivation(nn.Module): def __init__(self, module, scales): super().__init__() self.act = module self.scales = nn.Parameter(scales.data) - + def forward(self, x): return self.act(x) / self.scales.view(1, 1, -1).to(x.device) diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index 081dd8fb..f90fd502 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -9,6 +9,7 @@ try: import awq_ft_ext + FT_INSTALLED = True except: FT_INSTALLED = False @@ -16,6 +17,7 @@ HF_NEW_CACHE_FORMAT = False import transformers + # https://github.com/huggingface/transformers/pull/26681 introduced a new cache format HF_NEW_CACHE_FORMAT = hasattr(transformers, "cache_utils") if HF_NEW_CACHE_FORMAT: @@ -25,12 +27,12 @@ class RoPE(nn.Module): def __init__(self, hidden_size, n_heads, max_seq_len, device, rope_theta): super(RoPE, self).__init__() - + self.freqs_cis = nn.Parameter( self.precompute_freqs_cis( hidden_size // n_heads, max_seq_len * 2, rope_theta ).to(device), - requires_grad=False + requires_grad=False, ) @staticmethod @@ -58,18 +60,21 @@ def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: in ) freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device) - + xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) - + return xq_out.type_as(xq), xk_out.type_as(xk) + class ALiBi(nn.Module): def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8): super(ALiBi, self).__init__() - + # Initialize ALiBi slopes and bias - slopes, bias = self.build_alibi_bias(n_heads, max_seq_len, alibi_bias_max=alibi_bias_max) + slopes, bias = self.build_alibi_bias( + n_heads, max_seq_len, alibi_bias_max=alibi_bias_max + ) self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False) self.bias = nn.Parameter(bias.float().to(device), requires_grad=False) @@ -79,27 +84,42 @@ def gen_slopes(n_heads, alibi_bias_max=8): m = torch.arange(1, _n_heads + 1, dtype=torch.float32) m = m.mul(alibi_bias_max / _n_heads) slopes = 1.0 / torch.pow(2, m) - + if _n_heads != n_heads: slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads] - + return slopes.view(1, n_heads, 1, 1) @staticmethod def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32): - alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len) + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view( + 1, 1, 1, seq_len + ) slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max) alibi_bias = alibi_bias * slopes slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1) return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype) - + def forward(self, scores, seqlen): scores += self.bias[..., :seqlen] return scores + class QuantAttentionFused(nn.Module): - def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len, - use_alibi=False, attention_shapes=None, rope_theta=10000): + def __init__( + self, + hidden_size, + n_heads, + n_kv_heads, + qkv_layer, + o_proj, + dev, + max_seq_len=2048, + use_alibi=False, + attention_shapes=None, + rope_theta=10000, + **kwargs + ): super().__init__() self.hidden_size = hidden_size self.n_heads = n_heads @@ -111,17 +131,29 @@ def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max self.start_pos = 0 self.use_alibi = use_alibi self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) + + if kwargs.get("max_new_tokens") is not None: + max_seq_len = kwargs["max_new_tokens"] + self.max_seq_len = max_seq_len self.is_hf_transformers = False self.rope_theta = rope_theta # attention shapes for self attention self.attention_shapes = get_attention_shapes( - attention_shapes, max_seq_len, self.cache_batch_size, n_heads, n_kv_heads, self.head_dim + attention_shapes, + max_seq_len, + self.cache_batch_size, + n_heads, + n_kv_heads, + self.head_dim, ) # cache store that rolls cache self.cache = WindowedCache( - self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], self.max_seq_len, dev + self.attention_shapes["cache_v"], + self.attention_shapes["cache_k"], + self.max_seq_len, + dev, ) if use_alibi: @@ -133,8 +165,10 @@ def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev, rope_theta) self.rotary_dim = self.head_dim self.is_neox = True - - def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs): + + def forward( + self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs + ): bsz, seqlen, _ = hidden_states.shape # Reallocate cache if batch size changes @@ -147,18 +181,22 @@ def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwar self.cache_batch_size = bsz # Always reset to 0 - self.start_pos = 0 + self.start_pos = 0 - # In case we re-generate, we need to refresh the starting position - # to 0. We detect it by checking if `past_key_values` is set to None, + # In case we re-generate, we need to refresh the starting position + # to 0. We detect it by checking if `past_key_values` is set to None, # which indicates that we are on the first step of `generate()`. # This is only applicable for `transformers` integration - if self.is_hf_transformers and "past_key_value" in kwargs and kwargs["past_key_value"] is None: + if ( + self.is_hf_transformers + and "past_key_value" in kwargs + and kwargs["past_key_value"] is None + ): self.start_pos = 0 xqkv = self.qkv_proj(hidden_states) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) - + xq = self.attention_shapes["xq_slice"](xqkv) xk = self.attention_shapes["xk_slice"](xqkv) xv = self.attention_shapes["xv_slice"](xqkv) @@ -179,21 +217,22 @@ def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwar .permute(0, 2, 3, 1, 4) .contiguous() ) - + self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen) # Only necessary to retrieve from cache when we are not processing context if seqlen == 1: xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim) - keys = xk values = xv if self.n_kv_groups != 0: keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups) - values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups) - + values = torch.repeat_interleave( + values, dim=2, repeats=self.n_kv_groups + ) + xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) values = values.transpose(1, 2) @@ -204,7 +243,9 @@ def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwar # When seqlen is 1, there is nothing else to attend to if attention_mask is not None and seqlen > 1: - scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen) + scores = ( + scores + attention_mask + ) # (bs, n_local_heads, slen, cache_len + slen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) @@ -215,25 +256,25 @@ def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwar alibi_slopes = self.alibi.slopes if self.alibi is not None else None attention_weight = awq_ft_ext.single_query_attention( - xq, # query - xk, # key - xv, # value - self.cache.k, # key cache - self.cache.v, # value cache - None, # length per sample - alibi_slopes, # alibi slopes - self.start_pos, # timestep - self.rotary_dim, # rotary embedding dimension - self.rope_theta, # rotary embedding base - self.is_neox, # is neox + xq, # query + xk, # key + xv, # value + self.cache.k, # key cache + self.cache.v, # value cache + None, # length per sample + alibi_slopes, # alibi slopes + self.start_pos, # timestep + self.rotary_dim, # rotary embedding dimension + self.rope_theta, # rotary embedding base + self.is_neox, # is neox ) attention_weight = attention_weight.reshape(bsz, 1, -1) - + attn_output = self.o_proj(attention_weight) self.start_pos += seqlen # past_key_value is replaced with cache_v, cache_k, returning empty data - # we pass a dummy past kv cache for transformers to be able to retrieve the correct info + # we pass a dummy past kv cache for transformers to be able to retrieve the correct info # about past key length past_key_value = [torch.zeros(1, 1, self.start_pos, 1)] diff --git a/awq/modules/fused/block.py b/awq/modules/fused/block.py index 47c061f4..0ffc4b93 100644 --- a/awq/modules/fused/block.py +++ b/awq/modules/fused/block.py @@ -2,10 +2,21 @@ import torch.nn as nn from awq.modules.fused.attn import QuantAttentionFused + class MixtralBlock(nn.Module): def __init__( - self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, - moe, norm_1, norm_2, dev, max_seq_len, rope_theta + self, + hidden_size, + n_heads, + n_kv_heads, + qkv_layer, + o_proj, + moe, + norm_1, + norm_2, + dev, + max_seq_len, + rope_theta, ): super().__init__() self.n_heads = n_heads @@ -13,37 +24,62 @@ def __init__( self.hidden_size = hidden_size self.norm_1 = norm_1.to(dev) self.attn = QuantAttentionFused( - self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, - dev=dev, max_seq_len=max_seq_len, use_alibi=False, rope_theta=rope_theta + self.hidden_size, + self.n_heads, + self.n_kv_heads, + qkv_layer, + o_proj, + dev=dev, + max_seq_len=max_seq_len, + use_alibi=False, + rope_theta=rope_theta, ).to(dev) self.norm_2 = norm_2.to(dev) self.moe = moe self.device = dev def forward( - self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None + self, + hidden_states, + past_key_value, + attn_bias=None, + attention_mask=None, + is_causal=None, ): norm_out = self.norm_1(hidden_states) attn_output, _, past_key_value = self.attn.forward( hidden_states=norm_out, past_key_value=past_key_value, - attention_mask=attention_mask + attention_mask=attention_mask, ) h = hidden_states.to(attn_output.device) + attn_output - out, _ = self.moe.forward(self.norm_2(h)) + out = self.moe.forward(self.norm_2(h)) out = h + out return out, None, past_key_value + class LlamaLikeBlock(nn.Module): """ LlamaLikeBlock is intended to be reused across blocks that have an architecture that closely resembles Llama, e.g. Mistral and Aquila. """ + def __init__( - self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, - mlp, norm_1, norm_2, dev, max_seq_len, rope_theta=10000, use_alibi=False + self, + hidden_size, + n_heads, + n_kv_heads, + qkv_layer, + o_proj, + mlp, + norm_1, + norm_2, + dev, + max_seq_len, + rope_theta=10000, + use_alibi=False, ): super().__init__() self.n_heads = n_heads @@ -51,21 +87,33 @@ def __init__( self.hidden_size = hidden_size self.norm_1 = norm_1.to(dev) self.attn = QuantAttentionFused( - self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, - dev=dev, max_seq_len=max_seq_len, use_alibi=use_alibi, rope_theta=rope_theta + self.hidden_size, + self.n_heads, + self.n_kv_heads, + qkv_layer, + o_proj, + dev=dev, + max_seq_len=max_seq_len, + use_alibi=use_alibi, + rope_theta=rope_theta, ).to(dev) self.norm_2 = norm_2.to(dev) self.mlp = mlp.to(dev) self.device = dev def forward( - self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None + self, + hidden_states, + past_key_value, + attn_bias=None, + attention_mask=None, + is_causal=None, ): norm_out = self.norm_1(hidden_states) attn_output, _, past_key_value = self.attn.forward( hidden_states=norm_out, past_key_value=past_key_value, - attention_mask=attention_mask + attention_mask=attention_mask, ) h = hidden_states.to(attn_output.device) + attn_output @@ -73,23 +121,46 @@ def forward( return out, None, past_key_value + class MPTBlock(nn.Module): - def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len): + def __init__( + self, + hidden_size, + n_heads, + qkv_layer, + o_proj, + mpt_mlp, + norm_1, + norm_2, + dev, + max_seq_len, + ): super().__init__() self.n_heads = n_heads self.n_kv_heads = 0 self.hidden_size = hidden_size self.norm_1 = norm_1 self.attn = QuantAttentionFused( - hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, - dev=dev, max_seq_len=max_seq_len, use_alibi=True + hidden_size, + self.n_heads, + self.n_kv_heads, + qkv_layer, + o_proj, + dev=dev, + max_seq_len=max_seq_len, + use_alibi=True, ).to(dev) self.norm_2 = norm_2 self.ffn = mpt_mlp.to(dev) self.device = dev def forward( - self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None + self, + hidden_states, + past_key_value, + attn_bias=None, + attention_mask=None, + is_causal=None, ): norm_out = self.norm_1(hidden_states) attn_output, _, past_key_value = self.attn.forward( @@ -98,16 +169,29 @@ def forward( attention_mask=attention_mask, position_ids=None, output_attentions=False, - use_cache=True + use_cache=True, ) h = hidden_states.to(attn_output.device) + attn_output out = h + self.ffn.forward(self.norm_2(h)) return out, None, past_key_value + class FalconDecoderLayer(nn.Module): - def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, - input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True): + def __init__( + self, + hidden_size, + n_heads, + qkv_layer, + o_proj, + mlp, + dev, + max_seq_len, + input_layernorm=None, + ln_attn=None, + ln_mlp=None, + new_decoder_arch=True, + ): super().__init__() self.n_heads = n_heads self.n_kv_heads = 8 if new_decoder_arch else 0 @@ -117,33 +201,52 @@ def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_le if new_decoder_arch: attention_shapes = None else: - attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads) - + attention_shapes = self._get_attention_shapes( + n_heads, max_seq_len, self.hidden_size // n_heads + ) + # TODO: Falcon has ALiBi implemented but which model uses it? self.attn = QuantAttentionFused( - hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, - dev=dev, max_seq_len=max_seq_len, use_alibi=False, - attention_shapes=attention_shapes + hidden_size, + self.n_heads, + self.n_kv_heads, + qkv_layer, + o_proj, + dev=dev, + max_seq_len=max_seq_len, + use_alibi=False, + attention_shapes=attention_shapes, ).to(dev) - + if new_decoder_arch: - self.ln_attn = ln_attn # before attention - self.ln_mlp = ln_mlp # before mlp + self.ln_attn = ln_attn # before attention + self.ln_mlp = ln_mlp # before mlp else: - self.input_layernorm = input_layernorm # before attention - + self.input_layernorm = input_layernorm # before attention + self.mlp = mlp self.device = dev - + def _get_attention_shapes(self, n_heads, max_seq_len, head_dim): batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) - + self.attention_shapes = { # following fastertransformer definition - "cache_v": (batch_size, 1, max_seq_len, head_dim,), + "cache_v": ( + batch_size, + 1, + max_seq_len, + head_dim, + ), # 8: pack 8 fp16 in FT, if fp32 then use 4 - "cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,), - "xqkv_view": (n_heads+2, head_dim), + "cache_k": ( + batch_size, + 1, + head_dim // 8, + max_seq_len, + 8, + ), + "xqkv_view": (n_heads + 2, head_dim), "xq_slice": lambda xqkv: xqkv[:, :, :-2], "xk_slice": lambda xqkv: xqkv[:, :, [-2]], "xv_slice": lambda xqkv: xqkv[:, :, [-1]], @@ -153,27 +256,32 @@ def _get_attention_shapes(self, n_heads, max_seq_len, head_dim): "xk_reshape": (1, head_dim // 8, 8), "single_xq_view": (n_heads, head_dim), "single_xk_view": (1, head_dim), - "single_xv_view": (1, head_dim) + "single_xv_view": (1, head_dim), } return self.attention_shapes def forward( - self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None + self, + hidden_states, + past_key_value, + attn_bias=None, + attention_mask=None, + is_causal=None, ): if self.new_decoder_arch: layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) else: layernorm_out = self.input_layernorm(hidden_states) - + attn_output, _, past_key_value = self.attn.forward( hidden_states=layernorm_out, past_key_value=past_key_value, attention_mask=attention_mask, position_ids=None, output_attentions=False, - use_cache=True + use_cache=True, ) h_attn = hidden_states.to(attn_output.device) + attn_output @@ -182,7 +290,7 @@ def forward( h_mlp = self.mlp.forward(mlp_layernorm_out) else: h_mlp = self.mlp.forward(layernorm_out) - + out = h_attn + h_mlp - - return out, None, past_key_value \ No newline at end of file + + return out, None, past_key_value diff --git a/awq/modules/fused/cache.py b/awq/modules/fused/cache.py index cfbaa0d5..87943f59 100644 --- a/awq/modules/fused/cache.py +++ b/awq/modules/fused/cache.py @@ -1,27 +1,34 @@ import torch + class WindowedCache: def __init__(self, cache_v_shape, cache_k_shape, max_seq_len, device): """ - The window size is the same as the max_new_tokens. The window will - automatically roll once max_new_tokens is exceeded. + The window size is the same as the max_seq_len. The window will + automatically roll once max_seq_len is exceeded. """ # [batch_size, n_kv_heads, max_seq_len, head_dim] self.v = torch.zeros(cache_v_shape).to(device).half() # [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor] self.k = torch.zeros(cache_k_shape).to(device).half() self.max_seq_len = max_seq_len - + def get_kv(self, batch_size, start_pos, seqlen, head_dim): """ Gets the key-value store in correct shapes. """ - xv = self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous() - xk = self.k[:batch_size, :, :, : start_pos + seqlen, :].transpose(2, 3).contiguous() + xv = ( + self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous() + ) + xk = ( + self.k[:batch_size, :, :, : start_pos + seqlen, :] + .transpose(2, 3) + .contiguous() + ) xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous() return xv, xk - + def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen): """ Updates the values in the key-value store. @@ -41,19 +48,23 @@ def roll_kv_n_steps(self, start_pos, n=100): # Zero out the new part self.v[:, :, -n:, :] = 0 self.k[:, :, :, -n:, :] = 0 - + return start_pos - n - + def to(self, device): self.k = self.k.to(device) self.v = self.v.to(device) - + def increase_batch_size(self, to_bsz): """Dynamically allocate new kv when batch size changes.""" - self.v = torch.zeros(to_bsz, *self.v.shape[1:], dtype=self.v.dtype, device=self.v.device) - self.k = torch.zeros(to_bsz, *self.k.shape[1:], dtype=self.k.dtype, device=self.k.device) + self.v = torch.zeros( + to_bsz, *self.v.shape[1:], dtype=self.v.dtype, device=self.v.device + ) + self.k = torch.zeros( + to_bsz, *self.k.shape[1:], dtype=self.k.dtype, device=self.k.device + ) def decrease_batch_size(self, to_bsz): """Dynamically remove part of cache if batch size changes.""" self.v = self.v[:to_bsz, :, :, :] - self.k = self.k[:to_bsz, :, :, :, :] \ No newline at end of file + self.k = self.k[:to_bsz, :, :, :, :] diff --git a/awq/modules/fused/mlp.py b/awq/modules/fused/mlp.py index afd70d25..9236109b 100644 --- a/awq/modules/fused/mlp.py +++ b/awq/modules/fused/mlp.py @@ -5,26 +5,28 @@ try: import awq_ext # with CUDA kernels + AWQ_INSTALLED = True except: AWQ_INSTALLED = False + class QuantFusedMLP(nn.Module): def __init__( self, gate_proj, down_proj, up_proj, - activation = F.silu, + activation=F.silu, ): super().__init__() - self.register_buffer('gate_proj_qweight', gate_proj.qweight) - self.register_buffer('gate_proj_scales', gate_proj.scales) - self.register_buffer('gate_proj_qzeros', gate_proj.qzeros) - self.register_buffer('up_proj_qweight', up_proj.qweight) - self.register_buffer('up_proj_scales', up_proj.scales) - self.register_buffer('up_proj_qzeros', up_proj.qzeros) + self.register_buffer("gate_proj_qweight", gate_proj.qweight) + self.register_buffer("gate_proj_scales", gate_proj.scales) + self.register_buffer("gate_proj_qzeros", gate_proj.qzeros) + self.register_buffer("up_proj_qweight", up_proj.qweight) + self.register_buffer("up_proj_scales", up_proj.scales) + self.register_buffer("up_proj_qzeros", up_proj.qzeros) self.in_features = gate_proj.in_features self.intermediate_size = gate_proj.out_features @@ -66,17 +68,13 @@ def forward(self, x, routing_weights=None): x = routing_weights * x return x - + class QuantLlamaMLP(QuantFusedMLP): r""" - QuantLlamaMLP class kept for backward compatibilty, in the future, users + QuantLlamaMLP class kept for backward compatibilty, in the future, users should always use `QuantFusedMLP` class instead. """ - def __init__( - self, - gate_proj, - down_proj, - up_proj - ): - super().__init__(gate_proj, down_proj, up_proj) \ No newline at end of file + + def __init__(self, gate_proj, down_proj, up_proj): + super().__init__(gate_proj, down_proj, up_proj) diff --git a/awq/modules/fused/model.py b/awq/modules/fused/model.py index 50bd1bb6..c02233f6 100644 --- a/awq/modules/fused/model.py +++ b/awq/modules/fused/model.py @@ -2,8 +2,16 @@ import torch.nn as nn from typing import List from awq.utils import fused_utils -from transformers.modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast -from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock, MixtralBlock +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + MoeModelOutputWithPast, +) +from awq.modules.fused.block import ( + MPTBlock, + FalconDecoderLayer, + LlamaLikeBlock, + MixtralBlock, +) class MixtralModel(nn.Module): @@ -47,8 +55,10 @@ def forward( h, mask, ) - h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) - + h, _, past_key_value = layer( + h, None, attention_mask=mask, is_causal=is_causal + ) + h = self.norm(h) return MoeModelOutputWithPast( @@ -65,6 +75,7 @@ class LlamaLikeModel(nn.Module): LlamaLikeModel is intended to be reused across models that have an architecture that closely resembles Llama, e.g. Mistral and Aquila. """ + def __init__(self, vocab_size, blocks, embedding, norm): super().__init__() self.vocab_size = vocab_size @@ -72,12 +83,19 @@ def __init__(self, vocab_size, blocks, embedding, norm): self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks) self.norm = norm self.last_forward_num_tokens = 0 - + @torch.inference_mode() - def forward(self, input_ids: torch.Tensor, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): + def forward( + self, + input_ids: torch.Tensor, + attn_bias=None, + attention_mask=None, + is_causal=None, + *args, + **kwargs, + ): input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids( - input_ids, - self.last_forward_num_tokens + input_ids, self.last_forward_num_tokens ) _bsz, seqlen = input_ids.shape @@ -89,7 +107,7 @@ def forward(self, input_ids: torch.Tensor, attn_bias=None, attention_mask=None, seqlen=seqlen, start_pos=self.blocks[0].attn.start_pos, device=input_ids.device, - type_as=h + type_as=h, ) for layer in self.blocks: @@ -99,14 +117,17 @@ def forward(self, input_ids: torch.Tensor, attn_bias=None, attention_mask=None, mask, ) h, _, past_key_value = layer( - h, - None, - attention_mask=mask, - is_causal=is_causal + h, None, attention_mask=mask, is_causal=is_causal ) h = self.norm(h) - return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) + return BaseModelOutputWithPast( + last_hidden_state=h, + past_key_values=past_key_value, + hidden_states=(), + attentions=(), + ) + class MPTModel(nn.Module): def __init__(self, vocab_size, blocks, wte, norm_f): @@ -120,10 +141,17 @@ def __init__(self, vocab_size, blocks, wte, norm_f): self.last_forward_num_tokens = 0 @torch.inference_mode() - def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): + def forward( + self, + input_ids, + attn_bias=None, + attention_mask=None, + is_causal=None, + *args, + **kwargs, + ): input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids( - input_ids, - self.last_forward_num_tokens + input_ids, self.last_forward_num_tokens ) _bsz, seqlen = input_ids.shape @@ -135,7 +163,7 @@ def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None seqlen=seqlen, start_pos=self.blocks[0].attn.start_pos, device=input_ids.device, - type_as=h + type_as=h, ) for layer in self.blocks: @@ -145,14 +173,17 @@ def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None mask, ) h, _, past_key_value = layer( - h, - None, - attention_mask=mask, - is_causal=is_causal + h, None, attention_mask=mask, is_causal=is_causal ) h = self.norm_f(h) - return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) + return BaseModelOutputWithPast( + last_hidden_state=h, + past_key_values=past_key_value, + hidden_states=(), + attentions=(), + ) + class FalconModel(nn.Module): def __init__(self, vocab_size, blocks, word_embeddings, ln_f): @@ -166,10 +197,17 @@ def __init__(self, vocab_size, blocks, word_embeddings, ln_f): self.last_forward_num_tokens = 0 @torch.inference_mode() - def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): + def forward( + self, + input_ids, + attn_bias=None, + attention_mask=None, + is_causal=None, + *args, + **kwargs, + ): input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids( - input_ids, - self.last_forward_num_tokens + input_ids, self.last_forward_num_tokens ) _bsz, seqlen = input_ids.shape @@ -181,7 +219,7 @@ def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None seqlen=seqlen, start_pos=self.blocks[0].attn.start_pos, device=input_ids.device, - type_as=h + type_as=h, ) for layer in self.blocks: @@ -191,11 +229,13 @@ def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None mask, ) h, _, past_key_value = layer( - h, - None, - attention_mask=mask, - is_causal=is_causal + h, None, attention_mask=mask, is_causal=is_causal ) h = self.ln_f(h) - return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) + return BaseModelOutputWithPast( + last_hidden_state=h, + past_key_values=past_key_value, + hidden_states=(), + attentions=(), + ) diff --git a/awq/modules/fused/moe.py b/awq/modules/fused/moe.py new file mode 100644 index 00000000..b252ce46 --- /dev/null +++ b/awq/modules/fused/moe.py @@ -0,0 +1,461 @@ +import torch +import triton +from typing import Dict +import triton.language as tl + +try: + import awq_ext # with CUDA kernels + + AWQ_INSTALLED = True +except: + AWQ_INSTALLED = False + + +class FusedSparseMoeBlock(torch.nn.Module): + def __init__( + self, + top_k, + gate, + ws, + w2s, + ): + super().__init__() + self.gate = gate + self.top_k = top_k + self.ws = ws + self.w2s = w2s + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + final_hidden_states = apply_moe_weights( + self.ws, + self.w2s, + hidden_states, + router_logits, + self.top_k, + renormalize=True, + ) + + return final_hidden_states.view(batch_size, sequence_length, hidden_dim) + + +def apply_moe_weights( + w1: Dict[str, torch.Tensor], + w2: Dict[str, torch.Tensor], + x: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> torch.Tensor: + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 1024 + if FP16_MATMUL_HEURISTIC_CONDITION: + dequant_w1 = awq_ext.dequantize_weights_cuda( + w1.qweight, w1.scales, w1.qzeros, 0, 0, 0, False + ).permute(0, 2, 1) + dequant_w2 = awq_ext.dequantize_weights_cuda( + w2.qweight, w2.scales, w2.qzeros, 0, 0, 0, False + ).permute(0, 2, 1) + return fused_moe(x, dequant_w1, dequant_w2, gating_output, topk, renormalize) + + topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize) + (sorted_token_ids, expert_ids, num_tokens_post_padded) = moe_align_block_size( + topk_ids, 16, w1.qweight.shape[0] + ) + + x = x.view(x.shape[0], 1, *x.shape[1:]) + + gate_up = awq_ext.grouped_gemm_forward( + x, + w1.qweight, + w1.scales, + w1.qzeros, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + 8, + ) + + out = torch.empty( + (gate_up.shape[:-1] + (gate_up.shape[-1] // 2,)), dtype=x.dtype, device=x.device + ) + awq_ext.silu_and_mul(out, gate_up) + + out = awq_ext.grouped_gemm_forward( + out, + w2.qweight, + w2.scales, + w2.qzeros, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 8, + ) + + return torch.sum(out, dim=1) + + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, + and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. + - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. + This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` + by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def moe_align_block_size(topk_ids: torch.Tensor, block_size: int, num_experts: int): + """ + Aligns the token distribution across experts to be compatible with block size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. + """ + sorted_ids = torch.empty( + (topk_ids.numel() + num_experts * (block_size - 1),), + dtype=torch.int32, + device=topk_ids.device, + ) + expert_ids = torch.empty( + (topk_ids.numel() + num_experts,), dtype=torch.int32, device=topk_ids.device + ) + sorted_ids.fill_(topk_ids.numel()) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + awq_ext.moe_alig_block_size( + topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad + ) + return sorted_ids, expert_ids, num_tokens_post_pad + + +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: dict, +): + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + grid = lambda META: ( + triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), + ) + + fused_moe_kernel[grid]( + A, + B, + C, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2], + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, + **config, + ) + + +def fused_topk( + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + """Compute top-k indice and weights from gating logits + + Args: + gating_output (torch.Tensor): The output of the gating operation (before softmax). + topk (int): The number of top-k experts to select. + renormalize (bool): If True, renormalize the top-k weights to sum to 1. + """ + M = gating_output.shape[0] + if torch.version.hip is not None: + # The MoE kernels are not yet supported on ROCm. + routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) + else: + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=gating_output.device + ) + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=gating_output.device) + token_expert_indicies = torch.empty( + M, topk, dtype=torch.int32, device=gating_output.device + ) + awq_ext.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = True, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. Defaults to False. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + # assert w1.is_contiguous(), "Expert weights1 must be contiguous" + # assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + M, _ = hidden_states.shape + E, N, _ = w1.shape + + topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize) + + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + + if topk_ids.numel() <= w1.shape[0]: + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], E + ) + + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + ) + + awq_ext.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + ) + + if inplace: + return torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states, + ) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/awq/modules/linear/__init__.py b/awq/modules/linear/__init__.py index d7d79beb..41996f22 100644 --- a/awq/modules/linear/__init__.py +++ b/awq/modules/linear/__init__.py @@ -2,4 +2,4 @@ from .exllamav2 import WQLinear_ExllamaV2 from .gemm import WQLinear_GEMM from .gemv import WQLinear_GEMV -from .marlin import WQLinear_Marlin \ No newline at end of file +from .marlin import WQLinear_Marlin diff --git a/awq/modules/linear/gemm.py b/awq/modules/linear/gemm.py index 01cdccec..aff1e279 100644 --- a/awq/modules/linear/gemm.py +++ b/awq/modules/linear/gemm.py @@ -11,6 +11,7 @@ except: AWQ_INSTALLED = False + # Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev class WQLinearMMFunction(Function): @staticmethod @@ -24,45 +25,29 @@ def forward( w_bit=4, group_size=128, bias=None, - out_features=0 + out_features=0, ): # The forward pass can use ctx. ctx.save_for_backward(x, qweight, qzeros, scales, bias) ctx.out_features = out_features - out_shape = x.shape[:-1] + (out_features, ) + out_shape = x.shape[:-1] + (out_features,) x = x.to(torch.float16) if AWQ_INSTALLED: - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024 + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 if FP16_MATMUL_HEURISTIC_CONDITION: out = awq_ext.dequantize_weights_cuda( - qweight, - scales, - qzeros, - 0, - 0, - 0, - False + qweight, scales, qzeros, 0, 0, 0, False ) out = torch.matmul(x, out) else: out = awq_ext.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), - qweight, - scales, - qzeros, - 8 + x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8 ) else: - out = dequantize_gemm( - qweight, - qzeros, - scales, - w_bit, - group_size - ) + out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) out = torch.matmul(x, out) out = out + bias if bias is not None else out @@ -71,7 +56,7 @@ def forward( # always want 3D tensor if tensor is 2D if len(out.shape) == 2: out = out.unsqueeze(0) - + return out @staticmethod @@ -79,13 +64,7 @@ def backward(ctx, grad_output): input, qweight, qzeros, scales, bias = ctx.saved_tensors weights = awq_ext.dequantize_weights_cuda( - qweight, - scales, - qzeros, - 1, - 0, - 0, - False + qweight, scales, qzeros, 1, 0, 0, False ) if ctx.needs_input_grad[0]: @@ -98,7 +77,9 @@ def backward(ctx, grad_output): class WQLinear_GEMM(nn.Module): - def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False): + def __init__( + self, w_bit, group_size, in_features, out_features, bias, dev, training=False + ): super().__init__() if w_bit not in [4]: diff --git a/awq/modules/linear/marlin.py b/awq/modules/linear/marlin.py index 40996f03..2db8b7ee 100644 --- a/awq/modules/linear/marlin.py +++ b/awq/modules/linear/marlin.py @@ -54,7 +54,7 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): self.in_features = in_features self.out_features = out_features self.group_size = group_size if group_size != -1 else in_features - self.max_par = 8 # partitioning for large inputs + self.max_par = 8 # partitioning for large inputs # quick sanity check (make sure aligment) assert self.in_features % self.group_size == 0 diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index 2e604547..aa82cd6b 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -117,7 +117,7 @@ def quantize(self): best_device = "cuda:" + str(i % torch.cuda.device_count()) else: best_device = get_best_device() - + self.modules[i] = self.modules[i].to(best_device) common_device = next(self.modules[i].parameters()).device @@ -190,15 +190,15 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]): linear_layer.weight.data ) - if self.version == "GEMM": + if self.version == "gemm": scales = scales.t().contiguous() zeros = zeros.t().contiguous() q_linear_module = WQLinear_GEMM - elif self.version == "GEMV": + elif self.version == "gemv": q_linear_module = WQLinear_GEMV - elif self.version == "Marlin": + elif self.version == "marlin": q_linear_module = WQLinear_Marlin else: @@ -355,7 +355,9 @@ def _search_best_clip(self, layer, named_linears, input_feat): continue named_linears[name].to(get_best_device()) - max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name]) + max_val = self._compute_best_clip( + named_linears[name].weight, input_feat[name] + ) clip_list.append((name, max_val)) named_linears[name].cpu() @@ -481,7 +483,9 @@ def forward(self, *args, **kwargs): clear_memory() if layer_kwargs.get("attention_mask") is not None: - layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(best_device) + layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to( + best_device + ) return modules, layer_kwargs, inps diff --git a/awq/quantize/scale.py b/awq/quantize/scale.py index 6bf73c0e..0ee6ea05 100644 --- a/awq/quantize/scale.py +++ b/awq/quantize/scale.py @@ -9,7 +9,14 @@ from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation allowed_norms = [nn.LayerNorm, LlamaRMSNorm] -allowed_act_fns = [nn.GELU, BloomGelu, NewGELUActivation, PytorchGELUTanh, GELUActivation] +allowed_act_fns = [ + nn.GELU, + BloomGelu, + NewGELUActivation, + PytorchGELUTanh, + GELUActivation, +] + @torch.no_grad() def apply_clip(module, clip_list: Tuple[str, torch.Tensor]): @@ -28,35 +35,40 @@ def apply_scale(module, scales_list, input_feat_dict=None): for prev_op_name, layer_names, scales in scales_list: prev_op = get_op_by_name(module, prev_op_name) layers = [get_op_by_name(module, name) for name in layer_names] - + best_device = get_best_device() prev_op.to(best_device) for layer in layers: layer.to(best_device) scales.to(best_device) - - if isinstance(prev_op, nn.Linear) and type(layers) == list and isinstance(layers[0], nn.Linear): + + if ( + isinstance(prev_op, nn.Linear) + and type(layers) == list + and isinstance(layers[0], nn.Linear) + ): scale_fc_fcs(prev_op, layers, scales) elif isinstance(prev_op, nn.Linear): assert len(layers) == 1 scale_fc_fc(prev_op, layers[0], scales) - elif any(isinstance(prev_op,t) for t in allowed_norms) \ - or 'rmsnorm' in str(prev_op.__class__).lower(): + elif ( + any(isinstance(prev_op, t) for t in allowed_norms) + or "rmsnorm" in str(prev_op.__class__).lower() + ): scale_ln_fcs(prev_op, layers, scales) - elif any(isinstance(prev_op,t) for t in allowed_act_fns): + elif any(isinstance(prev_op, t) for t in allowed_act_fns): new_module = ScaledActivation(prev_op, scales) set_op_by_name(module, prev_op_name, new_module) scale_gelu_fc(prev_op, layers[0], scales) - + else: - raise NotImplementedError( - f"prev_op {type(prev_op)} not supported yet!") - + raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!") + # apply the scaling to input feat if given; prepare it for clipping - if input_feat_dict is not None: + if input_feat_dict is not None: for layer_name in layer_names: # Skip the modules that are not quantized if layer_name in input_feat_dict: @@ -68,15 +80,16 @@ def apply_scale(module, scales_list, input_feat_dict=None): layer.cpu() scales.cpu() + @torch.no_grad() def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor): if not isinstance(fcs, list): fcs = [fcs] - + scales = scales.to(ln.weight.device) ln.weight.div_(scales) - if hasattr(ln, 'bias') and ln.bias is not None: + if hasattr(ln, "bias") and ln.bias is not None: ln.bias.div_(scales) for fc in fcs: @@ -88,14 +101,15 @@ def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor): for p in fc.parameters(): assert torch.isnan(p).sum() == 0 + @torch.no_grad() def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor): assert isinstance(fc1, nn.Linear) assert isinstance(fc2, nn.Linear) - + scales = scales.to(fc1.weight.device) - fc1.weight[-scales.size(0):].div_(scales.view(-1, 1)) + fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1)) if fc1.bias is not None: fc1.bias.div_(scales.view(-1)) @@ -106,29 +120,31 @@ def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor): for p in fc2.parameters(): assert torch.isnan(p).sum() == 0 + @torch.no_grad() def scale_fc_fcs(fc1: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor): if not isinstance(fcs, list): fcs = [fcs] - + scales = scales.to(fc1.weight.device) - fc1.weight[-scales.size(0):].div_(scales.view(-1, 1)) + fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1)) if fc1.bias is not None: fc1.bias.div_(scales.view(-1)) for fc in fcs: fc.weight.mul_(scales.view(1, -1)) - + for p in fc1.parameters(): assert torch.isnan(p).sum() == 0 for fc in fcs: for p in fc.parameters(): assert torch.isnan(p).sum() == 0 + @torch.no_grad() def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor): - assert any(isinstance(gelu,t) for t in allowed_act_fns) + assert any(isinstance(gelu, t) for t in allowed_act_fns) assert isinstance(fc, nn.Linear) fc.weight.mul_(scales.view(1, -1).to(fc.weight.device)) diff --git a/awq/utils/calib_data.py b/awq/utils/calib_data.py index cc589a34..2408cf3f 100644 --- a/awq/utils/calib_data.py +++ b/awq/utils/calib_data.py @@ -3,33 +3,41 @@ from typing import List, Union from datasets import load_dataset -def get_calib_dataset(data: Union[str, List[str], List[List[int]]] = "pileval", - tokenizer=None, n_samples=512, block_size=512, - split="train", text_column="text"): + +def get_calib_dataset( + data: Union[str, List[str], List[List[int]]] = "pileval", + tokenizer=None, + n_samples=512, + block_size=512, + split="train", + text_column="text", +): if isinstance(data, str): if data == "pileval": dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") else: dataset = load_dataset(data, split=split) - + dataset = dataset.shuffle(seed=42) elif isinstance(data, list): if isinstance(data[0], str): dataset = [{text_column: text} for text in data] elif isinstance(data[0][0], int): - dataset = data + dataset = data else: raise NotImplementedError( "Either pass a string to a huggingface dataset or a list" "that is preprocessed with one sample of text per element" - " or a list of list of int for tokenized words.") + " or a list of list of int for tokenized words." + ) else: raise NotImplementedError( "Either pass a string to a huggingface dataset or a list" "that is preprocessed with one sample of text per element" - " or a list of list of int for tokenized words.") - + " or a list of list of int for tokenized words." + ) + samples = [] n_run = 0 for data in dataset: @@ -52,4 +60,6 @@ def get_calib_dataset(data: Union[str, List[str], List[List[int]]] = "pileval", cat_samples = torch.cat(samples, dim=1) n_split = cat_samples.shape[1] // block_size logging.debug(f" * Split into {n_split} blocks") - return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)] + return [ + cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split) + ] diff --git a/awq/utils/fused_utils.py b/awq/utils/fused_utils.py index d6f2ce18..64d63947 100644 --- a/awq/utils/fused_utils.py +++ b/awq/utils/fused_utils.py @@ -6,6 +6,7 @@ from awq.modules.linear.exllama import WQLinear_Exllama from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2 + def prepare_correct_devices(next_layer, hidden_states, mask): hidden_states = hidden_states.to(next_layer.device) @@ -13,7 +14,8 @@ def prepare_correct_devices(next_layer, hidden_states, mask): mask = mask.to(next_layer.device) return hidden_states, mask - + + def prepare_cache(blocks, seqlen: int) -> int: for block in blocks: start_pos = block.attn.start_pos @@ -21,12 +23,15 @@ def prepare_cache(blocks, seqlen: int) -> int: # Reset and avoid retaining state when processing context if seqlen > 1 and (will_cache_be_exceeded or start_pos > 0): - block.attn.start_pos = block.attn.cache.roll_kv_n_steps(start_pos, n=start_pos) - - # Slowly roll out old tokens without performance hit if exceeded during decoding + block.attn.start_pos = block.attn.cache.roll_kv_n_steps( + start_pos, n=start_pos + ) + + # Slowly roll out old tokens without performance hit if exceeded during decoding elif seqlen == 1 and will_cache_be_exceeded: block.attn.start_pos = block.attn.cache.roll_kv_n_steps(start_pos, n=100) + def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int): # NOTE: from transformers 4.35.0, input_ids includes full context during decoding num_input_tokens = input_ids.shape[-1] @@ -34,25 +39,29 @@ def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int): if num_input_tokens != 1: num_new_tokens = num_input_tokens - last_forward_num_tokens - + # after context is processed, slice to latest token if num_new_tokens == 1: input_ids = input_ids[:, -1:] return input_ids, last_forward_num_tokens + num_new_tokens + def prepare_attention_mask(seqlen, start_pos, device, type_as: torch.Tensor): mask = None if seqlen > 1: - mask = torch.full( - (1, 1, seqlen, seqlen), float("-inf"), device=device - ) - mask = torch.triu(mask, diagonal=start_pos+ 1).type_as(type_as) - + mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=device) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(type_as) + return mask + def fuse_qkv(module, q_proj, k_proj, v_proj): - bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None + bias = ( + torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) + if q_proj.bias is not None + else None + ) if isinstance(q_proj, WQLinear_GEMV): q_linear = WQLinear_GEMV @@ -71,45 +80,110 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): q_proj.in_features, q_proj.out_features + k_proj.out_features + v_proj.out_features, q_proj.bias is not None, - next(iter(module.state_dict().values())).device + next(iter(module.state_dict().values())).device, ) if isinstance(q_proj, WQLinear_GEMV): - qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0) - qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0) - qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) + qkv_layer.qweight = torch.cat( + [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0 + ) + qkv_layer.qzeros = torch.cat( + [q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0 + ) + qkv_layer.scales = torch.cat( + [q_proj.scales, k_proj.scales, v_proj.scales], dim=0 + ) qkv_layer.split_k_iters = q_proj.split_k_iters elif isinstance(q_proj, WQLinear_GEMM): - qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) - qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) - qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + qkv_layer.qweight = torch.cat( + [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1 + ) + qkv_layer.qzeros = torch.cat( + [q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1 + ) + qkv_layer.scales = torch.cat( + [q_proj.scales, k_proj.scales, v_proj.scales], dim=1 + ) elif isinstance(q_proj, WQLinear_Exllama): - qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) - qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) - qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + qkv_layer.qweight = torch.cat( + [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1 + ) + qkv_layer.qzeros = torch.cat( + [q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1 + ) + qkv_layer.scales = torch.cat( + [q_proj.scales, k_proj.scales, v_proj.scales], dim=1 + ) elif isinstance(q_proj, WQLinear_ExllamaV2): - qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) - qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) - qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + qkv_layer.qweight = torch.cat( + [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1 + ) + qkv_layer.qzeros = torch.cat( + [q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1 + ) + qkv_layer.scales = torch.cat( + [q_proj.scales, k_proj.scales, v_proj.scales], dim=1 + ) elif isinstance(q_proj, WQLinear_Marlin): - qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) - qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + qkv_layer.qweight = torch.cat( + [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1 + ) + qkv_layer.scales = torch.cat( + [q_proj.scales, k_proj.scales, v_proj.scales], dim=1 + ) # workspace is created in post_init - + qkv_layer.bias = bias + for layer in [q_proj, k_proj, v_proj]: + del (layer.qweight, layer.qzeros, layer.scales) + return qkv_layer -def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim): + +def fuse_linears(linears, device, dim=1, operation=torch.cat): + total_out_features = sum([layer.out_features for layer in linears]) + fused = WQLinear_GEMM( + linears[0].w_bit, + linears[0].group_size, + linears[0].in_features, + total_out_features, + bias=None, + dev=device, + ) + fused.qweight = operation([layer.qweight for layer in linears], dim=dim) + fused.qzeros = operation([layer.qzeros for layer in linears], dim=dim) + fused.scales = operation([layer.scales for layer in linears], dim=dim) + + for layer in linears: + del (layer.qweight, layer.qzeros, layer.scales, layer) + + return fused + + +def get_attention_shapes( + attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim +): if attention_shapes is not None: attention_shapes = attention_shapes elif n_kv_heads == 0: attention_shapes = { # following fastertransformer definition - "cache_v": (cache_batch_size, n_heads, max_seq_len, head_dim,), + "cache_v": ( + cache_batch_size, + n_heads, + max_seq_len, + head_dim, + ), # 8: pack 8 fp16 in FT, if fp32 then use 4 - "cache_k": (cache_batch_size, n_heads, head_dim // 8, max_seq_len, 8,), + "cache_k": ( + cache_batch_size, + n_heads, + head_dim // 8, + max_seq_len, + 8, + ), "xqkv_view": (-1, n_heads, head_dim), "xq_slice": lambda xqkv: xqkv[:, :, 0], "xk_slice": lambda xqkv: xqkv[:, :, 1], @@ -120,26 +194,37 @@ def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_head "xk_reshape": (n_heads, head_dim // 8, 8), "single_xq_view": (n_heads, head_dim), "single_xk_view": (n_heads, head_dim), - "single_xv_view": (n_heads, head_dim) + "single_xv_view": (n_heads, head_dim), } else: attention_shapes = { # following fastertransformer definition - "cache_v": (cache_batch_size, n_kv_heads, max_seq_len, head_dim,), + "cache_v": ( + cache_batch_size, + n_kv_heads, + max_seq_len, + head_dim, + ), # 8: pack 8 fp16 in FT, if fp32 then use 4 - "cache_k": (cache_batch_size, n_kv_heads, head_dim // 8, max_seq_len, 8,), + "cache_k": ( + cache_batch_size, + n_kv_heads, + head_dim // 8, + max_seq_len, + 8, + ), "xqkv_view": (n_heads + n_kv_heads * 2, head_dim), - "xq_slice": lambda xqkv: xqkv[:, :, 0 : n_heads], + "xq_slice": lambda xqkv: xqkv[:, :, 0:n_heads], "xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)], - "xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads :], + "xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads:], "xq_view": (n_heads, head_dim), "xk_view": (n_kv_heads, head_dim), "xv_view": (n_kv_heads, head_dim), "xk_reshape": (n_kv_heads, head_dim // 8, 8), "single_xq_view": (n_heads, head_dim), "single_xk_view": (n_kv_heads, head_dim), - "single_xv_view": (n_kv_heads, head_dim) + "single_xv_view": (n_kv_heads, head_dim), } - - return attention_shapes \ No newline at end of file + + return attention_shapes diff --git a/awq/utils/module.py b/awq/utils/module.py index dabfce04..12fff5c0 100644 --- a/awq/utils/module.py +++ b/awq/utils/module.py @@ -1,8 +1,10 @@ import torch.nn as nn + def get_named_linears(module): return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)} + def get_op_by_name(module, op_name): # get the op by its name relative to the module for name, m in module.named_modules(): @@ -12,10 +14,10 @@ def get_op_by_name(module, op_name): def set_op_by_name(layer, name, new_module): - levels = name.split('.') + levels = name.split(".") if len(levels) > 1: mod_ = layer - for l_idx in range(len(levels)-1): + for l_idx in range(len(levels) - 1): if levels[l_idx].isdigit(): mod_ = mod_[int(levels[l_idx])] else: @@ -43,6 +45,7 @@ def append_str_prefix(x, prefix): else: return x + def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert): if modules_to_not_convert is None: return linear_layers @@ -51,4 +54,4 @@ def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert): for name, linear_layer in linear_layers.items(): if not any(key in name for key in modules_to_not_convert): filtered_layers[name] = linear_layer - return filtered_layers \ No newline at end of file + return filtered_layers diff --git a/awq/utils/packing_utils.py b/awq/utils/packing_utils.py index e0a60836..d01f8e64 100644 --- a/awq/utils/packing_utils.py +++ b/awq/utils/packing_utils.py @@ -79,6 +79,7 @@ def unpack_reorder_pack(qweight, qzeros, bits): return qweight, qzeros + def dequantize_gemm(qweight, qzeros, scales, bits, group_size): # Unpack the qweight and qzeros tensors iweight, izeros = unpack_awq(qweight, qzeros, bits) @@ -94,4 +95,4 @@ def dequantize_gemm(qweight, qzeros, scales, bits, group_size): izeros = izeros.repeat_interleave(group_size, dim=0) iweight = (iweight - izeros) * scales - return iweight \ No newline at end of file + return iweight diff --git a/awq/utils/parallel.py b/awq/utils/parallel.py index eb4389bc..6d0a1ebd 100644 --- a/awq/utils/parallel.py +++ b/awq/utils/parallel.py @@ -23,6 +23,7 @@ def auto_parallel(args): else: cuda_visible_devices = list(range(8)) os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( - [str(dev) for dev in cuda_visible_devices[:n_gpu]]) + [str(dev) for dev in cuda_visible_devices[:n_gpu]] + ) logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"]) return cuda_visible_devices diff --git a/awq/utils/quant_utils.py b/awq/utils/quant_utils.py index aaf50735..36ea956a 100644 --- a/awq/utils/quant_utils.py +++ b/awq/utils/quant_utils.py @@ -115,7 +115,7 @@ def dequantize(imatrix, scales, zeros, group_size): ) * scales.repeat_interleave(group_size, dim=0) fmatrix = fmatrix.to(torch.float16) - + return fmatrix diff --git a/awq/utils/utils.py b/awq/utils/utils.py index dc419edd..44075876 100644 --- a/awq/utils/utils.py +++ b/awq/utils/utils.py @@ -8,6 +8,7 @@ def get_module_by_name_suffix(model, module_name: str): if name.endswith(module_name): return module + def simple_dispatch_model(model, device_map): from accelerate.hooks import add_hook_to_module, AlignDevicesHook @@ -18,7 +19,10 @@ def simple_dispatch_model(model, device_map): return model tied_params = accelerate.utils.modeling.find_tied_parameters(model) - if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}: + if set(device_map.values()) == {"cpu"} or set(device_map.values()) == { + "cpu", + "disk", + }: main_device = "cpu" else: main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] @@ -27,10 +31,14 @@ def simple_dispatch_model(model, device_map): prev_hook = None for idx, (n, d) in enumerate(cpu_offload_group): m = get_module_by_name_suffix(model, n) - _, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook) + _, prev_hook = accelerate.cpu_offload_with_hook( + m, execution_device=main_device, prev_module_hook=prev_hook + ) # set first cpu offload module's prev_module_hook to the last cpu offload module's hook if len(cpu_offload_group) > 1: - get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook + get_module_by_name_suffix( + model, cpu_offload_group[0][0] + )._hf_hook.prev_module_hook = prev_hook for n, d in device_map.items(): m = get_module_by_name_suffix(model, n) @@ -43,33 +51,53 @@ def simple_dispatch_model(model, device_map): return model + def set_module_name(model, name, value): - if '.' in name: - parent_name = name.rsplit('.', 1)[0] - child_name = name[len(parent_name) + 1:] + if "." in name: + parent_name = name.rsplit(".", 1)[0] + child_name = name[len(parent_name) + 1 :] parent = model.get_submodule(parent_name) else: - parent_name = '' + parent_name = "" parent = model child_name = name setattr(parent, child_name, value) + def clear_memory(weight=None): if weight is not None: del weight gc.collect() torch.cuda.empty_cache() + def compute_memory_used_pct(device): - memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3) - memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100 + memory_used = torch.cuda.max_memory_allocated(device) / (1024**3) + memory_pct = ( + memory_used + / (torch.cuda.get_device_properties(device).total_memory / (1024**3)) + * 100 + ) return memory_pct + def get_best_device(): if torch.backends.mps.is_available(): - return 'mps' + return "mps" elif torch.cuda.is_available(): - return 'cuda:0' + return "cuda:0" else: - return 'cpu' \ No newline at end of file + return "cpu" + + +def get_lowest_memory_device_index(): + device = None + curr_device_memory_pct = 0 + for device_index in range(torch.cuda.device_count()): + device_memory_pct = compute_memory_used_pct(device_index) + if device is None or device_memory_pct < curr_device_memory_pct: + device = device_index + curr_device_memory_pct = device_memory_pct + + return device diff --git a/docs/examples.md b/docs/examples.md new file mode 100644 index 00000000..af0f9b71 --- /dev/null +++ b/docs/examples.md @@ -0,0 +1,318 @@ +# Examples + +## Basic Quantization + +AWQ performs zero point quantization down to a precision of 4-bit integers. +You can also specify other bit rates like 3-bit, but some of these options may lack kernels +for running inference. + +Notes: + +- Some models like Falcon is only compatible with group size 64. +- To use Marlin, you must specify zero point as False and version as Marlin. + +```python +from awq import AutoAWQForCausalLM +from transformers import AutoTokenizer + +model_path = 'mistralai/Mistral-7B-Instruct-v0.2' +quant_path = 'mistral-instruct-v0.2-awq' +quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } + +# Load model +model = AutoAWQForCausalLM.from_pretrained( + model_path, **{"low_cpu_mem_usage": True, "use_cache": False} +) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +# Quantize +model.quantize(tokenizer, quant_config=quant_config) + +# Save quantized model +model.save_quantized(quant_path) +tokenizer.save_pretrained(quant_path) + +print(f'Model is quantized and saved at "{quant_path}"') +``` + +### Custom Data + +This includes an example function that loads either wikitext or dolly. +Note that currently all samples above 512 in length are discarded. + +```python +from datasets import load_dataset +from awq import AutoAWQForCausalLM +from transformers import AutoTokenizer + +model_path = 'lmsys/vicuna-7b-v1.5' +quant_path = 'vicuna-7b-v1.5-awq' +quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } + +# Load model +model = AutoAWQForCausalLM.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +# Define data loading methods +def load_dolly(): + data = load_dataset('databricks/databricks-dolly-15k', split="train") + + # concatenate data + def concatenate_data(x): + return {"text": x['instruction'] + '\n' + x['context'] + '\n' + x['response']} + + concatenated = data.map(concatenate_data) + return [text for text in concatenated["text"]] + +def load_wikitext(): + data = load_dataset('wikitext', 'wikitext-2-raw-v1', split="train") + return [text for text in data["text"] if text.strip() != '' and len(text.split(' ')) > 20] + +# Quantize +model.quantize(tokenizer, quant_config=quant_config, calib_data=load_wikitext()) + +# Save quantized model +model.save_quantized(quant_path) +tokenizer.save_pretrained(quant_path) + +print(f'Model is quantized and saved at "{quant_path}"') +``` + +### GGUF Export + +This computes AWQ scales and appliesthem to the model without running real quantization. +This keeps the quality of AWQ because theweights are applied but skips quantization +in order to make it compatible with other frameworks. + +Step by step: + +- `quantize()`: Compute AWQ scales and apply them +- `save_pretrained()`: Saves a non-quantized model in FP16 +- `convert.py`: Convert the Huggingface FP16 weights to GGUF FP16 weights +- `quantize`: Run GGUF quantization to get real quantized weights, in this case 4-bit. + +```python +import os +import subprocess +from awq import AutoAWQForCausalLM +from transformers import AutoTokenizer + +model_path = 'mistralai/Mistral-7B-v0.1' +quant_path = 'mistral-awq' +llama_cpp_path = '/workspace/llama.cpp' +quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 6, "version": "GEMM" } + +# Load model +# NOTE: pass safetensors=True to load safetensors +model = AutoAWQForCausalLM.from_pretrained( + model_path, **{"low_cpu_mem_usage": True, "use_cache": False} +) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +# Quantize +# NOTE: We avoid packing weights, so you cannot use this model in AutoAWQ +# after quantizing. The saved model is FP16 but has the AWQ scales applied. +model.quantize( + tokenizer, + quant_config=quant_config, + export_compatible=True +) + +# Save quantized model +model.save_quantized(quant_path) +tokenizer.save_pretrained(quant_path) +print(f'Model is quantized and saved at "{quant_path}"') + +# GGUF conversion +print('Converting model to GGUF...') +llama_cpp_method = "q4_K_M" +convert_cmd_path = os.path.join(llama_cpp_path, "convert.py") +quantize_cmd_path = os.path.join(llama_cpp_path, "quantize") + +if not os.path.exists(llama_cpp_path): + cmd = f"git clone https://github.com/ggerganov/llama.cpp.git {llama_cpp_path} && cd {llama_cpp_path} && make LLAMA_CUBLAS=1 LLAMA_CUDA_F16=1" + subprocess.run([cmd], shell=True, check=True) + +subprocess.run([ + f"python {convert_cmd_path} {quant_path} --outfile {quant_path}/model.gguf" +], shell=True, check=True) + +subprocess.run([ + f"{quantize_cmd_path} {quant_path}/model.gguf {quant_path}/model_{llama_cpp_method}.gguf {llama_cpp_method}" +], shell=True, check=True) +``` + +## Basic Inference + +To run inference, you often want to run with `fuse_layers=True` to get the claimed speedup in AutoAWQ. +Additionally, consider setting `max_seq_len` (default: 2048) as this will be the maximum context that the model can hold. + +Notes: + +- You can specify `use_exllama_v2=True` to enable ExLlamaV2 kernels during inference. + +```python +from awq import AutoAWQForCausalLM +from transformers import AutoTokenizer, TextStreamer + +quant_path = "TheBloke/Mistral-7B-Instruct-v0.2-AWQ" + +# Load model +model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True) +tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) +streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + +# Convert prompt to tokens +prompt_template = "[INST] {prompt} [/INST]" + +prompt = "You're standing on the surface of the Earth. "\ + "You walk one mile south, one mile west and one mile north. "\ + "You end up exactly where you started. Where are you?" + +tokens = tokenizer( + prompt_template.format(prompt=prompt), + return_tensors='pt' +).input_ids.cuda() + +# Generate output +generation_output = model.generate( + tokens, + streamer=streamer, + max_new_tokens=512 +) +``` + +### Transformers + +You can also load an AWQ model by using AutoModelForCausalLM, just make sure you have AutoAWQ installed. +Note that not all models will have fused modules when loading from transformers. +See more [documentation here](https://huggingface.co/docs/transformers/main/en/quantization#awq). + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer + +# NOTE: Must install from PR until merged +# pip install --upgrade git+https://github.com/younesbelkada/transformers.git@add-awq +model_id = "casperhansen/mistral-7b-instruct-v0.1-awq" + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + device_map="cuda:0" +) +streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + +# Convert prompt to tokens +text = "[INST] What are the basic steps to use the Huggingface transformers library? [/INST]" + +tokens = tokenizer( + text, + return_tensors='pt' +).input_ids.cuda() + +# Generate output +generation_output = model.generate( + tokens, + streamer=streamer, + max_new_tokens=512 +) +``` + +### vLLM + +You can also load AWQ models in [vLLM](https://github.com/vllm-project/vllm). + +```python +import asyncio +from transformers import AutoTokenizer, PreTrainedTokenizer +from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs + +model_path = "casperhansen/mixtral-instruct-awq" + +# prompting +prompt = "You're standing on the surface of the Earth. "\ + "You walk one mile south, one mile west and one mile north. "\ + "You end up exactly where you started. Where are you?", + +prompt_template = "[INST] {prompt} [/INST]" + +# sampling params +sampling_params = SamplingParams( + repetition_penalty=1.1, + temperature=0.8, + max_tokens=512 +) + +# tokenizer +tokenizer = AutoTokenizer.from_pretrained(model_path) + +# async engine args for streaming +engine_args = AsyncEngineArgs( + model=model_path, + quantization="awq", + dtype="float16", + max_model_len=512, + enforce_eager=True, + disable_log_requests=True, + disable_log_stats=True, +) + +async def generate(model: AsyncLLMEngine, tokenizer: PreTrainedTokenizer): + tokens = tokenizer(prompt_template.format(prompt=prompt)).input_ids + + outputs = model.generate( + prompt=prompt, + sampling_params=sampling_params, + request_id=1, + prompt_token_ids=tokens, + ) + + print("\n** Starting generation!\n") + last_index = 0 + + async for output in outputs: + print(output.outputs[0].text[last_index:], end="", flush=True) + last_index = len(output.outputs[0].text) + + print("\n\n** Finished generation!\n") + +if __name__ == '__main__': + model = AsyncLLMEngine.from_engine_args(engine_args) + asyncio.run(generate(model, tokenizer)) +``` + +### LLaVa (multimodal) + +AutoAWQ also supports the LLaVa model. You simply need to load an +AutoProcessor to process the prompt and image to generate inputs for the AWQ model. + +```python +import requests +import torch +from PIL import Image + +from awq import AutoAWQForCausalLM +from transformers import AutoProcessor + +quant_path = "ybelkada/llava-1.5-7b-hf-awq" + +# Load model +model = AutoAWQForCausalLM.from_quantized(quant_path, safetensors=True, device_map={"": 0}) +processor = AutoProcessor.from_pretrained(quant_path) + +prompt = "USER: \nWhat are these?\nASSISTANT:" +image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + +raw_image = Image.open(requests.get(image_file, stream=True).raw) +inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16) +# Generate output +generation_output = model.generate( + **inputs, + max_new_tokens=512 +) + +print(processor.decode(generation_output[0], skip_special_tokens=True)) +``` \ No newline at end of file diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..9129030d --- /dev/null +++ b/docs/index.md @@ -0,0 +1,51 @@ +# AutoAWQ + +AutoAWQ pushes ease of use and fast inference speed into one package. In the following documentation, +you will learn how to quantize and run inference. + +Example inference speed (RTX 4090, Ryzen 9 7950X, 64 tokens): + +- Vicuna 7B (GEMV kernel): 198.848 tokens/s +- Mistral 7B (GEMM kernel): 156.317 tokens/s +- Mistral 7B (ExLlamaV2 kernel): 188.865 tokens/s + +## Installation notes + +- Install: `pip install autoawq`. +- Your torch version must match the build version, i.e. you cannot use torch 2.0.1 with a wheel that was built with 2.2.0. +- For AMD GPUs, inference will run through ExLlamaV2 kernels without fused layers. You need to pass the following arguments to run with AMD GPUs: + +```python +model = AutoAWQForCausalLM.from_quantized( + ..., + fuse_layers=False, + use_exllama_v2=True +) +``` + +## Supported models + +The detailed support list: + +| Models | Sizes | +| -------- | --------------------------- | +| LLaMA-2 | 7B/13B/70B | +| LLaMA | 7B/13B/30B/65B | +| Mistral | 7B | +| Vicuna | 7B/13B | +| MPT | 7B/30B | +| Falcon | 7B/40B | +| OPT | 125m/1.3B/2.7B/6.7B/13B/30B | +| Bloom | 560m/3B/7B/ | +| GPTJ | 6.7B | +| Aquila | 7B | +| Aquila2 | 7B/34B | +| Yi | 6B/34B | +| Qwen | 1.8B/7B/14B/72B | +| BigCode | 1B/7B/15B | +| GPT NeoX | 20B | +| GPT-J | 6B | +| LLaVa | 7B/13B | +| Mixtral | 8x7B | +| Baichuan | 7B/13B | +| QWen | 1.8B/7B/14/72B | \ No newline at end of file diff --git a/docs/reference/index.md b/docs/reference/index.md new file mode 100644 index 00000000..5439f386 --- /dev/null +++ b/docs/reference/index.md @@ -0,0 +1,6 @@ +# Auto and Base model classes in AutoAWQ + +View the documentation of the main classes of AutoAWQ models below. + +::: awq.models.auto.AutoAWQForCausalLM +::: awq.models.base.BaseAWQForCausalLM diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..e26da2b3 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,4 @@ +# AutoAWQ examples + +Please see the docs for more thorough examples. In this folder, you will only find the +very basic examples of quantization, inference, and training. \ No newline at end of file diff --git a/examples/awq_to_gguf_quant.py b/examples/awq_to_gguf_quant.py deleted file mode 100644 index 61a77bbf..00000000 --- a/examples/awq_to_gguf_quant.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -import subprocess -from awq import AutoAWQForCausalLM -from transformers import AutoTokenizer - -model_path = 'mistralai/Mistral-7B-v0.1' -quant_path = 'mistral-awq' -llama_cpp_path = '/workspace/llama.cpp' -quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 6, "version": "GEMM" } - -# Load model -# NOTE: pass safetensors=True to load safetensors -model = AutoAWQForCausalLM.from_pretrained( - model_path, **{"low_cpu_mem_usage": True, "use_cache": False} -) -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - -# Quantize -# NOTE: We avoid packing weights, so you cannot use this model in AutoAWQ -# after quantizing. The saved model is FP16 but has the AWQ scales applied. -model.quantize( - tokenizer, - quant_config=quant_config, - export_compatible=True -) - -# Save quantized model -model.save_quantized(quant_path) -tokenizer.save_pretrained(quant_path) -print(f'Model is quantized and saved at "{quant_path}"') - -# GGUF conversion -print('Converting model to GGUF...') -llama_cpp_method = "q4_K_M" -convert_cmd_path = os.path.join(llama_cpp_path, "convert.py") -quantize_cmd_path = os.path.join(llama_cpp_path, "quantize") - -if not os.path.exists(llama_cpp_path): - cmd = f"git clone https://github.com/ggerganov/llama.cpp.git {llama_cpp_path} && cd {llama_cpp_path} && make LLAMA_CUBLAS=1 LLAMA_CUDA_F16=1" - subprocess.run([cmd], shell=True, check=True) - -subprocess.run([ - f"python {convert_cmd_path} {quant_path} --outfile {quant_path}/model.gguf" -], shell=True, check=True) - -subprocess.run([ - f"{quantize_cmd_path} {quant_path}/model.gguf {quant_path}/model_{llama_cpp_method}.gguf {llama_cpp_method}" -], shell=True, check=True) diff --git a/examples/basic_transformers.py b/examples/basic_transformers.py deleted file mode 100644 index 397a8d19..00000000 --- a/examples/basic_transformers.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer - -# NOTE: Must install from PR until merged -# pip install --upgrade git+https://github.com/younesbelkada/transformers.git@add-awq -model_id = "casperhansen/mistral-7b-instruct-v0.1-awq" - -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_pretrained( - model_id, - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - device_map="cuda:0" -) -streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) - -# Convert prompt to tokens -text = "[INST] What are the basic steps to use the Huggingface transformers library? [/INST]" - -tokens = tokenizer( - text, - return_tensors='pt' -).input_ids.cuda() - -# Generate output -generation_output = model.generate( - tokens, - streamer=streamer, - max_new_tokens=512 -) \ No newline at end of file diff --git a/examples/basic_vllm.py b/examples/basic_vllm.py deleted file mode 100644 index d3763790..00000000 --- a/examples/basic_vllm.py +++ /dev/null @@ -1,56 +0,0 @@ -import asyncio -from transformers import AutoTokenizer, PreTrainedTokenizer -from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs - -model_path = "casperhansen/mixtral-instruct-awq" - -# prompting -prompt = "You're standing on the surface of the Earth. "\ - "You walk one mile south, one mile west and one mile north. "\ - "You end up exactly where you started. Where are you?", - -prompt_template = "[INST] {prompt} [/INST]" - -# sampling params -sampling_params = SamplingParams( - repetition_penalty=1.1, - temperature=0.8, - max_tokens=512 -) - -# tokenizer -tokenizer = AutoTokenizer.from_pretrained(model_path) - -# async engine args for streaming -engine_args = AsyncEngineArgs( - model=model_path, - quantization="awq", - dtype="float16", - max_model_len=512, - enforce_eager=True, - disable_log_requests=True, - disable_log_stats=True, -) - -async def generate(model: AsyncLLMEngine, tokenizer: PreTrainedTokenizer): - tokens = tokenizer(prompt_template.format(prompt=prompt)).input_ids - - outputs = model.generate( - prompt=prompt, - sampling_params=sampling_params, - request_id=1, - prompt_token_ids=tokens, - ) - - print("\n** Starting generation!\n") - last_index = 0 - - async for output in outputs: - print(output.outputs[0].text[last_index:], end="", flush=True) - last_index = len(output.outputs[0].text) - - print("\n\n** Finished generation!\n") - -if __name__ == '__main__': - model = AsyncLLMEngine.from_engine_args(engine_args) - asyncio.run(generate(model, tokenizer)) \ No newline at end of file diff --git a/examples/benchmark.py b/examples/benchmark.py index 38e28dcd..dd47ea7e 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -98,7 +98,7 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si else: model = AutoAWQForCausalLM.from_quantized( model_path, quant_file, fuse_layers=True, - max_new_tokens=n_generate, batch_size=batch_size, + max_seq_len=n_generate, batch_size=batch_size, safetensors=not no_safetensors ) @@ -115,10 +115,6 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si successful_generate = False else: raise RuntimeError(ex) - - device = next(model.parameters()).device - memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3) - memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100 if successful_generate: # number of tokens in context / time for processing context * batch size @@ -128,7 +124,11 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si print(f" ** Speed (Prefill): {prefill_tokens_per_second:.2f} tokens/second") print(f" ** Speed (Decode): {decode_tokens_per_second:.2f} tokens/second") - print(f" ** Max Memory (VRAM): {memory_used:.2f} GB ({memory_pct:.2f}%)") + + for device in range(torch.cuda.device_count()): + memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3) + memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100 + print(f" ** Max Memory (device: {device}): {memory_used:.2f} GB ({memory_pct:.2f}%)") else: prefill_tokens_per_second = 'OOM' decode_tokens_per_second = 'OOM' diff --git a/examples/exllama_generate.py b/examples/exllama_generate.py deleted file mode 100644 index 27ea2f20..00000000 --- a/examples/exllama_generate.py +++ /dev/null @@ -1,28 +0,0 @@ -from awq import AutoAWQForCausalLM -from transformers import AutoTokenizer, TextStreamer - -quant_path = "TheBloke/Mistral-7B-Instruct-v0.1-AWQ" - -# Load model -model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True, use_exllama_v2=True) -tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) -streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) - -# Convert prompt to tokens -prompt_template = "[INST] {prompt} [/INST]" - -prompt = "You're standing on the surface of the Earth. "\ - "You walk one mile south, one mile west and one mile north. "\ - "You end up exactly where you started. Where are you?" - -tokens = tokenizer( - prompt_template.format(prompt=prompt), - return_tensors='pt' -).input_ids.cuda() - -# Generate output -generation_output = model.generate( - tokens, - streamer=streamer, - max_new_tokens=512 -) \ No newline at end of file diff --git a/examples/basic_generate.py b/examples/generate.py similarity index 100% rename from examples/basic_generate.py rename to examples/generate.py diff --git a/examples/llava_generate.py b/examples/llava_generate.py deleted file mode 100644 index 219fd5bb..00000000 --- a/examples/llava_generate.py +++ /dev/null @@ -1,25 +0,0 @@ -import requests -import torch -from PIL import Image - -from awq import AutoAWQForCausalLM -from transformers import AutoProcessor - -quant_path = "ybelkada/llava-1.5-7b-hf-awq" - -# Load model -model = AutoAWQForCausalLM.from_quantized(quant_path, safetensors=True, device_map={"": 0}) -processor = AutoProcessor.from_pretrained(quant_path) - -prompt = "USER: \nWhat are these?\nASSISTANT:" -image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" - -raw_image = Image.open(requests.get(image_file, stream=True).raw) -inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16) -# Generate output -generation_output = model.generate( - **inputs, - max_new_tokens=512 -) - -print(processor.decode(generation_output[0], skip_special_tokens=True)) \ No newline at end of file diff --git a/examples/llava_quant.py b/examples/llava_quant.py deleted file mode 100644 index 371430f1..00000000 --- a/examples/llava_quant.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch -from awq import AutoAWQForCausalLM -from transformers import AutoTokenizer - -model_path = "llava-hf/llava-1.5-7b-hf" -quant_path = "llava-1.5-7b-hf-awq" - -quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version":"GEMM"} - -# Load model -model = AutoAWQForCausalLM.from_pretrained( - model_path, safetensors=True, torch_dtype=torch.float16, device_map="auto" -) -tokenizer = AutoTokenizer.from_pretrained(model_path) - -# Quantize -model.quantize(tokenizer, quant_config=quant_config) - -# Save quantized model -model.save_quantized(quant_path) -tokenizer.save_pretrained(quant_path) - -print(f'Model is quantized and saved at "{quant_path}"') \ No newline at end of file diff --git a/examples/marlin_generate.py b/examples/marlin_generate.py deleted file mode 100644 index 2564b517..00000000 --- a/examples/marlin_generate.py +++ /dev/null @@ -1,33 +0,0 @@ -from awq import AutoAWQForCausalLM -from transformers import AutoTokenizer, TextStreamer - -quant_path = "IlyasMoutawwakil/vicuna-7b-v1.5-awq-marlin" - -# Load model -model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False) -tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) -streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) - -# Convert prompt to tokens -prompt_template = """\ -<|system|> - -<|user|> -{prompt} -<|assistant|>""" - -prompt = "You're standing on the surface of the Earth. "\ - "You walk one mile south, one mile west and one mile north. "\ - "You end up exactly where you started. Where are you?" - -tokens = tokenizer( - prompt_template.format(prompt=prompt), - return_tensors='pt' -).input_ids.cuda() - -# Generate output -generation_output = model.generate( - tokens, - streamer=streamer, - max_new_tokens=512 -) \ No newline at end of file diff --git a/examples/marlin_quant.py b/examples/marlin_quant.py deleted file mode 100644 index e330de6a..00000000 --- a/examples/marlin_quant.py +++ /dev/null @@ -1,22 +0,0 @@ -from awq import AutoAWQForCausalLM -from transformers import AutoTokenizer - -model_path = 'lmsys/vicuna-7b-v1.5' -quant_path = 'vicuna-7b-v1.5-awq-marlin' -quant_config = { "zero_point": False, "q_group_size": 128, "w_bit": 4, "version": "Marlin" } - -# Load model -# NOTE: pass safetensors=True to load safetensors -model = AutoAWQForCausalLM.from_pretrained( - model_path, **{"low_cpu_mem_usage": True, "use_cache": False} -) -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - -# Quantize -model.quantize(tokenizer, quant_config=quant_config) - -# Save quantized model -model.save_quantized(quant_path) -tokenizer.save_pretrained(quant_path) - -print(f'Model is quantized and saved at "{quant_path}"') \ No newline at end of file diff --git a/examples/mixtral_quant.py b/examples/mixtral_quant.py deleted file mode 100644 index fea92f60..00000000 --- a/examples/mixtral_quant.py +++ /dev/null @@ -1,30 +0,0 @@ -from awq import AutoAWQForCausalLM -from transformers import AutoTokenizer - -model_path = 'mistralai/Mixtral-8x7B-Instruct-v0.1' -quant_path = 'mixtral-instruct-awq' -modules_to_not_convert = ["gate"] -quant_config = { - "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM", - "modules_to_not_convert": modules_to_not_convert -} - -# Load model -# NOTE: pass safetensors=True to load safetensors -model = AutoAWQForCausalLM.from_pretrained( - model_path, safetensors=True, **{"low_cpu_mem_usage": True} -) -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - -# Quantize -model.quantize( - tokenizer, - quant_config=quant_config, - modules_to_not_convert=modules_to_not_convert -) - -# Save quantized model -model.save_quantized(quant_path) -tokenizer.save_pretrained(quant_path) - -print(f'Model is quantized and saved at "{quant_path}"') \ No newline at end of file diff --git a/examples/quant_custom_data.py b/examples/quant_custom_data.py deleted file mode 100644 index 34e43862..00000000 --- a/examples/quant_custom_data.py +++ /dev/null @@ -1,35 +0,0 @@ -from datasets import load_dataset -from awq import AutoAWQForCausalLM -from transformers import AutoTokenizer - -model_path = 'lmsys/vicuna-7b-v1.5' -quant_path = 'vicuna-7b-v1.5-awq' -quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } - -# Load model -model = AutoAWQForCausalLM.from_pretrained(model_path) -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - -# Define data loading methods -def load_dolly(): - data = load_dataset('databricks/databricks-dolly-15k', split="train") - - # concatenate data - def concatenate_data(x): - return {"text": x['instruction'] + '\n' + x['context'] + '\n' + x['response']} - - concatenated = data.map(concatenate_data) - return [text for text in concatenated["text"]] - -def load_wikitext(): - data = load_dataset('wikitext', 'wikitext-2-raw-v1', split="train") - return [text for text in data["text"] if text.strip() != '' and len(text.split(' ')) > 20] - -# Quantize -model.quantize(tokenizer, quant_config=quant_config, calib_data=load_wikitext()) - -# Save quantized model -model.save_quantized(quant_path) -tokenizer.save_pretrained(quant_path) - -print(f'Model is quantized and saved at "{quant_path}"') \ No newline at end of file diff --git a/examples/basic_quant.py b/examples/quantize.py similarity index 83% rename from examples/basic_quant.py rename to examples/quantize.py index 79cf1ab2..13dbb720 100644 --- a/examples/basic_quant.py +++ b/examples/quantize.py @@ -1,12 +1,11 @@ from awq import AutoAWQForCausalLM from transformers import AutoTokenizer -model_path = 'lmsys/vicuna-7b-v1.5' -quant_path = 'vicuna-7b-v1.5-awq' +model_path = 'mistralai/Mistral-7B-Instruct-v0.2' +quant_path = 'mistral-instruct-v0.2-awq' quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } # Load model -# NOTE: pass safetensors=True to load safetensors model = AutoAWQForCausalLM.from_pretrained( model_path, **{"low_cpu_mem_usage": True, "use_cache": False} ) diff --git a/examples/tinyllama_generate.py b/examples/tinyllama_generate.py deleted file mode 100644 index ba079147..00000000 --- a/examples/tinyllama_generate.py +++ /dev/null @@ -1,36 +0,0 @@ -from awq import AutoAWQForCausalLM -from transformers import AutoTokenizer, TextStreamer - -quant_path = "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ" - -# Load model -model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False) -tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) -streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) - -# Convert prompt to tokens -prompt_template = """\ -<|im_start|>system -{system}<|im_end|> -<|im_start|>user -{prompt}<|im_end|> -<|im_start|>assistant -""" - -system = "You are a helpful assistant that answers precisely." - -prompt = "You're standing on the surface of the Earth. "\ - "You walk one mile south, one mile west and one mile north. "\ - "You end up exactly where you started. Where are you?" - -tokens = tokenizer( - prompt_template.format(system=system, prompt=prompt), - return_tensors='pt' -).input_ids.to("mps") - -# Generate output -generation_output = model.generate( - tokens, - streamer=streamer, - max_new_tokens=64 -) \ No newline at end of file diff --git a/examples/awq_train.py b/examples/train.py similarity index 91% rename from examples/awq_train.py rename to examples/train.py index 5e8fd0f5..12edb8fb 100644 --- a/examples/awq_train.py +++ b/examples/train.py @@ -10,11 +10,10 @@ def prepare_split(tokenizer): data = datasets.load_dataset("mhenrichsen/alpaca_2k_test", split="train") - prompt_template = "[INST] {system} {prompt} [/INST] {output}" + prompt_template = "[INST] {prompt} [/INST] {output}" def format_prompt(x): return prompt_template.format( - system="", prompt=x["instruction"], output=x["output"] ) @@ -26,7 +25,7 @@ def format_prompt(x): return data -model_path = "ybelkada/opt-125m-awq" +model_path = "TheBloke/Mistral-7B-v0.1-AWQ" # Load model model = AutoAWQForCausalLM.from_quantized(model_path, fuse_layers=False) @@ -56,7 +55,6 @@ def format_prompt(x): optim="adamw_torch", num_train_epochs=1, learning_rate=1e-4, - # fp16=True, evaluation_strategy="no", save_strategy="epoch", save_steps=100, diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 00000000..243ab9ea --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,82 @@ +site_name: AutoAWQ +repo_name: casper-hansen/AutoAWQ +repo_url: https://github.com/casper-hansen/AutoAWQ + +nav: +- index.md +- Examples: examples.md +- Reference: + - reference/index.md + +markdown_extensions: + toc: + permalink: true + markdown.extensions.codehilite: + guess_lang: false + admonition: null + codehilite: null + extra: null + pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format '' + pymdownx.tabbed: + alternate_style: true + pymdownx.tilde: null + attr_list: null + md_in_html: null + +plugins: + search: null + mkdocstrings: + handlers: + python: + paths: [awq] + options: + extensions: + - griffe_typingdoc + show_root_heading: true + show_if_no_docstring: true + inherited_members: true + members_order: source + separate_signature: true + unwrap_annotated: true + filters: + - '!^_' + merge_init_into_class: true + docstring_section_style: spacy + signature_crossrefs: true + show_symbol_type_heading: true + show_symbol_type_toc: true + +theme: + name: material + palette: + - media: '(prefers-color-scheme: light)' + scheme: default + primary: teal + accent: amber + toggle: + icon: material/lightbulb + name: Switch to dark mode + - media: '(prefers-color-scheme: dark)' + scheme: slate + primary: teal + accent: amber + toggle: + icon: material/lightbulb-outline + name: Switch to light mode + features: + - search.suggest + - search.highlight + - content.tabs.link + - navigation.indexes + - content.tooltips + - navigation.path + - content.code.annotate + - content.code.copy + - content.code.select + - navigation.tabs + icon: + repo: fontawesome/brands/github-alt \ No newline at end of file diff --git a/setup.py b/setup.py index dfdc5921..a6dea813 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ def get_kernels_whl_url( return f"https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v{release_version}/autoawq_kernels-{release_version}+{gpu_system_version}-cp{python_version}-cp{python_version}-{platform}_{architecture}.whl" -AUTOAWQ_VERSION = "0.1.8" +AUTOAWQ_VERSION = "0.2.0" PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1" CUDA_VERSION = os.getenv("CUDA_VERSION", None) or torch.version.cuda @@ -90,6 +90,7 @@ def get_kernels_whl_url( "tokenizers>=0.12.1", "accelerate", "datasets", + "zstandard", ] try: @@ -101,9 +102,9 @@ def get_kernels_whl_url( # kernels can be downloaded from pypi for cuda+121 only # for everything else, we need to download the wheels from github if not KERNELS_INSTALLED and (CUDA_VERSION or ROCM_VERSION): - if CUDA_VERSION.startswith("12"): + if CUDA_VERSION and CUDA_VERSION.startswith("12"): requirements.append("autoawq-kernels") - elif CUDA_VERSION.startswith("11") or ROCM_VERSION in ["561", "571"]: + elif CUDA_VERSION and CUDA_VERSION.startswith("11") or ROCM_VERSION in ["561", "571"]: gpu_system_version = ( f"cu{CUDA_VERSION}" if CUDA_VERSION else f"rocm{ROCM_VERSION}" ) @@ -130,6 +131,7 @@ def get_kernels_whl_url( install_requires=requirements, extras_require={ "eval": ["lm_eval>=0.4.0", "tabulate", "protobuf", "evaluate", "scipy"], + "dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"] }, **common_setup_kwargs, )