From 82d89b4aa5654d46662204958c2177db0c18fe9f Mon Sep 17 00:00:00 2001 From: FermiNet Contributor Date: Wed, 29 May 2024 14:48:33 +0100 Subject: [PATCH] Simplifying the LayerTag Primitive machinary. PiperOrigin-RevId: 638267834 Change-Id: I3beb9051d840acc8de5e6fab0d8cbb7f15002b29 --- ferminet/curvature_tags_and_blocks.py | 88 +++++++++++++-------------- 1 file changed, 41 insertions(+), 47 deletions(-) diff --git a/ferminet/curvature_tags_and_blocks.py b/ferminet/curvature_tags_and_blocks.py index 617ae53..c854fbe 100644 --- a/ferminet/curvature_tags_and_blocks.py +++ b/ferminet/curvature_tags_and_blocks.py @@ -13,7 +13,9 @@ # limitations under the License. """Curvature blocks for FermiNet.""" -from typing import Any, Mapping, Sequence, Set, Tuple +import dataclasses +import functools +from typing import Sequence, Set, Tuple import jax import jax.numpy as jnp import kfac_jax @@ -28,18 +30,19 @@ vmap_matmul = jax.vmap(jnp.matmul, in_axes=(0, 0), out_axes=0) -repeated_dense_tag = kfac_jax.LayerTag("repeated_dense_tag", 1, 1) -qmc_tag = kfac_jax.LayerTag("qmc_tag", 1, 1) +def register_repeated_dense(y, x, w, b, **kwargs): + return kfac_jax.register_dense(y, x, w, b, variant="repeated_dense", **kwargs) -def register_repeated_dense(y, x, w, b): - if b is None: - return repeated_dense_tag.bind(y, x, w) - return repeated_dense_tag.bind(y, x, w, b) +def register_qmc(y, x, w, **kwargs): + return kfac_jax.register_dense(y, x, w, variant="qmc", **kwargs) -def register_qmc(y, x, w, **kwargs): - return qmc_tag.bind(y, x, w, **kwargs) +_dense = kfac_jax.tag_graph_matcher._dense # pylint: disable=protected-access +_repeated_dense_parameter_extractor = functools.partial( + kfac_jax.tag_graph_matcher._dense_parameter_extractor, # pylint: disable=protected-access + variant="repeated_dense", +) class RepeatedDenseBlock(kfac_jax.DenseTwoKroneckerFactored): @@ -52,18 +55,28 @@ def fixed_scale(self) -> Numeric: def update_curvature_matrix_estimate( self, state: kfac_jax.TwoKroneckerFactored.State, - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: kfac_jax.LayerVjpData[Array], ema_old: Numeric, ema_new: Numeric, identity_weight: Numeric, batch_size: int, ) -> kfac_jax.TwoKroneckerFactored.State: - estimation_data = dict(**estimation_data) - x, = estimation_data["inputs"] - dy, = estimation_data["outputs_tangent"] + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs assert x.shape[0] == batch_size - estimation_data["inputs"] = (x.reshape([-1, x.shape[-1]]),) - estimation_data["outputs_tangent"] = (dy.reshape([-1, dy.shape[-1]]),) + + estimation_data = dataclasses.replace( + estimation_data, + primals=dataclasses.replace( + estimation_data.primals, + inputs=(x.reshape([-1, x.shape[-1]]),), + ), + tangents=dataclasses.replace( + estimation_data.tangents, + outputs=(dy.reshape([-1, dy.shape[-1]]),), + ), + ) + batch_size = x.size // x.shape[-1] return super().update_curvature_matrix_estimate( state=state, @@ -90,7 +103,7 @@ def fixed_scale(self) -> Numeric: def update_curvature_matrix_estimate( self, state: kfac_jax.TwoKroneckerFactored.State, - estimation_data: Mapping[str, Sequence[Array]], + estimation_data: kfac_jax.LayerVjpData[Array], ema_old: Numeric, ema_new: Numeric, identity_weight: Numeric, @@ -98,8 +111,8 @@ def update_curvature_matrix_estimate( ) -> kfac_jax.TwoKroneckerFactored.State: del identity_weight - x, = estimation_data["inputs"] - dy, = estimation_data["outputs_tangent"] + [x] = estimation_data.primals.inputs + [dy] = estimation_data.tangents.outputs assert batch_size == x.shape[0] normalizer = x.shape[0] * x.shape[1] # The forward computation is @@ -211,23 +224,6 @@ def multiply_matpower( return (v,) -def _dense(x: Array, params: Sequence[Array]) -> Array: - """Example of a dense layer function.""" - w, *opt_b = params - y = jnp.matmul(x, w) - return y if not opt_b else y + opt_b[0] - - -def _dense_parameter_extractor( - eqns: Sequence[jax.core.JaxprEqn], -) -> Mapping[str, Any]: - """Extracts all parameters from the conv_general_dilated operator.""" - for eqn in eqns: - if eqn.primitive.name == "dot_general": - return dict(**eqn.params) - assert False - - # repeating a dense layer once _repeated_dense1 = jax.vmap(_dense, in_axes=[0, [None, None]]) _repeated_dense2 = jax.vmap(_repeated_dense1, in_axes=[0, [None, None]]) @@ -237,33 +233,33 @@ def _dense_parameter_extractor( # Computation for repeated dense layer repeated_dense1_with_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern( name="repeated_dense1_with_bias", - tag_primitive=repeated_dense_tag, + tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag, compute_func=_repeated_dense1, - parameters_extractor_func=_dense_parameter_extractor, + parameters_extractor_func=_repeated_dense_parameter_extractor, example_args=[np.zeros([9, 11, 13]), [np.zeros([13, 7]), np.zeros([7])]], ) repeated_dense1_no_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern( name="repeated_dense1_no_bias", - tag_primitive=repeated_dense_tag, + tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag, compute_func=_repeated_dense1_no_b, - parameters_extractor_func=_dense_parameter_extractor, + parameters_extractor_func=_repeated_dense_parameter_extractor, example_args=[np.zeros([9, 11, 13]), [np.zeros([13, 7])]], ) repeated_dense2_with_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern( name="repeated_dense2_with_bias", - tag_primitive=repeated_dense_tag, + tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag, compute_func=_repeated_dense2, - parameters_extractor_func=_dense_parameter_extractor, + parameters_extractor_func=_repeated_dense_parameter_extractor, example_args=[np.zeros([8, 9, 11, 13]), [np.zeros([13, 7]), np.zeros([7])]], ) repeated_dense2_no_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern( name="repeated_dense2_no_bias", - tag_primitive=repeated_dense_tag, + tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag, compute_func=_repeated_dense2_no_b, - parameters_extractor_func=_dense_parameter_extractor, + parameters_extractor_func=_repeated_dense_parameter_extractor, example_args=[np.zeros([8, 9, 11, 13]), [np.zeros([13, 7])]], ) @@ -276,8 +272,6 @@ def _dense_parameter_extractor( kfac_jax.set_default_tag_to_block_ctor( - "repeated_dense_tag", RepeatedDenseBlock -) -kfac_jax.set_default_tag_to_block_ctor( - "qmc_tag", QmcBlockedDense + "repeated_dense", RepeatedDenseBlock ) +kfac_jax.set_default_tag_to_block_ctor("qmc", QmcBlockedDense)