From 4d3f68ed3ce69889d7a985653d3dd6cc1816c2c6 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Fri, 30 Aug 2024 15:09:18 +0800 Subject: [PATCH] add optimum-intel ipex backend into benchmark (#250) Signed-off-by: YAO Matrix --- .github/workflows/test_cli_cpu_ipex.yaml | 48 +++++++ README.md | 1 + docker/cpu/Dockerfile | 6 +- examples/ipex_llama.yaml | 37 +++++ optimum_benchmark/__init__.py | 2 + optimum_benchmark/backends/__init__.py | 2 + optimum_benchmark/backends/ipex/__init__.py | 0 optimum_benchmark/backends/ipex/backend.py | 131 ++++++++++++++++++ optimum_benchmark/backends/ipex/config.py | 37 +++++ optimum_benchmark/backends/ipex/utils.py | 10 ++ optimum_benchmark/cli.py | 2 + optimum_benchmark/import_utils.py | 5 +- setup.py | 3 +- .../cpu_inference_ipex_text_decoders.yaml | 11 ++ .../cpu_inference_ipex_text_encoders.yaml | 11 ++ 15 files changed, 302 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/test_cli_cpu_ipex.yaml create mode 100644 examples/ipex_llama.yaml create mode 100644 optimum_benchmark/backends/ipex/__init__.py create mode 100644 optimum_benchmark/backends/ipex/backend.py create mode 100644 optimum_benchmark/backends/ipex/config.py create mode 100644 optimum_benchmark/backends/ipex/utils.py create mode 100644 tests/configs/cpu_inference_ipex_text_decoders.yaml create mode 100644 tests/configs/cpu_inference_ipex_text_encoders.yaml diff --git a/.github/workflows/test_cli_cpu_ipex.yaml b/.github/workflows/test_cli_cpu_ipex.yaml new file mode 100644 index 00000000..f86a6f4a --- /dev/null +++ b/.github/workflows/test_cli_cpu_ipex.yaml @@ -0,0 +1,48 @@ +name: CLI CPU IPEX Tests + +on: + workflow_dispatch: + push: + branches: + - main + paths: + - .github/workflows/test_cli_cpu_ipex.yaml + - "optimum_benchmark/**" + - "docker/**" + - "tests/**" + - "setup.py" + pull_request: + branches: + - main + paths: + - .github/workflows/test_cli_cpu_ipex.yaml + - "optimum_benchmark/**" + - "docker/**" + - "tests/**" + - "setup.py" + +concurrency: + cancel-in-progress: true + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + +jobs: + run_cli_cpu_ipex_tests: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install requirements + run: | + pip install --upgrade pip + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install -e .[testing,ipex,diffusers,timm] + + - name: Run tests + run: pytest -s -k "cli and cpu and ipex" diff --git a/README.md b/README.md index e2f824d3..f577e716 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,7 @@ pip install -e . Depending on the backends you want to use, you can install `optimum-benchmark` with the following extras: - PyTorch (default): `pip install optimum-benchmark` +- IPEX: `pip install optimum-benchmark[ipex]` - OpenVINO: `pip install optimum-benchmark[openvino]` - Torch-ORT: `pip install optimum-benchmark[torch-ort]` - OnnxRuntime: `pip install optimum-benchmark[onnxruntime]` diff --git a/docker/cpu/Dockerfile b/docker/cpu/Dockerfile index c2bdf67e..169a6f74 100644 --- a/docker/cpu/Dockerfile +++ b/docker/cpu/Dockerfile @@ -22,10 +22,12 @@ ENV PATH="/home/user/.local/bin:${PATH}" RUN apt-get update && apt-get install -y --no-install-recommends \ libgl1-mesa-dev libglib2.0-0 \ sudo build-essential git bash-completion \ - python3.10 python3-pip python3.10-dev && \ + python3.10 python3-pip python3.10-dev google-perftools && \ apt-get clean && rm -rf /var/lib/apt/lists/* && \ update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 && \ - pip install --no-cache-dir --upgrade pip setuptools wheel + pip install --no-cache-dir --upgrade pip setuptools wheel + +ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4" # Install PyTorch ARG TORCH_VERSION=stable diff --git a/examples/ipex_llama.yaml b/examples/ipex_llama.yaml new file mode 100644 index 00000000..0ff4c4df --- /dev/null +++ b/examples/ipex_llama.yaml @@ -0,0 +1,37 @@ +defaults: + - benchmark + - scenario: inference + - launcher: process + - backend: ipex + - _base_ + - _self_ + +name: ipex_llama + +launcher: + numactl: true + numactl_kwargs: + cpunodebind: 0 + membind: 0 + +scenario: + latency: true + memory: true + + warmup_runs: 10 + iterations: 10 + duration: 10 + + input_shapes: + batch_size: 1 + sequence_length: 256 + generate_kwargs: + max_new_tokens: 32 + min_new_tokens: 32 + +backend: + device: cpu + export: true + no_weights: true + torch_dtype: bfloat16 + model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 diff --git a/optimum_benchmark/__init__.py b/optimum_benchmark/__init__.py index 83845b28..264e8e38 100644 --- a/optimum_benchmark/__init__.py +++ b/optimum_benchmark/__init__.py @@ -1,5 +1,6 @@ from .backends import ( BackendConfig, + IPEXConfig, INCConfig, LlamaCppConfig, LLMSwarmConfig, @@ -24,6 +25,7 @@ "BenchmarkReport", "EnergyStarConfig", "InferenceConfig", + "IPEXConfig", "INCConfig", "InlineConfig", "LauncherConfig", diff --git a/optimum_benchmark/backends/__init__.py b/optimum_benchmark/backends/__init__.py index e78bb2e7..8ee6e155 100644 --- a/optimum_benchmark/backends/__init__.py +++ b/optimum_benchmark/backends/__init__.py @@ -1,6 +1,7 @@ from .config import BackendConfig from .llama_cpp.config import LlamaCppConfig from .llm_swarm.config import LLMSwarmConfig +from .ipex.config import IPEXConfig from .neural_compressor.config import INCConfig from .onnxruntime.config import ORTConfig from .openvino.config import OVConfig @@ -13,6 +14,7 @@ __all__ = [ "PyTorchConfig", "ORTConfig", + "IPEXConfig", "OVConfig", "TorchORTConfig", "TRTLLMConfig", diff --git a/optimum_benchmark/backends/ipex/__init__.py b/optimum_benchmark/backends/ipex/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/optimum_benchmark/backends/ipex/backend.py b/optimum_benchmark/backends/ipex/backend.py new file mode 100644 index 00000000..9bb8c065 --- /dev/null +++ b/optimum_benchmark/backends/ipex/backend.py @@ -0,0 +1,131 @@ +import inspect +from collections import OrderedDict +from tempfile import TemporaryDirectory +from typing import Any, Dict + +import torch +from hydra.utils import get_class + +from ...generators.dataset_generator import DatasetGenerator +from ...import_utils import is_accelerate_available, is_torch_distributed_available +from ...task_utils import TEXT_GENERATION_TASKS +from ..base import Backend +from ..transformers_utils import fast_weights_init +from .config import IPEXConfig +from .utils import TASKS_TO_IPEXMODEL + +if is_accelerate_available(): + from accelerate import Accelerator + +if is_torch_distributed_available(): + import torch.distributed + + +class IPEXBackend(Backend[IPEXConfig]): + NAME: str = "ipex" + + def __init__(self, config: IPEXConfig) -> None: + super().__init__(config) + + if self.config.task in TASKS_TO_IPEXMODEL: + self.ipexmodel_class = get_class(TASKS_TO_IPEXMODEL[self.config.task]) + self.logger.info(f"\t+ Using IPEXModel class {self.ipexmodel_class.__name__}") + else: + raise NotImplementedError(f"IPEXBackend does not support task {self.config.task}") + + + def load(self) -> None: + self.logger.info("\t+ Creating backend temporary directory") + self.tmpdir = TemporaryDirectory() + + if self.config.no_weights: + self.logger.info("\t+ Creating no weights IPEXModel") + self.create_no_weights_model() + self.logger.info("\t+ Loading no weights IPEXModel") + self._load_ipexmodel_with_no_weights() + else: + self.logger.info("\t+ Loading pretrained IPEXModel") + self._load_ipexmodel_from_pretrained() + + self.tmpdir.cleanup() + + def _load_automodel_from_pretrained(self) -> None: + self.pretrained_model = self.automodel_loader.from_pretrained(self.config.model, **self.config.model_kwargs) + + def _load_automodel_with_no_weights(self) -> None: + original_model, self.config.model = self.config.model, self.no_weights_model + + with fast_weights_init(): + self._load_automodel_from_pretrained() + + self.logger.info("\t+ Tying model weights") + self.pretrained_model.tie_weights() + + self.config.model = original_model + + def _load_ipexmodel_from_pretrained(self) -> None: + self.pretrained_model = self.ipexmodel_class.from_pretrained( + self.config.model, + export=self.config.export, + device=self.config.device, + **self.config.model_kwargs, + **self.automodel_kwargs, + ) + + def _load_ipexmodel_with_no_weights(self) -> None: + with fast_weights_init(): + original_model, self.config.model = self.config.model, self.no_weights_model + original_export, self.config.export = self.config.export, True + self.logger.info("\t+ Loading no weights IPEXModel") + self._load_ipexmodel_from_pretrained() + self.config.export = original_export + self.config.model = original_model + + @property + def automodel_kwargs(self) -> Dict[str, Any]: + kwargs = {} + + if self.config.torch_dtype is not None: + kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype) + + print(kwargs) + + return kwargs + + @property + def is_dp_distributed(self) -> bool: + return is_torch_distributed_available() and torch.distributed.is_initialized() + + def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]: + if self.is_dp_distributed: + if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0: + raise ValueError( + f"Batch size {input_shapes['batch_size']} must be divisible by " + f"data parallel world size {torch.distributed.get_world_size()}" + ) + # distributing batch size across processes + input_shapes["batch_size"] //= torch.distributed.get_world_size() + + # registering input shapes for usage during model reshaping + self.input_shapes = input_shapes + + return input_shapes + + def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + if self.is_dp_distributed: + with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: + inputs = process_inputs + + return inputs + + def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: + return self.pretrained_model.forward(**inputs, **kwargs) + + def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: + return self.pretrained_model.generate(**inputs, **kwargs) + + def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: + return self.pretrained_model.generate(**inputs, **kwargs) + + def call(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: + return self.pretrained_model(**inputs, **kwargs) diff --git a/optimum_benchmark/backends/ipex/config.py b/optimum_benchmark/backends/ipex/config.py new file mode 100644 index 00000000..3beffe87 --- /dev/null +++ b/optimum_benchmark/backends/ipex/config.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +from ...import_utils import ipex_version +from ..config import BackendConfig + +TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"] + +@dataclass +class IPEXConfig(BackendConfig): + name: str = "ipex" + version: Optional[str] = ipex_version() + _target_: str = "optimum_benchmark.backends.ipex.backend.IPEXBackend" + + # load options + no_weights: bool = False + torch_dtype: Optional[str] = None + + # export options + export: bool = True + + def __post_init__(self): + super().__post_init__() + + self.device = self.device.lower() + if self.device not in ["cpu", "gpu"]: + raise ValueError(f"IPEXBackend only supports CPU devices, got {self.device}") + + if self.model_kwargs.get("torch_dtype", None) is not None: + raise ValueError( + "`torch_dtype` is an explicit argument in the PyTorch backend config. " + "Please remove it from the `model_kwargs` and set it in the backend config directly." + ) + + if self.torch_dtype is not None and self.torch_dtype not in TORCH_DTYPES: + raise ValueError(f"`torch_dtype` must be one of {TORCH_DTYPES}. Got {self.torch_dtype} instead.") + diff --git a/optimum_benchmark/backends/ipex/utils.py b/optimum_benchmark/backends/ipex/utils.py new file mode 100644 index 00000000..4f98834c --- /dev/null +++ b/optimum_benchmark/backends/ipex/utils.py @@ -0,0 +1,10 @@ +TASKS_TO_IPEXMODEL = { + "fill-mask": "optimum.intel.IPEXModelForMaskedLM", + "text-generation": "optimum.intel.IPEXModelForCausalLM", + "text-classification": "optimum.intel.IPEXModelForSequenceClassification", + "token-classification": "optimum.intel.IPEXModelForTokenClassification", + "question-answering": "optimum.intel.IPEXModelForQuestionAnswering", + "image-classification": "optimum.intel.IPEXModelForImageClassification", + "audio-classification": "optimum.intel.IPEXModelForAudioClassification", +} + diff --git a/optimum_benchmark/cli.py b/optimum_benchmark/cli.py index 57c6b054..769340d2 100644 --- a/optimum_benchmark/cli.py +++ b/optimum_benchmark/cli.py @@ -10,6 +10,7 @@ Benchmark, BenchmarkConfig, EnergyStarConfig, + IPEXConfig, INCConfig, InferenceConfig, InlineConfig, @@ -36,6 +37,7 @@ # benchmark configuration cs.store(name="benchmark", node=BenchmarkConfig) # backends configurations +cs.store(group="backend", name=IPEXConfig.name, node=IPEXConfig) cs.store(group="backend", name=OVConfig.name, node=OVConfig) cs.store(group="backend", name=PyTorchConfig.name, node=PyTorchConfig) cs.store(group="backend", name=ORTConfig.name, node=ORTConfig) diff --git a/optimum_benchmark/import_utils.py b/optimum_benchmark/import_utils.py index e731bc74..cde0ceb1 100644 --- a/optimum_benchmark/import_utils.py +++ b/optimum_benchmark/import_utils.py @@ -15,6 +15,7 @@ _pynvml_available = importlib.util.find_spec("pynvml") is not None _torch_distributed_available = importlib.util.find_spec("torch.distributed") is not None _onnxruntime_available = importlib.util.find_spec("onnxruntime") is not None +_ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None _openvino_available = importlib.util.find_spec("openvino") is not None _neural_compressor_available = importlib.util.find_spec("neural_compressor") is not None _codecarbon_available = importlib.util.find_spec("codecarbon") is not None @@ -157,11 +158,13 @@ def onnxruntime_version(): except importlib.metadata.PackageNotFoundError: return None - def openvino_version(): if _openvino_available: return importlib.metadata.version("openvino") +def ipex_version(): + if _ipex_available: + return importlib.metadata.version("intel_extension_for_pytorch") def neural_compressor_version(): if _neural_compressor_available: diff --git a/setup.py b/setup.py index d5abc58c..b85ae9a6 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ except Exception as error: assert False, "Error: Could not open '%s' due %s\n" % (filepath, error) -MIN_OPTIMUM_VERSION = "1.16.0" +MIN_OPTIMUM_VERSION = "1.18.0" INSTALL_REQUIRES = [ # HF dependencies "transformers", @@ -69,6 +69,7 @@ "quality": ["ruff"], "testing": ["pytest", "hydra-joblib-launcher"], # optimum backends + "ipex":[f"optimum[ipex]>={MIN_OPTIMUM_VERSION}"], "openvino": [f"optimum[openvino,nncf]>={MIN_OPTIMUM_VERSION}"], "onnxruntime": [f"optimum[onnxruntime]>={MIN_OPTIMUM_VERSION}"], "onnxruntime-gpu": [f"optimum[onnxruntime-gpu]>={MIN_OPTIMUM_VERSION}"], diff --git a/tests/configs/cpu_inference_ipex_text_decoders.yaml b/tests/configs/cpu_inference_ipex_text_decoders.yaml new file mode 100644 index 00000000..e2580170 --- /dev/null +++ b/tests/configs/cpu_inference_ipex_text_decoders.yaml @@ -0,0 +1,11 @@ +defaults: + # order of inheritance, last one overrides previous ones + - _base_ # inherits from base config + - _cpu_ # inherits from cpu config + - _inference_ # inherits from inference config + - _text_decoders_ # inherits from text decoders config + - _no_weights_ # inherits from no weights config + - _self_ # hydra 1.1 compatibility + - override backend: ipex + +name: cpu_inference_ipex_text_decoders diff --git a/tests/configs/cpu_inference_ipex_text_encoders.yaml b/tests/configs/cpu_inference_ipex_text_encoders.yaml new file mode 100644 index 00000000..ffe61f34 --- /dev/null +++ b/tests/configs/cpu_inference_ipex_text_encoders.yaml @@ -0,0 +1,11 @@ +defaults: + # order of inheritance, last one overrides previous ones + - _base_ # inherits from base config + - _cpu_ # inherits from cpu config + - _inference_ # inherits from inference config + - _text_encoders_ # inherits from text encoders sweep config + - _no_weights_ # inherits from no weights config + - _self_ # hydra 1.1 compatibility + - override backend: ipex + +name: cpu_inference_ipex_text_encoders