From 55293ded898132b30b00fbfd0372ed8e1788045f Mon Sep 17 00:00:00 2001 From: jfrery Date: Fri, 22 Nov 2024 11:48:21 +0100 Subject: [PATCH] chore: fix coverage --- src/concrete/ml/torch/hybrid_model.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/concrete/ml/torch/hybrid_model.py b/src/concrete/ml/torch/hybrid_model.py index e8bbf6c34..aeacef262 100644 --- a/src/concrete/ml/torch/hybrid_model.py +++ b/src/concrete/ml/torch/hybrid_model.py @@ -412,13 +412,14 @@ def _replace_modules(self): if is_pure_linear_layer: module = self.private_modules[module_name] # Use weight shape instead of in/out_features - if hasattr(module, "weight"): - input_dim = module.weight.shape[ - 1 - ] # Input dimension is second dimension for Linear layers - output_dim = module.weight.shape[0] # Output dimension is first dimension - else: - input_dim = output_dim = 0 + input_dim, output_dim = ( + ( + module.weight.shape[1], + module.weight.shape[0], + ) + if hasattr(module, "weight") + else (0, 0) + ) is_pure_linear_layer = ( is_pure_linear_layer and input_dim >= 512 and output_dim >= 512 @@ -474,8 +475,7 @@ def forward(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: if fhe_mode in (HybridFHEMode.EXECUTE, HybridFHEMode.REMOTE, HybridFHEMode.DISABLE): # Initialize executor only if not already done - if self.executor is None: - self.executor = GLWELinearLayerExecutor() + self.executor = self.executor or GLWELinearLayerExecutor() # Generate keys only if needed and not already done if fhe_mode != HybridFHEMode.DISABLE and self.executor.private_key is None: