Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add brevitas channel-wise support #807

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def pytest_addoption(parser):
)

parser.addoption(
"--no-flaky", action="store_true", default=False, help="Don't run known flaky tests."
"--no-flaky",
action="store_true",
default=False,
help="Don't run known flaky tests.",
)


Expand Down
13 changes: 11 additions & 2 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto):
# Create a GEMM node which combines the MatMul and Add operations
gemm_node = helper.make_node(
"Gemm", # op_type
[matmul_node.input[0], matmul_node.input[1], bias_other_input_node_name], # inputs
[
matmul_node.input[0],
matmul_node.input[1],
bias_other_input_node_name,
], # inputs
[add_node.output[0]], # outputs
name="Gemm_Node",
alpha=1.0,
Expand Down Expand Up @@ -149,9 +153,14 @@ def get_equivalent_numpy_forward_from_torch(

arguments = list(inspect.signature(torch_module.forward).parameters)

if isinstance(dummy_input, torch.Tensor):
dummy_input = dummy_input.to("cpu")
else:
dummy_input = tuple(elt.to("cpu") for elt in dummy_input)

# Export to ONNX
torch.onnx.export(
torch_module,
torch_module.to("cpu"),
dummy_input,
str(output_onnx_file_path),
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
Comment on lines +156 to 166
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the module and the inputs to be on CPU for the exporter to work properly

Expand Down
212 changes: 201 additions & 11 deletions src/concrete/ml/pytest/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
import brevitas.nn as qnn
import numpy
import torch
from brevitas.quant import Int8ActPerTensorFloat, Int8WeightPerTensorFloat, IntBias
from brevitas.core.restrict_val import FloatRestrictValue, RestrictValueType
from brevitas.quant import (
Int8AccumulatorAwareWeightQuant,
Int8AccumulatorAwareZeroCenterWeightQuant,
Int8ActPerTensorFloat,
Int8WeightPerTensorFloat,
IntBias,
Uint8ActPerTensorFloat,
)
from torch import nn
from torch.nn.utils import prune

Expand Down Expand Up @@ -38,7 +46,7 @@ def forward(self, x, y):
return x + y + self.value, (x - y) ** 2


class SimpleNet(torch.nn.Module):
class SimpleNet(nn.Module):
"""Fake torch model used to generate some onnx."""

def __init__(self) -> None:
Expand Down Expand Up @@ -292,7 +300,7 @@ def forward(self, x):
return x


class NetWithLoops(torch.nn.Module):
class NetWithLoops(nn.Module):
"""Torch model, where we reuse some elements in a loop.

Torch model, where we reuse some elements in a loop in the forward and don't expect the
Expand Down Expand Up @@ -538,7 +546,7 @@ def step(x, bias):
return x


class NetWithConcatUnsqueeze(torch.nn.Module):
class NetWithConcatUnsqueeze(nn.Module):
"""Torch model to test the concat and unsqueeze operators."""

def __init__(self, activation_function, input_output, n_fc_layers):
Expand Down Expand Up @@ -1004,6 +1012,7 @@ def __init__(self, use_conv, use_qat, inp_size, n_bits):
layer_obj = self.mixing_layer

layer_obj.weight.data = torch.from_numpy(np_weights).float()
assert layer_obj.bias is not None
layer_obj.bias.data = torch.rand(size=(1,))

def forward(self, x):
Expand Down Expand Up @@ -1216,12 +1225,12 @@ def forward(self, x):
# for example a 4d tensor NCHW, padded with [1, 2, 2, 3] is padded
# along the last 2 dimensions, with 1 cell to the left and 2 to the right (dimension 4: W)
# and 2 cells at the top and 3 at the bottom (dimension 3: H)
x = torch.nn.functional.pad(x, (3, 2))
x = torch.nn.functional.pad(x, (1, 2, 3, 4))
x = nn.functional.pad(x, (3, 2))
x = nn.functional.pad(x, (1, 2, 3, 4))

# Concrete ML only supports padding on the last two dimensions as this is the
# most common setting
x = torch.nn.functional.pad(x, (1, 1, 2, 2, 0, 0, 0, 0))
x = nn.functional.pad(x, (1, 1, 2, 2, 0, 0, 0, 0))
return x


Expand Down Expand Up @@ -1340,7 +1349,12 @@ class ConcatFancyIndexing(nn.Module):
"""Concat with fancy indexing."""

def __init__(
self, input_shape, hidden_shape, output_shape, n_bits: int = 4, n_blocks: int = 3
self,
input_shape,
hidden_shape,
output_shape,
n_bits: int = 4,
n_blocks: int = 3,
) -> None:
"""Torch Model.

Expand All @@ -1361,7 +1375,10 @@ def __init__(

self.quant_2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc2 = qnn.QuantLinear(
hidden_shape * self.n_blocks, hidden_shape, bias=True, weight_bit_width=n_bits
hidden_shape * self.n_blocks,
hidden_shape,
bias=True,
weight_bit_width=n_bits,
)

self.quant_3 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
Expand Down Expand Up @@ -1393,7 +1410,7 @@ def forward(self, x):
return x


class PartialQATModel(torch.nn.Module):
class PartialQATModel(nn.Module):
"""A model with a QAT Module."""

def __init__(self, input_shape: int, output_shape: int, n_bits: int):
Expand Down Expand Up @@ -1442,7 +1459,7 @@ def forward(self, input1):
return output


class ManualLogisticRegressionTraining(torch.nn.Module):
class ManualLogisticRegressionTraining(nn.Module):
"""PyTorch module for performing SGD training."""

def __init__(self, learning_rate=0.1):
Expand Down Expand Up @@ -1665,3 +1682,176 @@ def forward(self, x):
x = self.relu(x)
x = self.linear(x)
return x


# pylint: disable-next=too-many-ancestors
class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat):
"""CommonIntWeightPerChannelQuant."""

scaling_per_output_channel = True
Comment on lines +1688 to +1691
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The per-channel quantizer from Brevitas



# pylint: disable-next=too-many-ancestors
class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant):
"""CommonIntAccumulatorAwareWeightQuant."""

restrict_scaling_impl = FloatRestrictValue # backwards compatibility
bit_width = None


# pylint: disable-next=too-many-ancestors
class CommonIntAccumulatorAwareZeroCenterWeightQuant(Int8AccumulatorAwareZeroCenterWeightQuant):
"""CommonIntAccumulatorAwareZeroCenterWeightQuant."""

bit_width = None


# pylint: disable-next=too-many-ancestors
class CommonUintActQuant(Uint8ActPerTensorFloat):
"""CommonUintActQuant."""

bit_width = None
restrict_scaling_type = RestrictValueType.LOG_FP


def weight_init(layer: nn.Module):
"""Initialize layer weights.

Arguments:
layer (nn.Module): a conv2d layer
"""

if isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight, nn.init.calculate_gain("relu"))
if layer.bias is not None:
layer.bias.data.zero_()


# pylint: disable-next=too-many-instance-attributes
class FloatLeNet(nn.Module):
"""Floating point LeNet."""

def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu1 = nn.ReLU(inplace=True)

self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu2 = nn.ReLU(inplace=True)

self.fc1 = nn.Linear(400, 120, bias=True)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84, bias=True)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, 10, bias=True)

