You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The admin and model server is running correctly even I have confirmed that they are communicating by running a sample test query.
The model pickle file is just 22.7 GB so it should acomodate into the TPU cluster. Any idea?
The enviornment
pip3 install accelerate
pip3 install torch
pip3 install transformers
pip install paxml==1.1.0)(Although I have build it from its gitrepo)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 192, in
convert(args.base_model_path, args.pax_model_path)
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 176, in convert
jax_states_gda = pjitted_identity(jax_states)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 22.54G of 15.48G hbm. Exceeded hbm capacity by 7.06G.
Total hbm usage >= 23.06G:
reserved 530.00M
program 4.0K
arguments 22.54G
Hi, I am trying to do the serving on gpt-j 6B model using TPUv3-8. For which I am using saxml framework,
The error is coming when I am doing the model conversion from pytorch to pax format which is supported by sax. This is the conversion script:
https://github.com/mlcommons/inference_results_v3.1/blob/main/closed/Google/code/gptj-99/convert_gptj_ckpt.py
The admin and model server is running correctly even I have confirmed that they are communicating by running a sample test query.
The model pickle file is just 22.7 GB so it should acomodate into the TPU cluster. Any idea?
The enviornment
pip3 install accelerate
pip3 install torch
pip3 install transformers
pip install paxml==1.1.0)(Although I have build it from its gitrepo)
2024-01-03 05:23:41.411871: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
Loading the base model from EleutherAI/gpt-j-6b
transformer.wte.weight (50400, 4096)
transformer.h.0.ln_1.weight (4096,)
transformer.h.0.ln_1.bias (4096,)
transformer.h.0.attn.k_proj.weight (4096, 4096)
transformer.h.0.attn.v_proj.weight (4096, 4096)
transformer.h.0.attn.q_proj.weight (4096, 4096)
transformer.h.0.attn.out_proj.weight (4096, 4096)
transformer.h.0.mlp.fc_in.weight (16384, 4096)
transformer.h.0.mlp.fc_in.bias (16384,)
transformer.h.0.mlp.fc_out.weight (4096, 16384)
transformer.h.0.mlp.fc_out.bias (4096,)
transformer.h.1.ln_1.weight (4096,)
transformer.h.1.ln_1.bias (4096,)
transformer.h.1.attn.k_proj.weight (4096, 4096)
transformer.h.1.attn.v_proj.weight (4096, 4096)
transformer.h.1.attn.q_proj.weight (4096, 4096)
transformer.h.1.attn.out_proj.weight (4096, 4096)
transformer.h.1.mlp.fc_in.weight (16384, 4096)
transformer.h.1.mlp.fc_in.bias (16384,)
transformer.h.1.mlp.fc_out.weight (4096, 16384)
transformer.h.1.mlp.fc_out.bias (4096,)
transformer.h.2.ln_1.weight (4096,)
transformer.h.2.ln_1.bias (4096,)
transformer.h.2.attn.k_proj.weight (4096, 4096)
transformer.h.2.attn.v_proj.weight (4096, 4096)
transformer.h.2.attn.q_proj.weight (4096, 4096)
transformer.h.2.attn.out_proj.weight (4096, 4096)
transformer.h.2.mlp.fc_in.weight (16384, 4096)
transformer.h.2.mlp.fc_in.bias (16384,)
transformer.h.2.mlp.fc_out.weight (4096, 16384)
transformer.h.2.mlp.fc_out.bias (4096,)
transformer.h.3.ln_1.weight (4096,)
transformer.h.3.ln_1.bias (4096,)
transformer.h.3.attn.k_proj.weight (4096, 4096)
transformer.h.3.attn.v_proj.weight (4096, 4096)
transformer.h.3.attn.q_proj.weight (4096, 4096)
transformer.h.3.attn.out_proj.weight (4096, 4096)
transformer.h.3.mlp.fc_in.weight (16384, 4096)
transformer.h.3.mlp.fc_in.bias (16384,)
transformer.h.3.mlp.fc_out.weight (4096, 16384)
transformer.h.3.mlp.fc_out.bias (4096,)
transformer.h.4.ln_1.weight (4096,)
transformer.h.4.ln_1.bias (4096,)
transformer.h.4.attn.k_proj.weight (4096, 4096)
transformer.h.4.attn.v_proj.weight (4096, 4096)
transformer.h.4.attn.q_proj.weight (4096, 4096)
transformer.h.4.attn.out_proj.weight (4096, 4096)
transformer.h.4.mlp.fc_in.weight (16384, 4096)
transformer.h.4.mlp.fc_in.bias (16384,)
transformer.h.4.mlp.fc_out.weight (4096, 16384)
transformer.h.4.mlp.fc_out.bias (4096,)
transformer.h.5.ln_1.weight (4096,)
transformer.h.5.ln_1.bias (4096,)
transformer.h.5.attn.k_proj.weight (4096, 4096)
transformer.h.5.attn.v_proj.weight (4096, 4096)
transformer.h.5.attn.q_proj.weight (4096, 4096)
transformer.h.5.attn.out_proj.weight (4096, 4096)
transformer.h.5.mlp.fc_in.weight (16384, 4096)
transformer.h.5.mlp.fc_in.bias (16384,)
transformer.h.5.mlp.fc_out.weight (4096, 16384)
transformer.h.5.mlp.fc_out.bias (4096,)
transformer.h.6.ln_1.weight (4096,)
transformer.h.6.ln_1.bias (4096,)
transformer.h.6.attn.k_proj.weight (4096, 4096)
transformer.h.6.attn.v_proj.weight (4096, 4096)
transformer.h.6.attn.q_proj.weight (4096, 4096)
transformer.h.6.attn.out_proj.weight (4096, 4096)
transformer.h.6.mlp.fc_in.weight (16384, 4096)
transformer.h.6.mlp.fc_in.bias (16384,)
transformer.h.6.mlp.fc_out.weight (4096, 16384)
transformer.h.6.mlp.fc_out.bias (4096,)
transformer.h.7.ln_1.weight (4096,)
transformer.h.7.ln_1.bias (4096,)
transformer.h.7.attn.k_proj.weight (4096, 4096)
transformer.h.7.attn.v_proj.weight (4096, 4096)
transformer.h.7.attn.q_proj.weight (4096, 4096)
transformer.h.7.attn.out_proj.weight (4096, 4096)
transformer.h.7.mlp.fc_in.weight (16384, 4096)
transformer.h.7.mlp.fc_in.bias (16384,)
transformer.h.7.mlp.fc_out.weight (4096, 16384)
transformer.h.7.mlp.fc_out.bias (4096,)
transformer.h.8.ln_1.weight (4096,)
transformer.h.8.ln_1.bias (4096,)
transformer.h.8.attn.k_proj.weight (4096, 4096)
transformer.h.8.attn.v_proj.weight (4096, 4096)
transformer.h.8.attn.q_proj.weight (4096, 4096)
transformer.h.8.attn.out_proj.weight (4096, 4096)
transformer.h.8.mlp.fc_in.weight (16384, 4096)
transformer.h.8.mlp.fc_in.bias (16384,)
transformer.h.8.mlp.fc_out.weight (4096, 16384)
transformer.h.8.mlp.fc_out.bias (4096,)
transformer.h.9.ln_1.weight (4096,)
transformer.h.9.ln_1.bias (4096,)
transformer.h.9.attn.k_proj.weight (4096, 4096)
transformer.h.9.attn.v_proj.weight (4096, 4096)
transformer.h.9.attn.q_proj.weight (4096, 4096)
transformer.h.9.attn.out_proj.weight (4096, 4096)
transformer.h.9.mlp.fc_in.weight (16384, 4096)
transformer.h.9.mlp.fc_in.bias (16384,)
transformer.h.9.mlp.fc_out.weight (4096, 16384)
transformer.h.9.mlp.fc_out.bias (4096,)
transformer.h.10.ln_1.weight (4096,)
transformer.h.10.ln_1.bias (4096,)
transformer.h.10.attn.k_proj.weight (4096, 4096)
transformer.h.10.attn.v_proj.weight (4096, 4096)
transformer.h.10.attn.q_proj.weight (4096, 4096)
transformer.h.10.attn.out_proj.weight (4096, 4096)
transformer.h.10.mlp.fc_in.weight (16384, 4096)
transformer.h.10.mlp.fc_in.bias (16384,)
transformer.h.10.mlp.fc_out.weight (4096, 16384)
transformer.h.10.mlp.fc_out.bias (4096,)
transformer.h.11.ln_1.weight (4096,)
transformer.h.11.ln_1.bias (4096,)
transformer.h.11.attn.k_proj.weight (4096, 4096)
transformer.h.11.attn.v_proj.weight (4096, 4096)
transformer.h.11.attn.q_proj.weight (4096, 4096)
transformer.h.11.attn.out_proj.weight (4096, 4096)
transformer.h.11.mlp.fc_in.weight (16384, 4096)
transformer.h.11.mlp.fc_in.bias (16384,)
transformer.h.11.mlp.fc_out.weight (4096, 16384)
transformer.h.11.mlp.fc_out.bias (4096,)
transformer.h.12.ln_1.weight (4096,)
transformer.h.12.ln_1.bias (4096,)
transformer.h.12.attn.k_proj.weight (4096, 4096)
transformer.h.12.attn.v_proj.weight (4096, 4096)
transformer.h.12.attn.q_proj.weight (4096, 4096)
transformer.h.12.attn.out_proj.weight (4096, 4096)
transformer.h.12.mlp.fc_in.weight (16384, 4096)
transformer.h.12.mlp.fc_in.bias (16384,)
transformer.h.12.mlp.fc_out.weight (4096, 16384)
transformer.h.12.mlp.fc_out.bias (4096,)
transformer.h.13.ln_1.weight (4096,)
transformer.h.13.ln_1.bias (4096,)
transformer.h.13.attn.k_proj.weight (4096, 4096)
transformer.h.13.attn.v_proj.weight (4096, 4096)
transformer.h.13.attn.q_proj.weight (4096, 4096)
transformer.h.13.attn.out_proj.weight (4096, 4096)
transformer.h.13.mlp.fc_in.weight (16384, 4096)
transformer.h.13.mlp.fc_in.bias (16384,)
transformer.h.13.mlp.fc_out.weight (4096, 16384)
transformer.h.13.mlp.fc_out.bias (4096,)
transformer.h.14.ln_1.weight (4096,)
transformer.h.14.ln_1.bias (4096,)
transformer.h.14.attn.k_proj.weight (4096, 4096)
transformer.h.14.attn.v_proj.weight (4096, 4096)
transformer.h.14.attn.q_proj.weight (4096, 4096)
transformer.h.14.attn.out_proj.weight (4096, 4096)
transformer.h.14.mlp.fc_in.weight (16384, 4096)
transformer.h.14.mlp.fc_in.bias (16384,)
transformer.h.14.mlp.fc_out.weight (4096, 16384)
transformer.h.14.mlp.fc_out.bias (4096,)
transformer.h.15.ln_1.weight (4096,)
transformer.h.15.ln_1.bias (4096,)
transformer.h.15.attn.k_proj.weight (4096, 4096)
transformer.h.15.attn.v_proj.weight (4096, 4096)
transformer.h.15.attn.q_proj.weight (4096, 4096)
transformer.h.15.attn.out_proj.weight (4096, 4096)
transformer.h.15.mlp.fc_in.weight (16384, 4096)
transformer.h.15.mlp.fc_in.bias (16384,)
transformer.h.15.mlp.fc_out.weight (4096, 16384)
transformer.h.15.mlp.fc_out.bias (4096,)
transformer.h.16.ln_1.weight (4096,)
transformer.h.16.ln_1.bias (4096,)
transformer.h.16.attn.k_proj.weight (4096, 4096)
transformer.h.16.attn.v_proj.weight (4096, 4096)
transformer.h.16.attn.q_proj.weight (4096, 4096)
transformer.h.16.attn.out_proj.weight (4096, 4096)
transformer.h.16.mlp.fc_in.weight (16384, 4096)
transformer.h.16.mlp.fc_in.bias (16384,)
transformer.h.16.mlp.fc_out.weight (4096, 16384)
transformer.h.16.mlp.fc_out.bias (4096,)
transformer.h.17.ln_1.weight (4096,)
transformer.h.17.ln_1.bias (4096,)
transformer.h.17.attn.k_proj.weight (4096, 4096)
transformer.h.17.attn.v_proj.weight (4096, 4096)
transformer.h.17.attn.q_proj.weight (4096, 4096)
transformer.h.17.attn.out_proj.weight (4096, 4096)
transformer.h.17.mlp.fc_in.weight (16384, 4096)
transformer.h.17.mlp.fc_in.bias (16384,)
transformer.h.17.mlp.fc_out.weight (4096, 16384)
transformer.h.17.mlp.fc_out.bias (4096,)
transformer.h.18.ln_1.weight (4096,)
transformer.h.18.ln_1.bias (4096,)
transformer.h.18.attn.k_proj.weight (4096, 4096)
transformer.h.18.attn.v_proj.weight (4096, 4096)
transformer.h.18.attn.q_proj.weight (4096, 4096)
transformer.h.18.attn.out_proj.weight (4096, 4096)
transformer.h.18.mlp.fc_in.weight (16384, 4096)
transformer.h.18.mlp.fc_in.bias (16384,)
transformer.h.18.mlp.fc_out.weight (4096, 16384)
transformer.h.18.mlp.fc_out.bias (4096,)
transformer.h.19.ln_1.weight (4096,)
transformer.h.19.ln_1.bias (4096,)
transformer.h.19.attn.k_proj.weight (4096, 4096)
transformer.h.19.attn.v_proj.weight (4096, 4096)
transformer.h.19.attn.q_proj.weight (4096, 4096)
transformer.h.19.attn.out_proj.weight (4096, 4096)
transformer.h.19.mlp.fc_in.weight (16384, 4096)
transformer.h.19.mlp.fc_in.bias (16384,)
transformer.h.19.mlp.fc_out.weight (4096, 16384)
transformer.h.19.mlp.fc_out.bias (4096,)
transformer.h.20.ln_1.weight (4096,)
transformer.h.20.ln_1.bias (4096,)
transformer.h.20.attn.k_proj.weight (4096, 4096)
transformer.h.20.attn.v_proj.weight (4096, 4096)
transformer.h.20.attn.q_proj.weight (4096, 4096)
transformer.h.20.attn.out_proj.weight (4096, 4096)
transformer.h.20.mlp.fc_in.weight (16384, 4096)
transformer.h.20.mlp.fc_in.bias (16384,)
transformer.h.20.mlp.fc_out.weight (4096, 16384)
transformer.h.20.mlp.fc_out.bias (4096,)
transformer.h.21.ln_1.weight (4096,)
transformer.h.21.ln_1.bias (4096,)
transformer.h.21.attn.k_proj.weight (4096, 4096)
transformer.h.21.attn.v_proj.weight (4096, 4096)
transformer.h.21.attn.q_proj.weight (4096, 4096)
transformer.h.21.attn.out_proj.weight (4096, 4096)
transformer.h.21.mlp.fc_in.weight (16384, 4096)
transformer.h.21.mlp.fc_in.bias (16384,)
transformer.h.21.mlp.fc_out.weight (4096, 16384)
transformer.h.21.mlp.fc_out.bias (4096,)
transformer.h.22.ln_1.weight (4096,)
transformer.h.22.ln_1.bias (4096,)
transformer.h.22.attn.k_proj.weight (4096, 4096)
transformer.h.22.attn.v_proj.weight (4096, 4096)
transformer.h.22.attn.q_proj.weight (4096, 4096)
transformer.h.22.attn.out_proj.weight (4096, 4096)
transformer.h.22.mlp.fc_in.weight (16384, 4096)
transformer.h.22.mlp.fc_in.bias (16384,)
transformer.h.22.mlp.fc_out.weight (4096, 16384)
transformer.h.22.mlp.fc_out.bias (4096,)
transformer.h.23.ln_1.weight (4096,)
transformer.h.23.ln_1.bias (4096,)
transformer.h.23.attn.k_proj.weight (4096, 4096)
transformer.h.23.attn.v_proj.weight (4096, 4096)
transformer.h.23.attn.q_proj.weight (4096, 4096)
transformer.h.23.attn.out_proj.weight (4096, 4096)
transformer.h.23.mlp.fc_in.weight (16384, 4096)
transformer.h.23.mlp.fc_in.bias (16384,)
transformer.h.23.mlp.fc_out.weight (4096, 16384)
transformer.h.23.mlp.fc_out.bias (4096,)
transformer.h.24.ln_1.weight (4096,)
transformer.h.24.ln_1.bias (4096,)
transformer.h.24.attn.k_proj.weight (4096, 4096)
transformer.h.24.attn.v_proj.weight (4096, 4096)
transformer.h.24.attn.q_proj.weight (4096, 4096)
transformer.h.24.attn.out_proj.weight (4096, 4096)
transformer.h.24.mlp.fc_in.weight (16384, 4096)
transformer.h.24.mlp.fc_in.bias (16384,)
transformer.h.24.mlp.fc_out.weight (4096, 16384)
transformer.h.24.mlp.fc_out.bias (4096,)
transformer.h.25.ln_1.weight (4096,)
transformer.h.25.ln_1.bias (4096,)
transformer.h.25.attn.k_proj.weight (4096, 4096)
transformer.h.25.attn.v_proj.weight (4096, 4096)
transformer.h.25.attn.q_proj.weight (4096, 4096)
transformer.h.25.attn.out_proj.weight (4096, 4096)
transformer.h.25.mlp.fc_in.weight (16384, 4096)
transformer.h.25.mlp.fc_in.bias (16384,)
transformer.h.25.mlp.fc_out.weight (4096, 16384)
transformer.h.25.mlp.fc_out.bias (4096,)
transformer.h.26.ln_1.weight (4096,)
transformer.h.26.ln_1.bias (4096,)
transformer.h.26.attn.k_proj.weight (4096, 4096)
transformer.h.26.attn.v_proj.weight (4096, 4096)
transformer.h.26.attn.q_proj.weight (4096, 4096)
transformer.h.26.attn.out_proj.weight (4096, 4096)
transformer.h.26.mlp.fc_in.weight (16384, 4096)
transformer.h.26.mlp.fc_in.bias (16384,)
transformer.h.26.mlp.fc_out.weight (4096, 16384)
transformer.h.26.mlp.fc_out.bias (4096,)
transformer.h.27.ln_1.weight (4096,)
transformer.h.27.ln_1.bias (4096,)
transformer.h.27.attn.k_proj.weight (4096, 4096)
transformer.h.27.attn.v_proj.weight (4096, 4096)
transformer.h.27.attn.q_proj.weight (4096, 4096)
transformer.h.27.attn.out_proj.weight (4096, 4096)
transformer.h.27.mlp.fc_in.weight (16384, 4096)
transformer.h.27.mlp.fc_in.bias (16384,)
transformer.h.27.mlp.fc_out.weight (4096, 16384)
transformer.h.27.mlp.fc_out.bias (4096,)
transformer.ln_f.weight (4096,)
transformer.ln_f.bias (4096,)
lm_head.weight (50400, 4096)
lm_head.bias (50400,)
Saving the pax model to pax_6b
Traceback (most recent call last):
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 192, in
convert(args.base_model_path, args.pax_model_path)
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 176, in convert
jax_states_gda = pjitted_identity(jax_states)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 248, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 195, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/core.py", line 2591, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/core.py", line 362, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/core.py", line 816, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 1246, in _pjit_call_impl
compiled = _pjit_lower(
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2836, in compile
self._executable = UnloadedMeshExecutable.from_hlo(
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 3048, in from_hlo
xla_executable = dispatch.compile_or_get_cached(
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
return backend_compile(backend, serialized_computation, compile_options,
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 22.54G of 15.48G hbm. Exceeded hbm capacity by 7.06G.
Total hbm usage >= 23.06G:
reserved 530.00M
program 4.0K
arguments 22.54G
Output size 22.54G; shares 0B with arguments.
Program hbm requirement 4.0K:
global 4.0K
Largest program allocations in hbm:
Size: 4.0K
Shape: u32[8,128]{1,0}
Unpadded size: 4.0K
XLA label: constant literal
Allocation type: global
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 192, in
convert(args.base_model_path, args.pax_model_path)
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 176, in convert
jax_states_gda = pjitted_identity(jax_states)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 22.54G of 15.48G hbm. Exceeded hbm capacity by 7.06G.
Total hbm usage >= 23.06G:
reserved 530.00M
program 4.0K
arguments 22.54G
Output size 22.54G; shares 0B with arguments.
Program hbm requirement 4.0K:
global 4.0K
Largest program allocations in hbm:
Size: 4.0K
Shape: u32[8,128]{1,0}
Unpadded size: 4.0K
XLA label: constant literal
Allocation type: global
@zhihaoshan-google
The text was updated successfully, but these errors were encountered: