Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question]: CUDA error: an illegal memory access was encountered when running benchmark_e2e.py #86

Open
lepangdan opened this issue Nov 20, 2024 · 6 comments
Assignees
Labels
question Further information is requested

Comments

@lepangdan
Copy link

Describe the bug

Hi,
I am running into an issue when executing
python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 1_000_000
The error is:

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Then, I tried setting CUDA_LAUNCH_BLOCKING=1, and I got the error info
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

Would you be able to provide any guidance on the possible causes of this error, or suggest debugging steps? Thanks in advance!

Steps to reproduce

No response

Expected Behavior

No response

Logs

No response

Additional Information

triton version: 2.2.0
torch version: 2.1.1+cu121
CUDA version: 12.2

@lepangdan lepangdan added the bug Something isn't working label Nov 20, 2024
@lepangdan
Copy link
Author

Addition: A different error reported on another triton version 2.1.0:

  • Command: python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 1_000_000
  • Error:
run_target_length
loc(callsite("/home/.local/lib/python3.10/site-packages/triton/language/core.py":1398:21 at "/home/far/MInference/minference/ops/flash_attn_triton.py":143:61)): error: 'triton_gpu.cmpf' op requires the same encoding for all operands
Traceback (most recent call last):
  File "/home/far/MInference/experiments/benchmarks/benchmark_e2e.py", line 140, in <module>
    run_target_length(args.context_window, model, args.attn_type)
  File "/home/far/MInference/experiments/benchmarks/benchmark_e2e.py", line 34, in run_target_length
    model(input_ids, attention_mask, use_cache=False)
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/far/MInference/minference/patch.py", line 790, in forward_llama_for_causal_lm
    outputs = self.model(
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/far/MInference/minference/patch.py", line 712, in forward_llama_model
    layer_outputs = decoder_layer(
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/far/MInference/minference/patch.py", line 570, in forward_llama_decoder_layer
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/far/MInference/minference/modules/minference_forward.py", line 549, in forward
    attn_output = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)
  File "/home/far/MInference/minference/modules/minference_forward.py", line 447, in gather_last_q_vertical_slash_topk_v4
    return dense(q, k, v)
  File "/home/far/MInference/minference/modules/minference_forward.py", line 431, in dense
    return flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, self.head_dim)
  File "/home/far/MInference/minference/ops/flash_attn_triton.py", line 258, in _flash_attn_triton_decoding
    _fwd_kernel[grid](
  File "/home/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 232, in run
    return self.fn.run(*args, **kwargs)
  File "<string>", line 63, in _fwd_kernel
  File "/home/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 476, in compile
    next_module = compile_kernel(module)
  File "/home/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 383, in <lambda>
    lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, arch))
  File "/home/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 91, in optimize_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed

minference-0.1.5.post1
triton-2.1.0
torch 2.1.1+cu121

@iofu728 iofu728 self-assigned this Nov 22, 2024
@iofu728
Copy link
Contributor

iofu728 commented Nov 22, 2024

Hi @lepangdan, thanks for your feedback. This issue seems to be caused by insufficient hardware resources.

Could you please provide details about the type of GPU you are using and the size of your CPU memory?

Additionally, if you need to test with 1M tokens (or any inputs exceeding 200K) within an 80GB GPU memory, you need to enable the --kv_cache_cpu option. Please refer to the https://github.com/microsoft/MInference/tree/main/experiments#end-to-end-benchmark for guidance.

And are you able to run inference for 100K or 500K tokens without issues? Please try the following scripts:

# For 1M tokens
python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window  1_000_000 --kv_cache_cpu

# For 100K tokens
python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window  100_000 

# For 500K tokens
python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window  500_000 --kv_cache_cpu

Let us know the results so we can further diagnose the issue.

@iofu728 iofu728 added question Further information is requested and removed bug Something isn't working labels Nov 22, 2024
@iofu728 iofu728 changed the title [Bug]: CUDA error: an illegal memory access was encountered when running benchmark_e2e.py [Question]: CUDA error: an illegal memory access was encountered when running benchmark_e2e.py Nov 22, 2024
@lepangdan
Copy link
Author

lepangdan commented Nov 22, 2024

Hi @iofu728

My hardware configurations: 1* A100 | 80G gpu memory| 11 core | 100G cpu memory

During the debugging process, the package versions have changed and current version information:
torch: 2.1.1+cu121
triton: 2.1.0
transformers: 4.44.2
cuda: 12.2

Test results you mentioned:

python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 100_000

