From be31b49b1f8ae03e2c296bfb2f945e7ad471c5e9 Mon Sep 17 00:00:00 2001 From: Zhihao Shan Date: Tue, 20 Feb 2024 20:09:22 -0800 Subject: [PATCH] SAXML 1.2.0 release --- RELEASE.md | 5 ++ saxml/server/pax/lm/BUILD | 2 +- saxml/server/pax/lm/all_imports.py | 2 +- saxml/server/pax/lm/params/BUILD | 4 +- .../pax/lm/params/{gamma.py => gemma.py} | 66 +++++++++---------- saxml/server/pax/lm/transformer_models.py | 6 +- saxml/tools/offline_quantize.py | 4 +- saxml/tools/quantization_configs.py | 8 +-- saxml/tools/quantization_provider.py | 4 +- 9 files changed, 53 insertions(+), 48 deletions(-) rename saxml/server/pax/lm/params/{gamma.py => gemma.py} (79%) diff --git a/RELEASE.md b/RELEASE.md index 504abfe3..82245519 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,8 @@ +# Release 1.2.0 +## Major Features and Improvements +* Open-source model: Gemma. +* Continuous Batching experimental support. + # Release 1.1.0 ## Major Features and Improvements diff --git a/saxml/server/pax/lm/BUILD b/saxml/server/pax/lm/BUILD index 9ec4d9d4..6969f828 100644 --- a/saxml/server/pax/lm/BUILD +++ b/saxml/server/pax/lm/BUILD @@ -17,7 +17,7 @@ pytype_strict_library( deps = [ "//saxml/server:servable_model_registry", "//saxml/server/pax/lm/params:c4", - "//saxml/server/pax/lm/params:gamma", + "//saxml/server/pax/lm/params:gemma", "//saxml/server/pax/lm/params:gptj", "//saxml/server/pax/lm/params:lm_cloud", ], diff --git a/saxml/server/pax/lm/all_imports.py b/saxml/server/pax/lm/all_imports.py index 157b1136..f032cc2c 100644 --- a/saxml/server/pax/lm/all_imports.py +++ b/saxml/server/pax/lm/all_imports.py @@ -19,7 +19,7 @@ from saxml.server.pax.lm.params import lm_cloud from saxml.server.pax.lm.params import c4 from saxml.server.pax.lm.params import gptj -from saxml.server.pax.lm.params import gamma +from saxml.server.pax.lm.params import gemma # Experimental models. # Specify the registry root. diff --git a/saxml/server/pax/lm/params/BUILD b/saxml/server/pax/lm/params/BUILD index 8835e14a..62b465ee 100644 --- a/saxml/server/pax/lm/params/BUILD +++ b/saxml/server/pax/lm/params/BUILD @@ -108,8 +108,8 @@ pytype_strict_library( ) pytype_strict_library( - name = "gamma", - srcs = ["gamma.py"], + name = "gemma", + srcs = ["gemma.py"], srcs_version = "PY3", deps = [ "//saxml/server:servable_model_registry", diff --git a/saxml/server/pax/lm/params/gamma.py b/saxml/server/pax/lm/params/gemma.py similarity index 79% rename from saxml/server/pax/lm/params/gamma.py rename to saxml/server/pax/lm/params/gemma.py index 73b32879..3966f6b0 100644 --- a/saxml/server/pax/lm/params/gamma.py +++ b/saxml/server/pax/lm/params/gemma.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Serving model parameters for Gamma.""" +"""Serving model parameters for Gemma.""" # OSS import placeholder from typing import List @@ -32,10 +32,10 @@ @servable_model_registry.register @template.make_servable(template.ServingTemplate) -class GammaBase(base_experiment.BaseExperiment): - """Gamma Transformer LM configuration.""" +class GemmaBase(base_experiment.BaseExperiment): + """Gemma Transformer LM configuration.""" - SPM_MODEL = 'gs://cloud-tpu-inference-public/sax-tokenizers/gamma/gamma-tokenizer.model' + SPM_MODEL = 'gs://cloud-tpu-inference-public/sax-tokenizers/gemma/gemma-tokenizer.model' SOS_ID = 2 EOS_ID = 1 GENERATE_ONLY = True # No need to compute loss. @@ -85,7 +85,7 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: else: task_p.model = pax_fiddle.Config(layers.LanguageModel, name='xformer_lm') model_p = task_p.model - model_p.lm_tpl = transformer_models.gamma( + model_p.lm_tpl = transformer_models.gemma( vocab_size=self.VOCAB_SIZE, model_dims=self.MODEL_DIMS, hidden_dims=self.HIDDEN_DIMS, @@ -116,8 +116,8 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: @servable_model_registry.register -class Gamma2BFP16(GammaBase): - """Gamma2B model.""" +class Gemma2BFP16(GemmaBase): + """Gemma2B model.""" NUM_LAYERS = 18 VOCAB_SIZE = 256128 @@ -148,15 +148,15 @@ def serving_mesh_shape(cls): @servable_model_registry.register -class Gamma2BFP16Exp(Gamma2BFP16): +class Gemma2BFP16Exp(Gemma2BFP16): BATCH_SIZE = 1 NUM_CACHE_SLOTS = 256 MAX_LIVE_BATCHES = 256 * 4 # BATCH_SIZE is always 1 in this case. @servable_model_registry.register -class Gamma7BFP16(GammaBase): - """Gamma7B model.""" +class Gemma7BFP16(GemmaBase): + """Gemma7B model.""" NUM_LAYERS = 28 VOCAB_SIZE = 256128 @@ -187,15 +187,15 @@ def serving_mesh_shape(cls): @servable_model_registry.register -class Gamma7BFP16Exp(Gamma7BFP16): +class Gemma7BFP16Exp(Gemma7BFP16): BATCH_SIZE = 1 NUM_CACHE_SLOTS = 32 MAX_LIVE_BATCHES = 32 * 4 # BATCH_SIZE is always 1 in this case. @servable_model_registry.register -class Gamma2BFP16With8Replicas(Gamma2BFP16): - """Gamma2B model on v5e-8 with 8 replications.""" +class Gemma2BFP16With8Replicas(Gemma2BFP16): + """Gemma2B model on v5e-8 with 8 replications.""" @classmethod def serving_mesh_shape(cls): @@ -205,8 +205,8 @@ def serving_mesh_shape(cls): @servable_model_registry.register -class Gamma2BFP16With4Replicas(Gamma2BFP16): - """Gamma2B model on v4-8 or v5e-4 both with 4 replications.""" +class Gemma2BFP16With4Replicas(Gemma2BFP16): + """Gemma2B model on v4-8 or v5e-4 both with 4 replications.""" @classmethod def serving_mesh_shape(cls): @@ -216,8 +216,8 @@ def serving_mesh_shape(cls): @servable_model_registry.register -class Gamma7BFP16With2Replicas(Gamma7BFP16): - """Gamma7B model on v5e-8 with 2 replications.""" +class Gemma7BFP16With2Replicas(Gemma7BFP16): + """Gemma7B model on v5e-8 with 2 replications.""" @classmethod def serving_mesh_shape(cls): @@ -228,19 +228,19 @@ def serving_mesh_shape(cls): @servable_model_registry.register @quantization.for_transformer(quantize_on_the_fly=False) -class Gamma2BInt8(Gamma2BFP16): - """Gamma2B model with int8 weight quantization.""" +class Gemma2BInt8(Gemma2BFP16): + """Gemma2B model with int8 weight quantization.""" @servable_model_registry.register @quantization.for_transformer(quantize_on_the_fly=False) -class Gamma7BInt8(Gamma7BFP16): - """Gamma7B model with int8 weight quantization.""" +class Gemma7BInt8(Gemma7BFP16): + """Gemma7B model with int8 weight quantization.""" @servable_model_registry.register -class Gamma2BInt8With8Replicas(Gamma2BInt8): - """Gamma2B model with int8 quantization on v4-8 or v5e-8 with 8 replications.""" +class Gemma2BInt8With8Replicas(Gemma2BInt8): + """Gemma2B model with int8 quantization on v4-8 or v5e-8 with 8 replications.""" @classmethod def serving_mesh_shape(cls): @@ -250,8 +250,8 @@ def serving_mesh_shape(cls): @servable_model_registry.register -class Gamma7BInt8With2Replicas(Gamma7BInt8): - """Gamma7B model with int8 quantization on v4-8 or v5e-8 with 2 replications.""" +class Gemma7BInt8With2Replicas(Gemma7BInt8): + """Gemma7B model with int8 quantization on v4-8 or v5e-8 with 2 replications.""" @classmethod def serving_mesh_shape(cls): @@ -261,30 +261,30 @@ def serving_mesh_shape(cls): @servable_model_registry.register -class Gamma2BFP16Test(Gamma2BFP16): - """Gamma2B model for testing without ckpt.""" +class Gemma2BFP16Test(Gemma2BFP16): + """Gemma2B model for testing without ckpt.""" test_mode = True @servable_model_registry.register -class Gamma7BFP16Test(Gamma7BFP16): - """Gamma7B model for testing without ckpt.""" +class Gemma7BFP16Test(Gemma7BFP16): + """Gemma7B model for testing without ckpt.""" test_mode = True @servable_model_registry.register @quantization.for_transformer(quantize_on_the_fly=False) -class Gamma2BInt8Test(Gamma2BInt8): - """Gamma2B model with int8 weight quantization for testing without ckpt.""" +class Gemma2BInt8Test(Gemma2BInt8): + """Gemma2B model with int8 weight quantization for testing without ckpt.""" test_mode = True @servable_model_registry.register @quantization.for_transformer(quantize_on_the_fly=False) -class Gamma7BInt8Test(Gamma7BInt8): - """Gamma7B model with int8 weight quantization for testing without ckpt.""" +class Gemma7BInt8Test(Gemma7BInt8): + """Gemma7B model with int8 weight quantization for testing without ckpt.""" test_mode = True diff --git a/saxml/server/pax/lm/transformer_models.py b/saxml/server/pax/lm/transformer_models.py index 1218761d..dc3371d8 100644 --- a/saxml/server/pax/lm/transformer_models.py +++ b/saxml/server/pax/lm/transformer_models.py @@ -19,7 +19,7 @@ from saxml.server.pax.lm import layers as sax_layers -def gamma( +def gemma( vocab_size, model_dims, hidden_dims, @@ -28,7 +28,7 @@ def gamma( dim_per_head, use_mqa, ) -> pax_fiddle.Config[layers.TransformerLm]: - """Create a TransformerLm config(template) for Gamma model family. + """Create a TransformerLm config(template) for Gemma model family. Args: vocab_size: Size of vocabulary. @@ -40,7 +40,7 @@ def gamma( use_mqa: Whether use Multi-Query Attention. Returns: - TransformerLm for Gamma. + TransformerLm for Gemma. """ model_p = pax_fiddle.Config(layers.TransformerLm) model_p.vocab_size = vocab_size diff --git a/saxml/tools/offline_quantize.py b/saxml/tools/offline_quantize.py index 6950ca22..43062570 100644 --- a/saxml/tools/offline_quantize.py +++ b/saxml/tools/offline_quantize.py @@ -56,8 +56,8 @@ def parse_known_args(argv): default='gptj', choices=[ 'gptj', - 'gamma2b', - 'gamma7b', + 'gemma2b', + 'gemma7b', 'llama2-70b-weight-linear-only-int8', ], help='Quantization Config.', diff --git a/saxml/tools/quantization_configs.py b/saxml/tools/quantization_configs.py index 2f7ae29e..55d28a99 100644 --- a/saxml/tools/quantization_configs.py +++ b/saxml/tools/quantization_configs.py @@ -79,8 +79,8 @@ class QuantizationConfigsGPTJStacked(QuantizationConfigs): } -class QuantizationConfigsGamma2B(QuantizationConfigs): - """Quantization config for Gamma 2B.""" +class QuantizationConfigsGemma2B(QuantizationConfigs): + """Quantization config for Gemma 2B.""" factor = 1.0 configs = { @@ -94,8 +94,8 @@ class QuantizationConfigsGamma2B(QuantizationConfigs): } -class QuantizationConfigsGamma7B(QuantizationConfigsGPTJ): - """Quantization config for Gamma 7B.""" +class QuantizationConfigsGemma7B(QuantizationConfigsGPTJ): + """Quantization config for Gemma 7B.""" class QuantizationConfigsLLaMA70BWeightLinearOnlyInt8(QuantizationConfigs): diff --git a/saxml/tools/quantization_provider.py b/saxml/tools/quantization_provider.py index aa1f1491..73768459 100644 --- a/saxml/tools/quantization_provider.py +++ b/saxml/tools/quantization_provider.py @@ -18,8 +18,8 @@ NAME_TO_CONFIG = { 'gptj': quantization_configs.QuantizationConfigsGPTJ(), - 'gamma2b': quantization_configs.QuantizationConfigsGamma2B(), - 'gamma7b': quantization_configs.QuantizationConfigsGamma7B(), + 'gemma2b': quantization_configs.QuantizationConfigsGemma2B(), + 'gemma7b': quantization_configs.QuantizationConfigsGemma7B(), 'llama2-70b-weight-linear-only-int8': ( quantization_configs.QuantizationConfigsLLaMA70BWeightLinearOnlyInt8() ),