Skip to content

Commit

Permalink
chore: fix coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Nov 22, 2024
1 parent 7d69721 commit 55293de
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,14 @@ def _replace_modules(self):
if is_pure_linear_layer:
module = self.private_modules[module_name]
# Use weight shape instead of in/out_features
if hasattr(module, "weight"):
input_dim = module.weight.shape[
1
] # Input dimension is second dimension for Linear layers
output_dim = module.weight.shape[0] # Output dimension is first dimension
else:
input_dim = output_dim = 0
input_dim, output_dim = (
(
module.weight.shape[1],
module.weight.shape[0],
)
if hasattr(module, "weight")
else (0, 0)
)

is_pure_linear_layer = (
is_pure_linear_layer and input_dim >= 512 and output_dim >= 512
Expand Down Expand Up @@ -474,8 +475,7 @@ def forward(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:

if fhe_mode in (HybridFHEMode.EXECUTE, HybridFHEMode.REMOTE, HybridFHEMode.DISABLE):
# Initialize executor only if not already done
if self.executor is None:
self.executor = GLWELinearLayerExecutor()
self.executor = self.executor or GLWELinearLayerExecutor()

# Generate keys only if needed and not already done
if fhe_mode != HybridFHEMode.DISABLE and self.executor.private_key is None:
Expand Down

0 comments on commit 55293de

Please sign in to comment.