Skip to content

Commit

Permalink
Torch native q conformance
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Oct 23, 2024
1 parent 16f3b97 commit d117e35
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
2 changes: 1 addition & 1 deletion tests/post_training/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def save_compressed_model(self) -> None:
self.path_compressed_ir = self.output_model_dir / "model.xml"
ov.serialize(ov_model, self.path_compressed_ir)
elif self.backend == BackendType.FX_TORCH:
exported_model = torch.export.export(self.compressed_model, (self.dummy_tensor,))
exported_model = torch.export.export(self.model, (self.dummy_tensor,))
ov_model = ov.convert_model(exported_model, example_input=self.dummy_tensor.cpu(), input=self.input_size)
self.path_compressed_ir = self.output_model_dir / "model.xml"
ov.serialize(ov_model, self.path_compressed_ir)
Expand Down
33 changes: 32 additions & 1 deletion tests/post_training/pipelines/image_classification_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def process_result(request, userdata):
def _validate_torch_compile(
self, val_loader: torch.utils.data.DataLoader, predictions: np.ndarray, references: np.ndarray
):
compiled_model = torch.compile(self.compressed_model, backend="openvino")
# compiled_model = torch.compile(self.compressed_model, backend="openvino")
compiled_model = torch.compile(self.compressed_model)
for i, (images, target) in enumerate(val_loader):
# W/A for memory leaks when using torch DataLoader and OpenVINO
pred = compiled_model(images)
Expand Down Expand Up @@ -104,3 +105,33 @@ def _validate(self):

self.run_info.metric_name = "Acc@1"
self.run_info.metric_value = acc_top1

def _compress_torch_native(self):
import os

os.environ["TORCHINDUCTOR_FREEZING"] = "1"

from torch.ao.quantization.quantize_pt2e import convert_pt2e
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.ao.quantization.quantizer.x86_inductor_quantizer import get_default_x86_inductor_quantization_config

quantizer = X86InductorQuantizer()
quantizer.set_global(get_default_x86_inductor_quantization_config())

prepared_model = prepare_pt2e(self.model, quantizer)
for data in self.calibration_dataset.get_inference_data():
prepared_model(data)
self.compressed_model = convert_pt2e(prepared_model)

def _compress_nncf_pt2e(self):
pass

def _compress(self):
"""
Quantize self.model
"""
if self.backend != BackendType.FX_TORCH:
super()._compress()

self._compress_torch_native()
7 changes: 0 additions & 7 deletions tests/post_training/test_quantize_conformance.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ def fixture_run_benchmark_app(pytestconfig):
return pytestconfig.getoption("benchmark")


@pytest.fixture(scope="session", name="validate_in_backend")
def fixture_validate_in_backend(pytestconfig):
return pytestconfig.getoption("validate_in_backend")


@pytest.fixture(scope="session", name="extra_columns")
def fixture_extra_columns(pytestconfig):
return pytestconfig.getoption("extra_columns")
Expand Down Expand Up @@ -267,7 +262,6 @@ def test_ptq_quantization(
run_torch_cuda_backend: bool,
subset_size: Optional[int],
run_benchmark_app: bool,
validate_in_backend: bool,
capsys: pytest.CaptureFixture,
extra_columns: bool,
memory_monitor: bool,
Expand Down Expand Up @@ -295,7 +289,6 @@ def test_ptq_quantization(
"data_dir": data_dir,
"no_eval": no_eval,
"run_benchmark_app": run_benchmark_app,
"validate_in_backend": validate_in_backend,
"batch_size": batch_size,
"memory_monitor": memory_monitor,
}
Expand Down

0 comments on commit d117e35

Please sign in to comment.