Skip to content

Commit

Permalink
FP8 example (openvinotoolkit#3062)
Browse files Browse the repository at this point in the history
On top of openvinotoolkit#3049

### Changes

- Added FP8 example.

### Reason for changes

- Examples coverage.

### Related tickets

- 155923

### Tests

- ubuntu test_examples 627 - passed
- windows test-examples 288 - passed
- GA Test examples 135 - passed

---------

Co-authored-by: Alexander Kozlov <[email protected]>
  • Loading branch information
KodiaqQ and AlexKoff88 authored Nov 27, 2024
1 parent f61aa89 commit 2db9fb9
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 1 deletion.
26 changes: 26 additions & 0 deletions examples/llm_compression/openvino/smollm2_360m_fp8/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Large Language Models FP8 Compression Example

This example demonstrates how to apply static FP8 quantization to [HuggingFaceTB/SmolLM2-360M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-360M-Instruct) model. It can be useful for evaluation and early HW enablement purposes.

## Prerequisites

To use this example:

- Create a separate Python* environment and activate it: `python3 -m venv nncf_env && source nncf_env/bin/activate`
- Install dependencies:

```bash
pip install -U pip
pip install -r requirements.txt
pip install ../../../../
```

## Run Example

To run example:

```bash
python main.py
```

It will automatically download the dataset and baseline model and save the resulting model.
128 changes: 128 additions & 0 deletions examples/llm_compression/openvino/smollm2_360m_fp8/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

import datasets
import numpy as np
import openvino as ov
from optimum.intel.openvino import OVModelForCausalLM
from transformers import AutoTokenizer

import nncf


def transform_fn(data, model, tokenizer):
tokenized_text = tokenizer(data["text"], return_tensors="np")
input_ids = tokenized_text["input_ids"]
attention_mask = tokenized_text["attention_mask"]

inputs = {}
inputs["input_ids"] = input_ids
inputs["attention_mask"] = tokenized_text["attention_mask"]
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1

# The magic forms KV cache as model inputs
batch_size = input_ids.shape[0]
for input_name in model.key_value_input_names:
model_inputs = model.model.input(input_name)
shape = model_inputs.get_partial_shape()
shape[0] = batch_size
if shape[2].is_dynamic:
shape[2] = 0
else:
shape[1] = 0
inputs[input_name] = ov.Tensor(model_inputs.get_element_type(), shape.get_shape())

inputs["position_ids"] = position_ids
return inputs


def generate_answers(questions, model, tokenizer, max_new_tokens=50):
messages = [
{"role": "system", "content": "You are a chatbot who always responds as short as possible."},
{"role": "user", "content": "What is the capital of Spain?"},
{"role": "assistant", "content": "Madrid."},
]
answers_by_questions = {}
model.request = None

for question in questions:
messages.append({"role": "user", "content": question})
input_ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(device=model.device)
input_len = len(input_ids[0])

output = model.generate(input_ids, max_new_tokens=max_new_tokens, do_sample=False)[0]
answer = tokenizer.decode(output[input_len:], skip_special_tokens=True)
answers_by_questions[question] = answer
messages.append({"role": "assistant", "content": answer})

model.request = None
return answers_by_questions


def main():
MODEL_ID = "HuggingFaceTB/SmolLM2-360M-Instruct"
OUTPUT_DIR = "smollm2_360m_compressed"

dataset = datasets.load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
# Filtering to remove empty samples from the dataset
dataset = dataset.filter(lambda example: len(example["text"]) > 1)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = OVModelForCausalLM.from_pretrained(
MODEL_ID,
export=True,
load_in_8bit=False,
compile=False,
stateful=False,
ov_config={"INFERENCE_PRECISION_HINT": "f32"},
)

questions = [
"What is the capital of France?",
"What is the highest mountain in the Alps?",
"What is the largest city in Canada?",
"What is the most visited city in Japan?",
]

answers_by_questions = generate_answers(questions, model, tokenizer)
print(f"Non-optimized model outputs:\n{answers_by_questions}\n")

quantization_dataset = nncf.Dataset(dataset, partial(transform_fn, model=model, tokenizer=tokenizer))

model.model = nncf.quantize(
model.model,
calibration_dataset=quantization_dataset,
# Only PERFORMANCE preset supports in combination with FP8 quantization mode
preset=nncf.QuantizationPreset.PERFORMANCE,
mode=nncf.QuantizationMode.FP8_E4M3,
model_type=nncf.ModelType.TRANSFORMER,
# SmoothQuant algorithm is not needed for FP8 quantization
advanced_parameters=nncf.AdvancedQuantizationParameters(
smooth_quant_alphas=nncf.AdvancedSmoothQuantParameters(matmul=-1)
),
)
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

model = OVModelForCausalLM.from_pretrained(
OUTPUT_DIR, ov_config={"DYNAMIC_QUANTIZATION_GROUP_SIZE": "0", "INFERENCE_PRECISION_HINT": "f32"}
)
answers_by_questions = generate_answers(questions, model, tokenizer)
print(f"Optimized model outputs:\n{answers_by_questions}\n")
return answers_by_questions


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
datasets
openvino==2024.5
optimum-intel[openvino]
transformers
onnx<1.16.2
3 changes: 2 additions & 1 deletion tests/cross_fw/examples/.test_durations
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
"tests/cross_fw/examples/test_examples.py::test_examples[post_training_quantization_torch_ssd300_vgg16]": 231.613,
"tests/cross_fw/examples/test_examples.py::test_examples[quantization_aware_training_torch_anomalib]": 478.797,
"tests/cross_fw/examples/test_examples.py::test_examples[quantization_aware_training_torch_resnet18]": 1251.144,
"tests/cross_fw/examples/test_examples.py::test_examples[post_training_quantization_torch_fx_resnet18]": 412.243
"tests/cross_fw/examples/test_examples.py::test_examples[post_training_quantization_torch_fx_resnet18]": 412.243,
"tests/cross_fw/examples/test_examples.py::test_examples[fp8_llm_quantization]": 229.69
}
13 changes: 13 additions & 0 deletions tests/cross_fw/examples/example_scope.json
Original file line number Diff line number Diff line change
Expand Up @@ -260,5 +260,18 @@
"int8_model_size": 5.677968978881836,
"model_compression_rate": 3.7654144877995197
}
},
"fp8_llm_quantization": {
"backend": "openvino",
"requirements": "examples/llm_compression/openvino/smollm2_360m_fp8/requirements.txt",
"cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz",
"accuracy_metrics": {
"answers": [
"Paris.",
"Mont Blanc.",
"Toronto.",
"Tokyo."
]
}
}
}
8 changes: 8 additions & 0 deletions tests/cross_fw/examples/run_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,14 @@ def llm_compression_synthetic() -> Dict[str, float]:
return {"word_count": len(result.split())}


def fp8_llm_quantization() -> Dict[str, float]:
from examples.llm_compression.openvino.smollm2_360m_fp8.main import main as fp8_llm_quantization_main

result = fp8_llm_quantization_main()

return {"answers": list(result.values())}


def post_training_quantization_torch_fx_resnet18():
from examples.post_training_quantization.torch_fx.resnet18.main import main as resnet18_main

Expand Down

0 comments on commit 2db9fb9

Please sign in to comment.