Skip to content

Commit

Permalink
chore: support arbitrary msb
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Jun 5, 2024
1 parent b11d99d commit a221d28
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 15 deletions.
37 changes: 31 additions & 6 deletions src/concrete/ml/common/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,7 @@ def decompose_1_bit_tlu(
bounds: Tuple[int, int],
rounding_function: Callable = round_bit_pattern,
n_jumps_limit: Optional[int] = None,
msbs_to_keep=1,
):
assert (
rounding_function.__name__ == "round_bit_pattern"
Expand Down Expand Up @@ -907,7 +908,7 @@ def decompose_1_bit_tlu(
# Populate coefficients and offsets
for threshold_index, (threshold, coef) in enumerate(zip(thresholds_selected, tlu_coefs)):
# Compute the offset
offset = threshold + (2 ** (bit_width(input_range)))
offset = threshold
offsets_to_apply[best_indexes + (threshold_index,)] = offset

# todo: get the proper value here: must be the result of f(x-1) - f(x) or smth like that
Expand All @@ -918,14 +919,33 @@ def decompose_1_bit_tlu(
acc_size = bit_width(scale_up(input_range, scaling_factor=1, bias=offset))
max_acc_size = max(max_acc_size, acc_size)

msbs_to_keep = 1
lsbs_to_remove = max_acc_size - msbs_to_keep
rounded = round_bit_pattern(

if rounding_function.__name__ == "round_bit_pattern":
offsets_to_apply += 2 ** (lsbs_to_remove - 1)

# Sanity check
rounded = rounding_function(
(subgraph_inputs[..., np.newaxis] - offsets_to_apply).astype(np.int64),
lsbs_to_remove=lsbs_to_remove,
lsbs_to_remove=int(lsbs_to_remove),
)
pred = ((rounded >= 0).astype(np.int64) * coefficients).sum(axis=-1) + base
assert (subgraph_outputs == pred).all()

if (subgraph_outputs != pred).any():
# TODO: DEBUG
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
fig.suptitle("Debug")
slice_index = tuple(
[slice(0, subgraph_inputs.shape[0])] + [0 for _ in subgraph_inputs.shape[1:]]
)
ax.plot(subgraph_inputs[slice_index], pred[slice_index], label="debug", linestyle="--")
ax.plot(subgraph_inputs[slice_index], subgraph_outputs[slice_index], label="reference")
plt.legend()
plt.savefig("debug.png")
plt.close("all")
raise ValueError(f"{(subgraph_outputs == pred).mean()=} != 1")

return (
max_number_of_thresholds,
Expand Down Expand Up @@ -1603,7 +1623,10 @@ def apply(self, graph: Graph) -> None:
# no rounding already -> Put TLU-1bit and then InsertRounding in CIFAR
class TLU1bitDecomposition(GraphProcessor):
def __init__(
self, n_jumps_limit: int = 4, exactness: Exactness = Exactness.APPROXIMATE
self,
n_jumps_limit: int = 4,
exactness: Exactness = Exactness.APPROXIMATE,
msbs_to_keep=1,
) -> None:
super().__init__()
self.exactness = exactness
Expand All @@ -1612,6 +1635,7 @@ def __init__(
self.rounding_function = round_bit_pattern
self.overflow_protection = True
self.verbose = True
self.msbs_to_keep = msbs_to_keep

def apply(self, graph: Graph) -> None:
"""Apply the TLU optimization to a Graph for all TLUs.
Expand Down Expand Up @@ -1689,6 +1713,7 @@ def apply(self, graph: Graph) -> None:
shape_=shape_,
bounds=(int(min_bound), int(max_bound)),
n_jumps_limit=self.n_jumps_limit,
msbs_to_keep=self.msbs_to_keep,
)
)
assert isinstance(lsbs_to_remove, int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def wrapper(*args, **kwargs):
tlu_optimizer = TLU1bitDecomposition(
n_jumps_limit=2,
exactness=exactness,
msbs_to_keep=4,
)
rounding = InsertRounding(6, exactness=exactness)

Expand All @@ -99,7 +100,7 @@ def wrapper(*args, **kwargs):
use_insecure_key_cache=True,
insecure_key_cache_location=KEYGEN_CACHE_DIR,
additional_pre_processors=[
# tlu_optimizer,
tlu_optimizer,
rounding,
],
fhe_simulation=SIMULATE_ONLY,
Expand All @@ -114,7 +115,6 @@ def wrapper(*args, **kwargs):
torch_model,
x,
configuration=configuration,
# rounding_threshold_bits={"method": Exactness.APPROXIMATE, "n_bits": 6},
p_error=P_ERROR,
)
assert isinstance(quantized_numpy_module, QuantizedModule)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def evaluate(torch_model, cml_model, device, num_workers):

# Import and load the CIFAR test dataset (following bnn_pynq_train.py)
test_set = get_test_set(dataset="CIFAR10", datadir=CURRENT_DIR / ".datasets/")
# test_set = Subset(test_set, np.arange(128))
test_set = Subset(test_set, np.arange(128))
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=num_workers)

torch_top_1_batches = []
Expand Down Expand Up @@ -121,13 +121,10 @@ def main(args):
# Eval mode
model.eval()

exactness=fhe.Exactness.APPROXIMATE
exactness = fhe.Exactness.APPROXIMATE
# Multi-parameter strategy is used in order to speed-up the FHE executions
tlu_optimizer = TLU1bitDecomposition(
n_jumps_limit=2,
exactness=exactness,
)
insert_rounding = InsertRounding(6,exactness=exactness)
tlu_optimizer = TLU1bitDecomposition(n_jumps_limit=2, exactness=exactness, msbs_to_keep=4)
insert_rounding = InsertRounding(6, exactness=exactness)
# debug_processor = Debug()
cfg = Configuration(
verbose=True,
Expand Down

0 comments on commit a221d28

Please sign in to comment.