Skip to content

Commit

Permalink
chore: fix hybrid model glwe lora mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Nov 14, 2024
1 parent e396438 commit ace90dd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 29 deletions.
1 change: 1 addition & 0 deletions .gitleaksignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ f41de03048a9ed27946b875e81b34138bb4bb17b:use_case_examples/training/analyze.ipyn
e2904473898ddd325f245f4faca526a0e9520f49:builders/Dockerfile.zamalang-env:generic-api-key:5
7d5e885816f1f1e432dd94da38c5c8267292056a:docs/advanced_examples/XGBRegressor.ipynb:aws-access-token:1026
25c5e7abaa7382520af3fb7a64266e193b1f6a59:poetry.lock:square-access-token:6401
eebd4bea78f6dd2361baa7f94f68ae4cba8b9fe8:tests/deployment/test_deployment.py:generic-api-key:20
41 changes: 12 additions & 29 deletions src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# pylint: disable=too-many-lines
import ast
import contextvars
import io
import sys
import time
Expand Down Expand Up @@ -102,13 +101,6 @@ def convert_conv1d_to_linear(layer_or_module):
return layer_or_module


# This module member is instantiated by the Hybrid FHE model
# when hybrid FHE forward is called and the GLWE backend is available
_optimized_linear_executor: contextvars.ContextVar[Optional[GLWELinearLayerExecutor]] = (
contextvars.ContextVar("optimized_linear_executor")
)


# pylint: disable-next=too-many-instance-attributes
class RemoteModule(nn.Module):
"""A wrapper class for the modules to be evaluated remotely with FHE."""
Expand Down Expand Up @@ -136,6 +128,7 @@ def __init__(
self.model_name: Optional[str] = model_name
self.verbose = verbose
self.optimized_linear_execution = optimized_linear_execution
self.executor: Optional[GLWELinearLayerExecutor] = None

def init_fhe_client(
self, path_to_client: Optional[Path] = None, path_to_keys: Optional[Path] = None
Expand Down Expand Up @@ -252,15 +245,10 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:
}:
assert self.private_q_module is not None

try:
optimized_linear_layer_executor = _optimized_linear_executor.get()
except LookupError:
optimized_linear_layer_executor = None

if optimized_linear_layer_executor:
if self.executor:
# Delegate to the optimized GLWE executor
y = torch.Tensor(
optimized_linear_layer_executor.forward(
self.executor.forward(
x.detach().numpy(), self.private_q_module, self.fhe_local_mode
)
)
Expand All @@ -269,6 +257,7 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:
y = torch.Tensor(
self.private_q_module.forward(x.detach().numpy(), fhe=self.fhe_local_mode.value)
)

elif self.fhe_local_mode == HybridFHEMode.CALIBRATE:
# Calling torch + gathering calibration data
assert self.private_module is not None
Expand All @@ -278,14 +267,7 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:

elif self.fhe_local_mode == HybridFHEMode.REMOTE: # pragma:no cover
# Remote call
try:
optimized_linear_layer_executor = _optimized_linear_executor.get()
except LookupError:
optimized_linear_layer_executor = None

assert optimized_linear_layer_executor is None, (
"Remote optimized linear layers " "are not yet implemented"
)
assert self.executor is None, "Remote optimized linear layers are not yet implemented"
y = self.remote_call(x)

elif self.fhe_local_mode == HybridFHEMode.TORCH:
Expand Down Expand Up @@ -400,6 +382,7 @@ def __init__(
self.configuration: Optional[Configuration] = None
self.model_name = model_name
self.verbose = verbose
self.executor: Optional[GLWELinearLayerExecutor] = None

self._replace_modules()

Expand Down Expand Up @@ -461,6 +444,7 @@ def forward(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:

# Validate the FHE mode
fhe_mode = HybridFHEMode(fhe)
self.executor = None

if _HAS_GLWE_BACKEND and self._all_layers_are_pure_linear:
if fhe_mode == HybridFHEMode.SIMULATE:
Expand All @@ -476,17 +460,16 @@ def forward(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:

# Loading keys from a file could be done here, and the
# keys could be passed as arguments to the Executor
executor = GLWELinearLayerExecutor()

self.executor = GLWELinearLayerExecutor()
if fhe_mode != HybridFHEMode.DISABLE:
executor.keygen()
self.executor.keygen()

_optimized_linear_executor.set(executor)
# Update executor for all remote modules
for module in self.remote_modules.values():
module.executor = self.executor

result = self.model(x)

_optimized_linear_executor.set(None)

return result

def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:
Expand Down

0 comments on commit ace90dd

Please sign in to comment.