Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shared memory out of resource.Need 135200M memory! #63

Closed
DESEOUMAIGA opened this issue Nov 6, 2023 · 14 comments
Closed

Shared memory out of resource.Need 135200M memory! #63

DESEOUMAIGA opened this issue Nov 6, 2023 · 14 comments

Comments

@DESEOUMAIGA
Copy link

Remind the testers of this project that it requires 130GB of memory for the demo.

'triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 135200, Hardware limit: 101376. Reducing block sizes or num_stages may help.'

@cpuhrsch
Copy link
Contributor

Hello @DESEOUMAIGA,

Thank you for opening the issue. Do you mind sharing a bit more detail about the environment and GPU type you ran this on? Also, are you trying to reproduce the experiments or are you trying to use this in an end-to-end context?

Thank you,
Christian

@Chris-toe-pher
Copy link

After fixing some tensor type bugs/issues (float32 to float16...) I can confirm the same error:

raise OutOfResources(self.shared, max_shared, "shared memory")
triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 133136, Hardware limit: 101376. Reducing block sizes or num_stages may help.

RTX A6000 / 48GB.

@cjpurackal
Copy link

I'm getting the same error:

triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 135200, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

even after trying to reduce the block_size and num_stages in flash_4.py

@rbavery
Copy link

rbavery commented Nov 17, 2023

I'm running this on an NVIDIA 3090 (24Gb) and getting the same error when running the amg example.

I set up a fresh conda environment with python 3.10 and followed the install instructions for sam-fast with pip.

OutOfResources: out of resource: shared memory, Required: 135200, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

@cpuhrsch
Copy link
Contributor

Hey all, thanks for giving this project a go! Since this project was optimized for A100s the tuned kernel assumes more shared memory than available in many other GPUs. I'll push a fix to rerun auto-tuning for non A100s.

@cpuhrsch
Copy link
Contributor

Ok, I tried to address this in #67 . Please try again and let me know if it doesn't work :)

@agunapal
Copy link

@cpuhrsch Thanks. I tried it on A10G. It works!

@rbavery
Copy link

rbavery commented Nov 18, 2023

I tried this again on my 24Gb Nvidia 3090, it looks like it works when disabling the custom flash attention kernel, thank you!

I get an error with it enabled when running the amg example, I think this is because I haven't run the experiment in experiments/ to create a new kernel for my gpu, trying that now.

--> [296](https://vscode-remote+ssh-002dremote-002bchattyryan-002ecom.vscode-resource.vscode-cdn.net/home/rave/segment-anything-fast/amg_example/~/mambaforge/envs/sam-fast/lib/python3.10/site-packages/segment_anything_fast/flash_4.py:296) if key not in BEST_CONFIGS:
    [297](https://vscode-remote+ssh-002dremote-002bchattyryan-002ecom.vscode-resource.vscode-cdn.net/home/rave/segment-anything-fast/amg_example/~/mambaforge/envs/sam-fast/lib/python3.10/site-packages/segment_anything_fast/flash_4.py:297)     print("key ", key, " not found. Running autotune. This might take a while.")
    [298](https://vscode-remote+ssh-002dremote-002bchattyryan-002ecom.vscode-resource.vscode-cdn.net/home/rave/segment-anything-fast/amg_example/~/mambaforge/envs/sam-fast/lib/python3.10/site-packages/segment_anything_fast/flash_4.py:298)     import functools

TypeError: argument of type 'NoneType' is not iterable

@rbavery
Copy link

rbavery commented Nov 18, 2023

When I run the experiment script to (i think?) regenerate the triton kernel, I get an out of memory error. It looks like there's no new file in configs/ after running the experiments script, just the A100 config.

python run_experiments.py 16 vit_b ../ ../../segment-anything experiments_data --run-experiments --num-workers 8 --capture_output False 
loading annotations into memory...
Done (t=0.45s)
creating index...
index created!
  0%|                                                                                                                  | 0/310 [00:07<?, ?it/s]

CUDA out of memory. Tried to allocate 12.00 GiB. GPU 0 has a total capacity of 23.69 GiB of which 3.05 GiB is free. Process 11488 has 4.28 GiB memory in use. Including non-PyTorch memory, this process has 16.27 GiB memory in use. Of the allocated memory 14.46 GiB is allocated by PyTorch, and 265.91 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

 File "/home/rave/segment-anything/segment_anything/modeling/image_encoder.py", line 358, in add_decomposed_rel_pos
    attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
  File "/home/rave/segment-anything/segment_anything/modeling/image_encoder.py", line 234, in forward
    attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
  File "/home/rave/segment-anything/segment_anything/modeling/image_encoder.py", line 174, in forward
    x = self.attn(x)
  File "/home/rave/segment-anything/segment_anything/modeling/image_encoder.py", line 112, in forward
    x = blk(x)
  File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 133, in build_results_batch
    features_batch = encoder(input_image_batch)
  File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 199, in build_results
    _ = batch_runner(predictor, batch, batch_size, pad_input_image_batch)
  File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 228, in identity_runner
    return fn(*args, **kwargs)
  File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 409, in run
    results, avg_ms_per_img, num_batches, num_images = runner(build_results,
  File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 446, in <module>
    fire.Fire(run)
torch.cuda.OutOfMemoryError: CUDA out of memory. 

@cpuhrsch
Copy link
Contributor

Hello @rbavery - I just pushed #73 , hope that resolves your issue.

For the OOM error, it's entirely possible you don't have enough RAM for batch size 16. Can you try 8 or 4 or maybe even 1? As in run with python run_experiments.py 8 [...] instead of python run_experiments.py 16 [...]

Just to make sure you won't think I forgot about this, I'm going to be on vacation until November 27th starting today. I'll review it again once I'm back or in between if I find the time.

Thank you!
Christian

@rbavery
Copy link

rbavery commented Nov 29, 2023

Thanks I'm running this now with a smaller batch size and it's working! and the BEST_CONFIGS issue fix works as well.

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Dec 1, 2023

I'll close this issue now because we seem to have been able to address the issues here, but please reopen if that doesn't apply!

@cpuhrsch cpuhrsch closed this as completed Dec 1, 2023
@PranavPanicker
Copy link

@cpuhrsch Hey there! I am encountering a similar problem with this error when I am running a Phi3 instruct LLM-based model (from huggingface). I am running this on one GPU of an NVIDIA DGX node (80GB A100 gpus). I need some help and can't understand how to rectify this issue. Any help from your end would be great!

@YuhengHuang42
Copy link

@cpuhrsch Hey there! I am encountering a similar problem with this error when I am running a Phi3 instruct LLM-based model (from huggingface). I am running this on one GPU of an NVIDIA DGX node (80GB A100 gpus). I need some help and can't understand how to rectify this issue. Any help from your end would be great!

I encountered the same problem. Just update the triton version according to https://huggingface.co/microsoft/Phi-3-small-8k-instruct

Install tiktoken (0.6.0) ans triton (2.3.0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants