diff --git a/fjformer/__init__.py b/fjformer/__init__.py index 36e88f6..6900051 100644 --- a/fjformer/__init__.py +++ b/fjformer/__init__.py @@ -26,4 +26,4 @@ JaxRNG, GenerateRNG, init_rng, next_rng, count_num_params ) -__version__ = '0.0.8' +__version__ = '0.0.9' diff --git a/fjformer/bits/calibration.py b/fjformer/bits/calibration.py index b52e167..d6224be 100644 --- a/fjformer/bits/calibration.py +++ b/fjformer/bits/calibration.py @@ -14,6 +14,8 @@ """Quantization calibration methods.""" import abc +from typing import Union + import flax.struct import jax.numpy as jnp @@ -27,7 +29,7 @@ def get_bound(self, x, shared_axes) -> jnp.ndarray: @flax.struct.dataclass class ConstantCalibration(Calibration): - bound: jnp.ndarray | float + bound: Union[jnp.ndarray, float] def get_bound(self, x, shared_axes) -> jnp.ndarray: """Calibration.""" diff --git a/fjformer/bits/config.py b/fjformer/bits/config.py index 0abb4c7..583aec7 100644 --- a/fjformer/bits/config.py +++ b/fjformer/bits/config.py @@ -271,8 +271,8 @@ def dot_general_make( def fully_quantized( *, - fwd_bits: int | None = 8, - bwd_bits: int | None = 8, + fwd_bits: Union[int, None] = 8, + bwd_bits: Union[int, None] = 8, use_fwd_quant: bool = True, use_stochastic_rounding: Optional[bool] = True, # Typically we have (but it's a caller's responsibility to check): @@ -332,9 +332,9 @@ def fully_quantized( def config_v3( *, - fwd_bits: int | None, - dlhs_bits: int | None, - drhs_bits: int | None, + fwd_bits: Union[int, None], + dlhs_bits: Union[int, None], + drhs_bits: Union[int, None], use_dummy_static_bound: bool = False, rng_type: str = 'jax.uniform', # 'custom-1' dlhs_local_q: Optional[LocalQ] = None, diff --git a/fjformer/bits/q_flax.py b/fjformer/bits/q_flax.py index 39fcd69..d3df438 100644 --- a/fjformer/bits/q_flax.py +++ b/fjformer/bits/q_flax.py @@ -21,6 +21,7 @@ from . import int_numerics import flax.linen as nn import jax.numpy as jnp +from typing import Optional, Union class Freezer(nn.Module, config.Preprocess): @@ -53,8 +54,8 @@ def __call__(self, inputs): class QDotGeneral(nn.Module): """A layer that can be injected into flax.nn.Dense, etc.""" - cfg: config.DotGeneral | None = None - prng_name: str | None = 'params' + cfg: Optional[Union[config.DotGeneral, None]] = None + prng_name: Optional[Union[str, None]] = None @nn.compact def __call__( @@ -81,8 +82,8 @@ def __call__( class QEinsum(nn.Module): """Quantized Einsum class for model injection.""" - cfg: config.DotGeneral | None = None - prng_name: str | None = 'params' + cfg: Optional[Union[config.DotGeneral, None]] = None + prng_name: Optional[Union[str, None]] = None @nn.compact def __call__(self, eqn, lhs, rhs): @@ -128,14 +129,14 @@ def set_lhs_quant_mode( def config_v4( *, - fwd_bits: int | None, - dlhs_bits: int | None, - drhs_bits: int | None, + fwd_bits: Union[int, None], + dlhs_bits: Union[int, None], + drhs_bits: Union[int, None], # The dummy static bound flag is for performance benchmarking. use_dummy_static_bound: bool = False, rng_type: str = 'jax.uniform', # 'custom-1' - dlhs_local_q: config.LocalQ | None = None, - drhs_local_q: config.LocalQ | None = None, + dlhs_local_q: Union[config.LocalQ, None] = None, + drhs_local_q: Union[config.LocalQ, None] = None, fwd_accumulator_dtype: ... = jnp.int32, dlhs_accumulator_dtype: ... = jnp.int32, drhs_accumulator_dtype: ... = jnp.int32, @@ -145,7 +146,7 @@ def config_v4( ) -> config.DotGeneral: """Version 4 of user-visible AQT config.""" - def tensor_config(bits: int | None) -> config.Tensor: + def tensor_config(bits: Union[int, None]) -> config.Tensor: assert bits is None or bits >= 2, 'Need at least 2 bits.' if bits is None: numerics = config.NoNumerics() diff --git a/setup.py b/setup.py index d4bd27f..c026279 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setuptools.setup( name="fjformer", - version='0.0.8', + version='0.0.9', author="Erfan Zare Chavoshi", author_email="erfanzare82@yahoo.com", long_description=long_description,