self.apply(weight_init)

def forward(self, x: torch.Tensor):
"""Forward function.

Arguments:
x (torch.Tensor): input image

Returns:
Neural network prediction
"""
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = torch.flatten(x, 1)
x = self.relu3(self.fc1(x))
x = self.relu4(self.fc2(x))
x = self.fc3(x)
return x


# pylint: disable-next=too-many-instance-attributes
class QuantLeNet(FloatLeNet):
"""Quantized LeNet with per-channel quantization."""

def __init__(
self,
weight_bit_width=4,
act_bit_width=4,
acc_bit_width=32,
weight_quant=CommonIntAccumulatorAwareWeightQuant,
):
super().__init__()

self.conv1 = qnn.QuantConv2d(
bias=False,
in_channels=1,
out_channels=6,
kernel_size=5,
stride=1,
padding=0,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu1 = qnn.QuantReLU(
inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width
)

self.conv2 = qnn.QuantConv2d(
bias=False,
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=0,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu2 = qnn.QuantReLU(
inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width
)

self.fc1 = qnn.QuantLinear(
400,
120,
bias=True,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.relu3 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width)
self.fc2 = qnn.QuantLinear(
120,
84,
bias=True,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.relu4 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width)
self.fc3 = qnn.QuantLinear(
84,
10,
bias=True,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)

self.apply(weight_init)
Comment on lines +1731 to +1857
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A LeNet provided by a user

9 changes: 7 additions & 2 deletions src/concrete/ml/quantization/base_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, TextIO, Tuple, Type, Union, cast

import numpy
import numpy.typing as npt

from concrete import fhe

Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(
input_quant_opts: Optional[QuantizationOptions] = None,
**attrs,
) -> None:

self.n_bits = n_bits_output

if input_quant_opts is not None:
Expand Down Expand Up @@ -916,7 +918,7 @@ def can_fuse(self) -> bool:
def make_output_quant_parameters(
self,
q_values: Union[numpy.ndarray, Any],
scale: numpy.float64,
scale: npt.NDArray[numpy.float64],
RomanBredehoft marked this conversation as resolved.
Show resolved Hide resolved
zero_point: Union[int, float, numpy.ndarray],
) -> QuantizedArray:
"""Build a quantized array from quantized integer results of the op and quantization params.
Expand Down Expand Up @@ -1016,6 +1018,9 @@ def cnp_round(
# Rounding to low bit-width with approximate can cause issues with overflow protection
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4345
x = fhe.round_bit_pattern(
x, lsbs_to_remove=lsbs_value, exactness=exactness, overflow_protection=False
x,
lsbs_to_remove=lsbs_value,
exactness=exactness,
overflow_protection=False,
)
return x
Loading
Loading