Skip to content

Commit

Permalink
add notebook for numerical verification
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuLi-goog committed Nov 15, 2024
1 parent 061abd8 commit f6f6c2d
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 1 deletion.
2 changes: 1 addition & 1 deletion MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,6 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel):
softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype)
# shape of top_k_weights & top_k_indices: (batch, sequence, num_experts_per_tok)
top_k_weights, top_k_indices = jax.lax.top_k(softmax_probs, self.num_experts_per_tok)
top_k_weights /= top_k_weights.sum(-1, keepdims=True)
matmul_precision = lax.Precision(self.config.matmul_precision)

if self.config.capacity_factor > 0:
Expand Down Expand Up @@ -605,6 +604,7 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel):
)
return output, loss
else:
top_k_weights /= top_k_weights.sum(-1, keepdims=True)
weights = self.reshape_and_update_weights(top_k_weights, top_k_indices)
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed"))
with jax.named_scope("wi_0"):
Expand Down
248 changes: 248 additions & 0 deletions MaxText/scratch_code/mixtral-numerical-verification.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "bce1951a-8eef-4842-a70f-987b85a3240f",
"metadata": {},
"outputs": [],
"source": [
"# installation\n",
"!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n",
"!pip3 install tokenizers -U\n",
"!pip3 install transformers -U"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9769e847-d838-473d-8d32-1061b3e0f1c8",
"metadata": {},
"outputs": [],
"source": [
"# go to maxtext/MaxText for library import\n",
"\n",
"current_dir = %pwd\n",
"working_dir = current_dir.replace(\"scratch_code\", \"\") \n",
"%cd $working_dir"
]
},
{
"cell_type": "markdown",
"id": "f1c108fc-d739-471d-9c64-c08151845f06",
"metadata": {},
"source": [
"# one layer mixtral model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cf8eee59-295e-41f4-8c09-d2177b410ddc",
"metadata": {},
"outputs": [],
"source": [
"import pyconfig\n",
"from transformers.models.mixtral.configuration_mixtral import MixtralConfig\n",
"\n",
"pyconfig.initialize(\n",
" [None, \"configs/base.yml\"],\n",
" base_emb_dim=4096,\n",
" base_num_query_heads=32,\n",
" base_num_kv_heads=8,\n",
" base_mlp_dim=14336,\n",
" base_num_decoder_layers=1, # 1 layer for simplicity\n",
" head_dim=128,\n",
" mlp_activations=[\"silu\",\"linear\"],\n",
" vocab_size=32000,\n",
" enable_dropout=False,\n",
" logits_via_embedding=False,\n",
" normalization_layer_epsilon=1.0e-5,\n",
" num_experts=8,\n",
" num_experts_per_tok=2,\n",
" rope_max_timescale=1_000_000,\n",
" decoder_block=\"mistral\",\n",
" run_name=\"moe_test\",\n",
" enable_checkpointing=False,\n",
" dtype=\"bfloat16\",\n",
" weight_dtype=\"bfloat16\",\n",
" megablox=True, # or False\n",
" max_target_length=4,\n",
" per_device_batch_size=1,\n",
" capacity_factor=-1,\n",
" scan_layers=False,\n",
")\n",
"config_maxtext = pyconfig.config\n",
"\n",
"config_hf = MixtralConfig(\n",
" vocab_size=config_maxtext.vocab_size,\n",
" hidden_size=config_maxtext.emb_dim,\n",
" intermediate_size=config_maxtext.mlp_dim,\n",
" num_hidden_layers=config_maxtext.num_decoder_layers, \n",
" num_attention_heads=config_maxtext.base_num_query_heads,\n",
" num_key_value_heads=config_maxtext.num_kv_heads,\n",
" rms_norm_eps=config_maxtext.normalization_layer_epsilon,\n",
" rope_theta=config_maxtext.rope_max_timescale,\n",
" attention_dropout=0.0,\n",
" num_experts_per_tok=config_maxtext.num_experts_per_tok,\n",
" num_local_experts=config_maxtext.num_experts,\n",
" tie_word_embeddings=config_maxtext.logits_via_embedding,\n",
" output_router_logits=False,\n",
" router_aux_loss_coef=0.001,\n",
" router_jitter_noise=0.0,\n",
" torch_dtype=\"bfloat16\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c94c857a-2efd-48f3-9669-aef926329cbd",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForCausalLM, set_seed\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from layers.models import Transformer\n",
"import max_utils\n",
"from jax.sharding import Mesh\n",
"\n",
"# ensure the same model initialization\n",
"set_seed(0)\n",
"\n",
"model_hf = AutoModelForCausalLM.from_config(config_hf)\n",
"\n",
"devices_array = max_utils.create_device_mesh(config_maxtext)\n",
"mesh = Mesh(devices_array, config_maxtext.mesh_axes)\n",
"prng_key = jax.random.PRNGKey(1234)\n",
"model_maxtext = Transformer(config=config_maxtext, mesh=mesh, quant=None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "707df022-ec37-44b3-b203-5f938151c6ca",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"input_np = {\n",
" 'inputs': np.random.randint(0, config_maxtext.vocab_size, size=(int(config_maxtext.per_device_batch_size), config_maxtext.max_target_length)),\n",
" 'inputs_position': np.tile(np.arange(config_maxtext.max_target_length), (int(config_maxtext.per_device_batch_size), 1)),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "baca50fb-28f2-48b1-b4f5-0145ac6cfe38",
"metadata": {},
"outputs": [],
"source": [
"state_maxtext = model_maxtext.init({'params': prng_key, 'dropout': prng_key, 'aqt': prng_key},\n",
" jnp.array(input_np['inputs']),\n",
" jnp.array(input_np['inputs_position']),\n",
" enable_dropout=config_maxtext.enable_dropout,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74e8353b-b87a-4c5e-9a7c-138052249250",
"metadata": {},
"outputs": [],
"source": [
"import torch \n",
"from flax import linen as nn\n",
"\n",
"state_map = {\n",
" \"['params']['decoder']['decoder_norm']['scale'].value\": (\"model.norm.weight\", lambda x: x), \n",
" \"['params']['decoder']['layers_0']['MoeBlock_0']['gate']['kernel'].value\": (\"model.layers.0.block_sparse_moe.gate.weight\", lambda x: x.T),\n",
" \"['params']['decoder']['layers_0']['MoeBlock_0']['wi_0'].value\": (\"model.layers.0.block_sparse_moe.experts.<exp_idx>.w1.weight\", lambda *x: torch.stack(*x, dim=0).transpose(1,2)),\n",
" \"['params']['decoder']['layers_0']['MoeBlock_0']['wi_1'].value\": (\"model.layers.0.block_sparse_moe.experts.<exp_idx>.w3.weight\", lambda *x: torch.stack(*x, dim=0).transpose(1,2)),\n",
" \"['params']['decoder']['layers_0']['MoeBlock_0']['wo'].value\": (\"model.layers.0.block_sparse_moe.experts.<exp_idx>.w2.weight\", lambda *x: torch.stack(*x, dim=0).transpose(1,2)),\n",
" \"['params']['decoder']['layers_0']['post_self_attention_layer_norm']['scale'].value\": (\"model.layers.0.post_attention_layernorm.weight\", lambda x: x),\n",
" \"['params']['decoder']['layers_0']['pre_self_attention_layer_norm']['scale'].value\": (\"model.layers.0.input_layernorm.weight\", lambda x:x),\n",
" \"['params']['decoder']['layers_0']['self_attention']['key']['kernel'].value\": (\"model.layers.0.self_attn.k_proj.weight\", lambda x:x.T.reshape(config_hf.hidden_size, config_hf.num_key_value_heads, config_maxtext.head_dim)),\n",
" \"['params']['decoder']['layers_0']['self_attention']['out']['kernel'].value\": (\"model.layers.0.self_attn.o_proj.weight\", lambda x:x.T.reshape(config_hf.num_attention_heads, config_maxtext.head_dim, config_hf.hidden_size)),\n",
" \"['params']['decoder']['layers_0']['self_attention']['query']['kernel'].value\": (\"model.layers.0.self_attn.q_proj.weight\", lambda x:x.T.reshape(config_hf.hidden_size, config_hf.num_attention_heads, config_maxtext.head_dim) / np.sqrt(config_maxtext.head_dim)),\n",
" \"['params']['decoder']['layers_0']['self_attention']['value']['kernel'].value\": (\"model.layers.0.self_attn.v_proj.weight\", lambda x:x.T.reshape(config_hf.hidden_size, config_hf.num_key_value_heads, config_maxtext.head_dim)),\n",
" \"['params']['decoder']['logits_dense']['kernel'].value\": (\"lm_head.weight\", lambda x:x.T),\n",
" \"['params']['token_embedder']['embedding'].value\": (\"model.embed_tokens.weight\", lambda x:x),\n",
" }\n",
"\n",
"state_hf = model_hf.state_dict()\n",
"def map_fn(key_path, value):\n",
" key_path_str = jax.tree_util.keystr(key_path)\n",
" torch_key, transform_fn = state_map[key_path_str]\n",
" if \"<exp_idx>\" in torch_key:\n",
" torch_tensors = [state_hf[torch_key.replace(\"<exp_idx>\", str(i))] for i in range(config_hf.num_local_experts)]\n",
" else:\n",
" torch_tensors = state_hf[torch_key]\n",
" \n",
" torch_tensors = transform_fn(torch_tensors)\n",
"\n",
" assert value.shape == torch_tensors.shape, f\"{key_path_str}, {value.shape}, {torch_tensors.shape}\"\n",
" new_value = jnp.array(torch_tensors.to(torch.float32).numpy(), dtype=value.dtype)\n",
" if isinstance(value, nn.LogicallyPartitioned):\n",
" new_value = value.replace_boxed(new_value)\n",
" return new_value\n",
"\n",
"loaded_state_maxtext = jax.tree_util.tree_map_with_path(map_fn, state_maxtext)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1f88708-c3a6-4b95-bc51-94adfebdf2aa",
"metadata": {},
"outputs": [],
"source": [
"logits_hf = model_hf(torch.from_numpy(input_np['inputs'])).logits.detach()\n",
"\n",
"logits_maxtext = model_maxtext.apply(\n",
" loaded_state_maxtext,\n",
" input_np['inputs'],\n",
" input_np['inputs_position'],\n",
" enable_dropout=False,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1207375a-b92c-4a8c-975a-21f2f027d91e",
"metadata": {},
"outputs": [],
"source": [
"# currently, pass the following tests in both \"megablox=True\" & \"megablox=False capacity_factor=-1\"\n",
"\n",
"np.testing.assert_allclose(np.array(logits_maxtext), logits_hf.numpy(), rtol=1e-1, atol=1e-1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit f6f6c2d

Please sign in to comment.