Skip to content

Commit

Permalink
Simplifying the LayerTag Primitive machinary.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638267834
Change-Id: I3beb9051d840acc8de5e6fab0d8cbb7f15002b29
  • Loading branch information
FermiNet Contributor authored and jsspencer committed Jun 4, 2024
1 parent 41ff8d1 commit 82d89b4
Showing 1 changed file with 41 additions and 47 deletions.
88 changes: 41 additions & 47 deletions ferminet/curvature_tags_and_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

"""Curvature blocks for FermiNet."""
from typing import Any, Mapping, Sequence, Set, Tuple
import dataclasses
import functools
from typing import Sequence, Set, Tuple
import jax
import jax.numpy as jnp
import kfac_jax
Expand All @@ -28,18 +30,19 @@
vmap_matmul = jax.vmap(jnp.matmul, in_axes=(0, 0), out_axes=0)


repeated_dense_tag = kfac_jax.LayerTag("repeated_dense_tag", 1, 1)
qmc_tag = kfac_jax.LayerTag("qmc_tag", 1, 1)
def register_repeated_dense(y, x, w, b, **kwargs):
return kfac_jax.register_dense(y, x, w, b, variant="repeated_dense", **kwargs)


def register_repeated_dense(y, x, w, b):
if b is None:
return repeated_dense_tag.bind(y, x, w)
return repeated_dense_tag.bind(y, x, w, b)
def register_qmc(y, x, w, **kwargs):
return kfac_jax.register_dense(y, x, w, variant="qmc", **kwargs)


def register_qmc(y, x, w, **kwargs):
return qmc_tag.bind(y, x, w, **kwargs)
_dense = kfac_jax.tag_graph_matcher._dense # pylint: disable=protected-access
_repeated_dense_parameter_extractor = functools.partial(
kfac_jax.tag_graph_matcher._dense_parameter_extractor, # pylint: disable=protected-access
variant="repeated_dense",
)


class RepeatedDenseBlock(kfac_jax.DenseTwoKroneckerFactored):
Expand All @@ -52,18 +55,28 @@ def fixed_scale(self) -> Numeric:
def update_curvature_matrix_estimate(
self,
state: kfac_jax.TwoKroneckerFactored.State,
estimation_data: Mapping[str, Sequence[Array]],
estimation_data: kfac_jax.LayerVjpData[Array],
ema_old: Numeric,
ema_new: Numeric,
identity_weight: Numeric,
batch_size: int,
) -> kfac_jax.TwoKroneckerFactored.State:
estimation_data = dict(**estimation_data)
x, = estimation_data["inputs"]
dy, = estimation_data["outputs_tangent"]
[x] = estimation_data.primals.inputs
[dy] = estimation_data.tangents.outputs
assert x.shape[0] == batch_size
estimation_data["inputs"] = (x.reshape([-1, x.shape[-1]]),)
estimation_data["outputs_tangent"] = (dy.reshape([-1, dy.shape[-1]]),)

estimation_data = dataclasses.replace(
estimation_data,
primals=dataclasses.replace(
estimation_data.primals,
inputs=(x.reshape([-1, x.shape[-1]]),),
),
tangents=dataclasses.replace(
estimation_data.tangents,
outputs=(dy.reshape([-1, dy.shape[-1]]),),
),
)

batch_size = x.size // x.shape[-1]
return super().update_curvature_matrix_estimate(
state=state,
Expand All @@ -90,16 +103,16 @@ def fixed_scale(self) -> Numeric:
def update_curvature_matrix_estimate(
self,
state: kfac_jax.TwoKroneckerFactored.State,
estimation_data: Mapping[str, Sequence[Array]],
estimation_data: kfac_jax.LayerVjpData[Array],
ema_old: Numeric,
ema_new: Numeric,
identity_weight: Numeric,
batch_size: int,
) -> kfac_jax.TwoKroneckerFactored.State:
del identity_weight

x, = estimation_data["inputs"]
dy, = estimation_data["outputs_tangent"]
[x] = estimation_data.primals.inputs
[dy] = estimation_data.tangents.outputs
assert batch_size == x.shape[0]
normalizer = x.shape[0] * x.shape[1]
# The forward computation is
Expand Down Expand Up @@ -211,23 +224,6 @@ def multiply_matpower(
return (v,)


def _dense(x: Array, params: Sequence[Array]) -> Array:
"""Example of a dense layer function."""
w, *opt_b = params
y = jnp.matmul(x, w)
return y if not opt_b else y + opt_b[0]


def _dense_parameter_extractor(
eqns: Sequence[jax.core.JaxprEqn],
) -> Mapping[str, Any]:
"""Extracts all parameters from the conv_general_dilated operator."""
for eqn in eqns:
if eqn.primitive.name == "dot_general":
return dict(**eqn.params)
assert False


# repeating a dense layer once
_repeated_dense1 = jax.vmap(_dense, in_axes=[0, [None, None]])
_repeated_dense2 = jax.vmap(_repeated_dense1, in_axes=[0, [None, None]])
Expand All @@ -237,33 +233,33 @@ def _dense_parameter_extractor(
# Computation for repeated dense layer
repeated_dense1_with_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern(
name="repeated_dense1_with_bias",
tag_primitive=repeated_dense_tag,
tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag,
compute_func=_repeated_dense1,
parameters_extractor_func=_dense_parameter_extractor,
parameters_extractor_func=_repeated_dense_parameter_extractor,
example_args=[np.zeros([9, 11, 13]), [np.zeros([13, 7]), np.zeros([7])]],
)

repeated_dense1_no_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern(
name="repeated_dense1_no_bias",
tag_primitive=repeated_dense_tag,
tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag,
compute_func=_repeated_dense1_no_b,
parameters_extractor_func=_dense_parameter_extractor,
parameters_extractor_func=_repeated_dense_parameter_extractor,
example_args=[np.zeros([9, 11, 13]), [np.zeros([13, 7])]],
)

repeated_dense2_with_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern(
name="repeated_dense2_with_bias",
tag_primitive=repeated_dense_tag,
tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag,
compute_func=_repeated_dense2,
parameters_extractor_func=_dense_parameter_extractor,
parameters_extractor_func=_repeated_dense_parameter_extractor,
example_args=[np.zeros([8, 9, 11, 13]), [np.zeros([13, 7]), np.zeros([7])]],
)

repeated_dense2_no_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern(
name="repeated_dense2_no_bias",
tag_primitive=repeated_dense_tag,
tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag,
compute_func=_repeated_dense2_no_b,
parameters_extractor_func=_dense_parameter_extractor,
parameters_extractor_func=_repeated_dense_parameter_extractor,
example_args=[np.zeros([8, 9, 11, 13]), [np.zeros([13, 7])]],
)

Expand All @@ -276,8 +272,6 @@ def _dense_parameter_extractor(


kfac_jax.set_default_tag_to_block_ctor(
"repeated_dense_tag", RepeatedDenseBlock
)
kfac_jax.set_default_tag_to_block_ctor(
"qmc_tag", QmcBlockedDense
"repeated_dense", RepeatedDenseBlock
)
kfac_jax.set_default_tag_to_block_ctor("qmc", QmcBlockedDense)

0 comments on commit 82d89b4

Please sign in to comment.