/home/.local/lib/python3.10/site-packages/transformers/utils/hub.py:127: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
  warnings.warn(
Downloading shards: 100%|████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 470.31it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████| 4/4 [00:42<00:00, 10.62s/it]
Patched model for MInference load KV Cache to CPU.
run_target_length
loc(callsite("/home/.local/lib/python3.10/site-packages/triton/language/core.py":1398:21 at "/home/far/MInference/minference/ops/flash_attn_triton.py":143:61)): error: 'triton_gpu.cmpf' op requires the same encoding for all operands
Traceback (most recent call last):
  File "/home/far/MInference/experiments/benchmarks/benchmark_e2e.py", line 140, in <module>
    run_target_length(args.context_window, model, args.attn_type)
  File "/home/far/MInference/experiments/benchmarks/benchmark_e2e.py", line 34, in run_target_length
    model(input_ids, attention_mask, use_cache=False)
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/far/MInference/minference/patch.py", line 790, in forward_llama_for_causal_lm
    outputs = self.model(
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/far/MInference/minference/patch.py", line 712, in forward_llama_model
    layer_outputs = decoder_layer(
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/far/MInference/minference/patch.py", line 570, in forward_llama_decoder_layer
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/far/MInference/minference/modules/minference_forward.py", line 656, in forward
    part_o = self.gather_last_q_vertical_slash_topk_v4(part_q, part_k, part_v, head)
  File "/home/far/MInference/minference/modules/minference_forward.py", line 447, in gather_last_q_vertical_slash_topk_v4
    return dense(q, k, v)
  File "/home/far/MInference/minference/modules/minference_forward.py", line 431, in dense
    return flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, self.head_dim)
  File "/home/far/MInference/minference/ops/flash_attn_triton.py", line 259, in _flash_attn_triton_decoding
    _fwd_kernel[grid](
  File "/home/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 232, in run
    return self.fn.run(*args, **kwargs)
  File "<string>", line 63, in _fwd_kernel
  File "/home/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 476, in compile
    next_module = compile_kernel(module)
  File "/home/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 383, in <lambda>
    lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, arch))
  File "/home/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 91, in optimize_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed

python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 500_000 --kv_cache_cpu

[same as above, ignore]
RuntimeError: PassManager::run failed

python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 1_000_000 --kv_cache_cpu

[same as above, ignore]
RuntimeError: PassManager::run failed

Additional information:

  1. No error occurs when running hf mode python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 50_000
  2. No error occurs when I try to reproduce the script in other posts about "PassManager::run failed" in https://github.com/pytorch/pytorch/issues/101368#issue-1709080902, RuntimeError: PassManager::run failed / error: 'tt.reduce' op inferred type(s)... triton-lang/triton#1672 (comment)

I wonder if the version of packages matters. I am not certain, do you have any information on this?

Loooooook forward to your reply. Thanks in advance.

@lepangdan
Copy link
Author

lepangdan commented Nov 22, 2024

@iofu728 Add more test results for your information: hf and minference runs as expected. However minference_with_dense mode is always reporting the error: 'RuntimeError: PassManager::run failed.' Is it possible that something might be wrong in the code under minference_with_dense mode, perhaps as indicated by the error trace in flash_attn_triton.py?

hf:

python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 50_000
hf 50000 9.624406599998474
python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 100_000
torch.cuda.OutOfMemoryError: CUDA out of memory.

minference:

python experiments/benchmarks/benchmark_e2e.py --attn_type minference --context_window  10_000
minference 10000 2.600739598274231
python experiments/benchmarks/benchmark_e2e.py --attn_type minference --context_window 100_000
minference 100000 12.096721625328064
python experiments/benchmarks/benchmark_e2e.py --attn_type minference --context_window  1_000_000
torch.cuda.OutOfMemoryError: CUDA out of memory.
python experiments/benchmarks/benchmark_e2e.py --attn_type minference --context_window  1_000_000 --kv_cache_cpu
minference 1000000 146.8974978685379

@iofu728
Copy link
Contributor

iofu728 commented Nov 25, 2024

Thank you for the information. However, it's a bit strange that "minference" runs normally while "minference_with_dense" doesn't. Please check whether your local flash_attn is functioning correctly, and consider reinstalling it.

Additionally, triton and flash_attn are not limited to specific versions, so you could try using other versions to see if the issue persists.

@lepangdan
Copy link
Author

lepangdan commented Nov 25, 2024

It works after installing flash_attn. Many thanks!

I noticed that the key point is in the minference_forward.pyfile

try:
    from flash_attn import flash_attn_func
except ImportError:
    from ..ops.flash_attn_triton import _flash_attn_triton_decoding as flash_attn_func

The mentioned error RuntimeError: PassManager::run failed , occurred in _flash_attn_triton_decoding. After installing flash_attn, it switched to using flash_attn_func from the flash_attn package, and the error has not occurred since. However, I'm curious about the differences between these two implementations of flash_attn_func, why a _flash_attn_triton_decoding version of flash_attn_func was introduced, and why _flash_attn_triton_decoding failed in this context?

Additionally, it might be helpful to add a warning when the flash_attn module import fails and the code switches to a different version of flash_attn_func, indicating a change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants