Skip to content

Commit

Permalink
TorchDynamo Compatability
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Sep 13, 2024
1 parent 84cfdb2 commit c452a86
Showing 1 changed file with 8 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,14 @@ def apply_weights(self,
def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
fn: Callable) -> None:
if name is not None and getattr(layer, name, None) is not None:
replace_parameter(layer, name, fn(getattr(layer, name)))

old_param = getattr(layer, name)
new_param = fn(old_param)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter(
layer, name,
torch.nn.Parameter(new_param.data, requires_grad=False))

def _get_weight_params(
self, layer: torch.nn.Module
Expand Down

0 comments on commit c452a86

Please sign in to comment.