Skip to content

Commit

Permalink
Make the injectable dot_general optional and not pre-initialized so…
Browse files Browse the repository at this point in the history
… that functional interceptor works.

PiperOrigin-RevId: 573054116
  • Loading branch information
IvyZX authored and Flax Authors committed Oct 13, 2023
1 parent 342d12a commit c3859e8
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
4 changes: 2 additions & 2 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ class MultiHeadDotProductAttention(Module):
decode: bool = False
normalize_qk: bool = False
# Deprecated, will be removed.
qkv_dot_general: DotGeneralT = lax.dot_general
out_dot_general: DotGeneralT = lax.dot_general
qkv_dot_general: Optional[DotGeneralT] = None
out_dot_general: Optional[DotGeneralT] = None
qkv_dot_general_cls: Any = None
out_dot_general_cls: Any = None

Expand Down
6 changes: 4 additions & 2 deletions flax/linen/experimental/layers_with_named_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Dense(nn.Module):
)
kernel_axes: Tuple[str, ...] = ()
# Deprecated. Will be removed.
dot_general: DotGeneralT = lax.dot_general
dot_general: Optional[DotGeneralT] = None
dot_general_cls: Any = None

@nn.compact
Expand All @@ -98,8 +98,10 @@ def __call__(self, inputs: Array) -> Array:

if self.dot_general_cls is not None:
dot_general = self.dot_general_cls()
else:
elif self.dot_general is not None:
dot_general = self.dot_general
else:
dot_general = lax.dot_general
y = dot_general(
inputs,
kernel,
Expand Down
18 changes: 12 additions & 6 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class DenseGeneral(Module):
)
precision: PrecisionLike = None
# Deprecated. Will be removed.
dot_general: DotGeneralT = lax.dot_general
dot_general: Optional[DotGeneralT] = None
dot_general_cls: Any = None

@compact
Expand Down Expand Up @@ -178,8 +178,10 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32):

if self.dot_general_cls is not None:
dot_general = self.dot_general_cls()
else:
elif self.dot_general is not None:
dot_general = self.dot_general
else:
dot_general = lax.dot_general
out = dot_general(
inputs,
kernel,
Expand Down Expand Up @@ -218,7 +220,7 @@ class Dense(Module):
initializers.zeros_init()
)
# Deprecated. Will be removed.
dot_general: DotGeneralT = lax.dot_general
dot_general: Optional[DotGeneralT] = None
dot_general_cls: Any = None

@compact
Expand Down Expand Up @@ -247,8 +249,10 @@ def __call__(self, inputs: Array) -> Array:

if self.dot_general_cls is not None:
dot_general = self.dot_general_cls()
else:
elif self.dot_general is not None:
dot_general = self.dot_general
else:
dot_general = lax.dot_general
y = dot_general(
inputs,
kernel,
Expand Down Expand Up @@ -350,7 +354,7 @@ class _Conv(Module):
initializers.zeros_init()
)
# Deprecated. Will be removed.
conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated
conv_general_dilated: Optional[ConvGeneralDilatedT] = None
conv_general_dilated_cls: Any = None

@property
Expand Down Expand Up @@ -466,8 +470,10 @@ def maybe_broadcast(
# create the unshared convolution kernel.
if self.conv_general_dilated_cls is not None:
conv_general_dilated = self.conv_general_dilated_cls()
else:
elif self.conv_general_dilated is not None:
conv_general_dilated = self.conv_general_dilated
else:
conv_general_dilated = lax.conv_general_dilated
conv_output_shape = eval_shape(
lambda lhs, rhs: conv_general_dilated( # pylint: disable=g-long-lambda
lhs=lhs,
Expand Down

0 comments on commit c3859e8

Please sign in to comment.