Skip to content

Commit

Permalink
SAXML 1.2.0 release
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhihao Shan committed Feb 21, 2024
1 parent b28c180 commit be31b49
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 48 deletions.
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Release 1.2.0
## Major Features and Improvements
* Open-source model: Gemma.
* Continuous Batching experimental support.

# Release 1.1.0

## Major Features and Improvements
Expand Down
2 changes: 1 addition & 1 deletion saxml/server/pax/lm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pytype_strict_library(
deps = [
"//saxml/server:servable_model_registry",
"//saxml/server/pax/lm/params:c4",
"//saxml/server/pax/lm/params:gamma",
"//saxml/server/pax/lm/params:gemma",
"//saxml/server/pax/lm/params:gptj",
"//saxml/server/pax/lm/params:lm_cloud",
],
Expand Down
2 changes: 1 addition & 1 deletion saxml/server/pax/lm/all_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from saxml.server.pax.lm.params import lm_cloud
from saxml.server.pax.lm.params import c4
from saxml.server.pax.lm.params import gptj
from saxml.server.pax.lm.params import gamma
from saxml.server.pax.lm.params import gemma
# Experimental models.

# Specify the registry root.
Expand Down
4 changes: 2 additions & 2 deletions saxml/server/pax/lm/params/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ pytype_strict_library(
)

pytype_strict_library(
name = "gamma",
srcs = ["gamma.py"],
name = "gemma",
srcs = ["gemma.py"],
srcs_version = "PY3",
deps = [
"//saxml/server:servable_model_registry",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Serving model parameters for Gamma."""
"""Serving model parameters for Gemma."""

# OSS import placeholder
from typing import List
Expand All @@ -32,10 +32,10 @@

@servable_model_registry.register
@template.make_servable(template.ServingTemplate)
class GammaBase(base_experiment.BaseExperiment):
"""Gamma Transformer LM configuration."""
class GemmaBase(base_experiment.BaseExperiment):
"""Gemma Transformer LM configuration."""

SPM_MODEL = 'gs://cloud-tpu-inference-public/sax-tokenizers/gamma/gamma-tokenizer.model'
SPM_MODEL = 'gs://cloud-tpu-inference-public/sax-tokenizers/gemma/gemma-tokenizer.model'
SOS_ID = 2
EOS_ID = 1
GENERATE_ONLY = True # No need to compute loss.
Expand Down Expand Up @@ -85,7 +85,7 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]:
else:
task_p.model = pax_fiddle.Config(layers.LanguageModel, name='xformer_lm')
model_p = task_p.model
model_p.lm_tpl = transformer_models.gamma(
model_p.lm_tpl = transformer_models.gemma(
vocab_size=self.VOCAB_SIZE,
model_dims=self.MODEL_DIMS,
hidden_dims=self.HIDDEN_DIMS,
Expand Down Expand Up @@ -116,8 +116,8 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]:


@servable_model_registry.register
class Gamma2BFP16(GammaBase):
"""Gamma2B model."""
class Gemma2BFP16(GemmaBase):
"""Gemma2B model."""

NUM_LAYERS = 18
VOCAB_SIZE = 256128
Expand Down Expand Up @@ -148,15 +148,15 @@ def serving_mesh_shape(cls):


@servable_model_registry.register
class Gamma2BFP16Exp(Gamma2BFP16):
class Gemma2BFP16Exp(Gemma2BFP16):
BATCH_SIZE = 1
NUM_CACHE_SLOTS = 256
MAX_LIVE_BATCHES = 256 * 4 # BATCH_SIZE is always 1 in this case.


@servable_model_registry.register
class Gamma7BFP16(GammaBase):
"""Gamma7B model."""
class Gemma7BFP16(GemmaBase):
"""Gemma7B model."""

NUM_LAYERS = 28
VOCAB_SIZE = 256128
Expand Down Expand Up @@ -187,15 +187,15 @@ def serving_mesh_shape(cls):


@servable_model_registry.register
class Gamma7BFP16Exp(Gamma7BFP16):
class Gemma7BFP16Exp(Gemma7BFP16):
BATCH_SIZE = 1
NUM_CACHE_SLOTS = 32
MAX_LIVE_BATCHES = 32 * 4 # BATCH_SIZE is always 1 in this case.


@servable_model_registry.register
class Gamma2BFP16With8Replicas(Gamma2BFP16):
"""Gamma2B model on v5e-8 with 8 replications."""
class Gemma2BFP16With8Replicas(Gemma2BFP16):
"""Gemma2B model on v5e-8 with 8 replications."""

@classmethod
def serving_mesh_shape(cls):
Expand All @@ -205,8 +205,8 @@ def serving_mesh_shape(cls):


@servable_model_registry.register
class Gamma2BFP16With4Replicas(Gamma2BFP16):
"""Gamma2B model on v4-8 or v5e-4 both with 4 replications."""
class Gemma2BFP16With4Replicas(Gemma2BFP16):
"""Gemma2B model on v4-8 or v5e-4 both with 4 replications."""

@classmethod
def serving_mesh_shape(cls):
Expand All @@ -216,8 +216,8 @@ def serving_mesh_shape(cls):


@servable_model_registry.register
class Gamma7BFP16With2Replicas(Gamma7BFP16):
"""Gamma7B model on v5e-8 with 2 replications."""
class Gemma7BFP16With2Replicas(Gemma7BFP16):
"""Gemma7B model on v5e-8 with 2 replications."""

@classmethod
def serving_mesh_shape(cls):
Expand All @@ -228,19 +228,19 @@ def serving_mesh_shape(cls):

@servable_model_registry.register
@quantization.for_transformer(quantize_on_the_fly=False)
class Gamma2BInt8(Gamma2BFP16):
"""Gamma2B model with int8 weight quantization."""
class Gemma2BInt8(Gemma2BFP16):
"""Gemma2B model with int8 weight quantization."""


@servable_model_registry.register
@quantization.for_transformer(quantize_on_the_fly=False)
class Gamma7BInt8(Gamma7BFP16):
"""Gamma7B model with int8 weight quantization."""
class Gemma7BInt8(Gemma7BFP16):
"""Gemma7B model with int8 weight quantization."""


@servable_model_registry.register
class Gamma2BInt8With8Replicas(Gamma2BInt8):
"""Gamma2B model with int8 quantization on v4-8 or v5e-8 with 8 replications."""
class Gemma2BInt8With8Replicas(Gemma2BInt8):
"""Gemma2B model with int8 quantization on v4-8 or v5e-8 with 8 replications."""

@classmethod
def serving_mesh_shape(cls):
Expand All @@ -250,8 +250,8 @@ def serving_mesh_shape(cls):


@servable_model_registry.register
class Gamma7BInt8With2Replicas(Gamma7BInt8):
"""Gamma7B model with int8 quantization on v4-8 or v5e-8 with 2 replications."""
class Gemma7BInt8With2Replicas(Gemma7BInt8):
"""Gemma7B model with int8 quantization on v4-8 or v5e-8 with 2 replications."""

@classmethod
def serving_mesh_shape(cls):
Expand All @@ -261,30 +261,30 @@ def serving_mesh_shape(cls):


@servable_model_registry.register
class Gamma2BFP16Test(Gamma2BFP16):
"""Gamma2B model for testing without ckpt."""
class Gemma2BFP16Test(Gemma2BFP16):
"""Gemma2B model for testing without ckpt."""

test_mode = True


@servable_model_registry.register
class Gamma7BFP16Test(Gamma7BFP16):
"""Gamma7B model for testing without ckpt."""
class Gemma7BFP16Test(Gemma7BFP16):
"""Gemma7B model for testing without ckpt."""

test_mode = True


@servable_model_registry.register
@quantization.for_transformer(quantize_on_the_fly=False)
class Gamma2BInt8Test(Gamma2BInt8):
"""Gamma2B model with int8 weight quantization for testing without ckpt."""
class Gemma2BInt8Test(Gemma2BInt8):
"""Gemma2B model with int8 weight quantization for testing without ckpt."""

test_mode = True


@servable_model_registry.register
@quantization.for_transformer(quantize_on_the_fly=False)
class Gamma7BInt8Test(Gamma7BInt8):
"""Gamma7B model with int8 weight quantization for testing without ckpt."""
class Gemma7BInt8Test(Gemma7BInt8):
"""Gemma7B model with int8 weight quantization for testing without ckpt."""

test_mode = True
6 changes: 3 additions & 3 deletions saxml/server/pax/lm/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from saxml.server.pax.lm import layers as sax_layers


def gamma(
def gemma(
vocab_size,
model_dims,
hidden_dims,
Expand All @@ -28,7 +28,7 @@ def gamma(
dim_per_head,
use_mqa,
) -> pax_fiddle.Config[layers.TransformerLm]:
"""Create a TransformerLm config(template) for Gamma model family.
"""Create a TransformerLm config(template) for Gemma model family.
Args:
vocab_size: Size of vocabulary.
Expand All @@ -40,7 +40,7 @@ def gamma(
use_mqa: Whether use Multi-Query Attention.
Returns:
TransformerLm for Gamma.
TransformerLm for Gemma.
"""
model_p = pax_fiddle.Config(layers.TransformerLm)
model_p.vocab_size = vocab_size
Expand Down
4 changes: 2 additions & 2 deletions saxml/tools/offline_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def parse_known_args(argv):
default='gptj',
choices=[
'gptj',
'gamma2b',
'gamma7b',
'gemma2b',
'gemma7b',
'llama2-70b-weight-linear-only-int8',
],
help='Quantization Config.',
Expand Down
8 changes: 4 additions & 4 deletions saxml/tools/quantization_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ class QuantizationConfigsGPTJStacked(QuantizationConfigs):
}


class QuantizationConfigsGamma2B(QuantizationConfigs):
"""Quantization config for Gamma 2B."""
class QuantizationConfigsGemma2B(QuantizationConfigs):
"""Quantization config for Gemma 2B."""

factor = 1.0
configs = {
Expand All @@ -94,8 +94,8 @@ class QuantizationConfigsGamma2B(QuantizationConfigs):
}


class QuantizationConfigsGamma7B(QuantizationConfigsGPTJ):
"""Quantization config for Gamma 7B."""
class QuantizationConfigsGemma7B(QuantizationConfigsGPTJ):
"""Quantization config for Gemma 7B."""


class QuantizationConfigsLLaMA70BWeightLinearOnlyInt8(QuantizationConfigs):
Expand Down
4 changes: 2 additions & 2 deletions saxml/tools/quantization_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

NAME_TO_CONFIG = {
'gptj': quantization_configs.QuantizationConfigsGPTJ(),
'gamma2b': quantization_configs.QuantizationConfigsGamma2B(),
'gamma7b': quantization_configs.QuantizationConfigsGamma7B(),
'gemma2b': quantization_configs.QuantizationConfigsGemma2B(),
'gemma7b': quantization_configs.QuantizationConfigsGemma7B(),
'llama2-70b-weight-linear-only-int8': (
quantization_configs.QuantizationConfigsLLaMA70BWeightLinearOnlyInt8()
),
Expand Down

0 comments on commit be31b49

Please sign in to comment.