From c3859e87d79a2bf0f857158f8a7596e3cdc3efbe Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Thu, 12 Oct 2023 17:20:40 -0700 Subject: [PATCH] Make the injectable `dot_general` optional and not pre-initialized so that functional interceptor works. PiperOrigin-RevId: 573054116 --- flax/linen/attention.py | 4 ++-- .../experimental/layers_with_named_axes.py | 6 ++++-- flax/linen/linear.py | 18 ++++++++++++------ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 575620efcd..efb34b510a 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -244,8 +244,8 @@ class MultiHeadDotProductAttention(Module): decode: bool = False normalize_qk: bool = False # Deprecated, will be removed. - qkv_dot_general: DotGeneralT = lax.dot_general - out_dot_general: DotGeneralT = lax.dot_general + qkv_dot_general: Optional[DotGeneralT] = None + out_dot_general: Optional[DotGeneralT] = None qkv_dot_general_cls: Any = None out_dot_general_cls: Any = None diff --git a/flax/linen/experimental/layers_with_named_axes.py b/flax/linen/experimental/layers_with_named_axes.py index 25ee5e6276..24fde145ba 100644 --- a/flax/linen/experimental/layers_with_named_axes.py +++ b/flax/linen/experimental/layers_with_named_axes.py @@ -73,7 +73,7 @@ class Dense(nn.Module): ) kernel_axes: Tuple[str, ...] = () # Deprecated. Will be removed. - dot_general: DotGeneralT = lax.dot_general + dot_general: Optional[DotGeneralT] = None dot_general_cls: Any = None @nn.compact @@ -98,8 +98,10 @@ def __call__(self, inputs: Array) -> Array: if self.dot_general_cls is not None: dot_general = self.dot_general_cls() - else: + elif self.dot_general is not None: dot_general = self.dot_general + else: + dot_general = lax.dot_general y = dot_general( inputs, kernel, diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 4a23bd3c1c..bee28a8055 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -98,7 +98,7 @@ class DenseGeneral(Module): ) precision: PrecisionLike = None # Deprecated. Will be removed. - dot_general: DotGeneralT = lax.dot_general + dot_general: Optional[DotGeneralT] = None dot_general_cls: Any = None @compact @@ -178,8 +178,10 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): if self.dot_general_cls is not None: dot_general = self.dot_general_cls() - else: + elif self.dot_general is not None: dot_general = self.dot_general + else: + dot_general = lax.dot_general out = dot_general( inputs, kernel, @@ -218,7 +220,7 @@ class Dense(Module): initializers.zeros_init() ) # Deprecated. Will be removed. - dot_general: DotGeneralT = lax.dot_general + dot_general: Optional[DotGeneralT] = None dot_general_cls: Any = None @compact @@ -247,8 +249,10 @@ def __call__(self, inputs: Array) -> Array: if self.dot_general_cls is not None: dot_general = self.dot_general_cls() - else: + elif self.dot_general is not None: dot_general = self.dot_general + else: + dot_general = lax.dot_general y = dot_general( inputs, kernel, @@ -350,7 +354,7 @@ class _Conv(Module): initializers.zeros_init() ) # Deprecated. Will be removed. - conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated + conv_general_dilated: Optional[ConvGeneralDilatedT] = None conv_general_dilated_cls: Any = None @property @@ -466,8 +470,10 @@ def maybe_broadcast( # create the unshared convolution kernel. if self.conv_general_dilated_cls is not None: conv_general_dilated = self.conv_general_dilated_cls() - else: + elif self.conv_general_dilated is not None: conv_general_dilated = self.conv_general_dilated + else: + conv_general_dilated = lax.conv_general_dilated conv_output_shape = eval_shape( lambda lhs, rhs: conv_general_dilated( # pylint: disable=g-long-lambda lhs=lhs,