Skip to content

Commit

Permalink
testing HQQ [not for land]
Browse files Browse the repository at this point in the history
Summary:

for eval=5
wikitext: {'word_perplexity,none': 11.49343838017535, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6110947678444059, 'byte_perplexity_stderr,none':

for eval all
...

Test Plan: sh run.sh

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: e1564ea867790825ad8a00c8de8a672a349b8a48
Pull Request resolved: #155
  • Loading branch information
HDCharles committed Apr 9, 2024
1 parent 095b222 commit 55b9f6e
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 2 deletions.
11 changes: 10 additions & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,16 @@ def _load_model(checkpoint_path, device, precision, use_tp):
simple_quantizer = WeightOnlyInt8QuantHandler(model)
model = simple_quantizer.convert_for_runtime()

if "int4" in str(checkpoint_path):
if "int4-hqq" in str(checkpoint_path):
print("Using int4 weight-only HQQ quantization.")
from quantize import WeightOnlyInt4HqqQuantHandler
path_comps = checkpoint_path.name.split(".")
assert path_comps[-3].startswith("g")
assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!"
groupsize = int(path_comps[-3][1:])
quantizer = WeightOnlyInt4HqqQuantHandler(model, groupsize=groupsize)
model = quantizer._convert_for_runtime()
elif "int4" in str(checkpoint_path):
print("Using int4 weight-only quantization!")
path_comps = checkpoint_path.name.split(".")
assert path_comps[-3].startswith("g")
Expand Down
41 changes: 40 additions & 1 deletion quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,33 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)

# TODO a hacky placeholder class
class WeightOnlyInt4HqqQuantHandler:
def __init__(self, mod, groupsize):
self.mod = mod
self.groupsize = groupsize

def _create_quantized_state_dict(self):
from hqq.core.quantize import Quantizer # TODO maybe torchao

for m in self.mod.modules():
for name, child in m.named_children():
if isinstance(child, torch.nn.Linear):
child.weight = torch.nn.Parameter(
Quantizer.dequantize(
*Quantizer.quantize(
child.weight,
nbits=4,
group_size=self.groupsize,
axis=1,
)
)
)

return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict()

def _convert_for_runtime(self):
return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True)

def quantize(
checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
Expand Down Expand Up @@ -592,6 +619,18 @@ def quantize(
dir_name = checkpoint_path.parent
base_name = checkpoint_path.name
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.{device}.pth")

elif mode == 'int4-hqq':
print("Quantizing model weights for int4 using HQQ")
quant_handler = WeightOnlyInt4HqqQuantHandler(model, groupsize)
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))

quantized_state_dict = quant_handler._create_quantized_state_dict()
dir_name = checkpoint_path.parent
base_name = checkpoint_path.name
new_base_name = base_name.replace('.pth', f"{label}int4-hqq.g{groupsize}.{device}.pth")
else:
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")

Expand All @@ -606,7 +645,7 @@ def quantize(
import argparse
parser = argparse.ArgumentParser(description='Quantize a model.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq', 'int4-hqq'], help='type of quantization to perform')
parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.')
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
Expand Down
27 changes: 27 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf

# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working
# echo "base"
# export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5


python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-hqq
# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-hqq.g32.cuda.pth --compile
python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-hqq.g32.cuda.pth --tasks wikitext

python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --compile
python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
# broken

# export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5

0 comments on commit 55b9f6e

Please sign in to comment.