You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Experimental/Whisper notebook (speedup.ipynb) is not working.
When run the unmodified notebook on RTX 4080/4090 (i.e., it is using the large-v2 model), it takes a lot of time to 'optimize', but at some point it starts to print the following messages:
[2023-07-25 16:24:47,436] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: 'forward' (kernl/.venv/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:564)
reasons: ___check_obj_id(past_key_value, 94111367005152)
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
and then many of these follow:
kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
Tried to increase the cache_size_limit up to 1024, the only result is that it waits much longer to print the error about the cache size and the final outcome is the same.
The final outcome: there is no any speedup, moreover, the 'optimized' variant is usually slower (well, obviously it wasn't fully optimized).
Steps to reproduce
Repro steps are pretty much the same as in the experimental/whisper/README.md:
DOCKER_BUILDKIT=1 docker build -t kernl .
docker run --rm -it --gpus all -v $(pwd):/kernl kernl
apt install libsndfile1-dev # used by a Python audio dependency
pip install datasets soundfile librosa jupyter notebook
jupyter nbconvert --execute --clear-output experimental/whisper/speedup.ipynb --log-level=10
Or, this script could be used:
import time
import torch
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from kernl.model_optimization import optimize_model
torch.set_float32_matmul_precision("high")
torch._dynamo.config.cache_size_limit = 64 # 1024
#torch._dynamo.config.dynamic_shapes = True
max_len = 50 # we do not expect more than 50 tokens per audio.
num_beams = 5
model_name = "openai/whisper-large-v2" # "openai/whisper-tiny"
# audio_dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # small dataset for tests
audio_dataset = load_dataset("librispeech_asr", "clean", split="test")
def get_tokens(item: dict[str, dict]) -> torch.Tensor:
tensor = processor(item["audio"]["array"], return_tensors="pt", sampling_rate=16_000).input_features
return tensor.cuda()
processor = WhisperProcessor.from_pretrained(model_name)
inputs_warmup = get_tokens(audio_dataset[0])
model = WhisperForConditionalGeneration.from_pretrained(model_name).to("cuda").eval()
MAX_ITER = 100
timings_original = list()
transcriptions = list()
with torch.inference_mode(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"):
# warmup
model.generate(inputs_warmup, min_length=max_len, max_length=max_len, num_beams=num_beams, do_sample=False)
torch.cuda.synchronize()
i = 0
for audio in audio_dataset:
inputs = get_tokens(audio)
torch.cuda.synchronize()
start = time.time()
predicted_ids = model.generate(inputs, min_length=1, max_length=max_len, num_beams=num_beams, do_sample=False)
torch.cuda.synchronize()
timings_original.append(time.time() - start)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
print(f'{i}: {transcription}')
transcriptions.append(transcription)
i = i + 1
if i == MAX_ITER:
break
len_audio_dataset = i
assert len_audio_dataset == len(transcriptions) == len(timings_original)
@staticmethod
def fix_reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
WhisperForConditionalGeneration._reorder_cache = fix_reorder_cache
print('###################################')
# uncomment 2 following lines and comment the third one to use vanilla torch.compile instead of Kernl
# model.model.decoder.forward_original = model.model.decoder.forward
# model.model.decoder.forward = torch.compile(model.model.decoder.forward_original, mode="reduce-overhead")
optimize_model(model.model.decoder)
nb_diff = 0
timings_optimized = list()
with torch.inference_mode(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"):
start = time.time()
model.generate(inputs_warmup, min_length=max_len, max_length=max_len, num_beams=num_beams, do_sample=False)
torch.cuda.synchronize()
print(f"time to warmup: {(time.time() - start)/60:.2f}min")
i = 0
for original_modem_transcription, audio in zip(transcriptions, audio_dataset):
inputs = get_tokens(audio)
torch.cuda.synchronize()
start = time.time()
predicted_ids = model.generate(inputs, min_length=1, max_length=max_len, num_beams=num_beams, do_sample=False)
torch.cuda.synchronize()
timings_optimized.append(time.time() - start)
optimized_transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
print(f'{i}: {optimized_transcription}')
nb_diff += original_modem_transcription != optimized_transcription
i = i + 1
if i == MAX_ITER:
break
original_mins = sum(timings_original) / 60
optimized_mins = sum(timings_optimized) / 60
speedup = original_mins / optimized_mins
print(f"Kernl speedup: {speedup:.1f}X ({optimized_mins:.1f} VS {original_mins:.1f} min)")
print(f"# different outputs: {nb_diff}/{len_audio_dataset} ({nb_diff / len_audio_dataset * 100:.2f}%)")
print("\nmemory footprint:")
print(f"* allocated: {torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024:.1f}GB")
print(f"* reserved: {torch.cuda.memory_reserved(0) / 1024 / 1024 / 1024:.1f}GB")
print(f"* max reserved: {torch.cuda.max_memory_reserved(0) / 1024 / 1024 / 1024:.1f}GB")
Expected Behavior
The hope is to see the promised speed up in action.
Full log of running the script above (with just 100 samples):
Found cached dataset librispeech_asr (/home/artyom/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/cff5df6e7955c80a67f80e27e7e655de71c689e2d2364bece785b972acb37fe7)
[2023-07-25 16:39:39,587] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: 'forward' (/home/artyom/kernl/.venv/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:564)
reasons: ___check_obj_id(past_key_value, 94076869202912)
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
[2023-07-25 16:40:23,679] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: 'forward' (/home/artyom/kernl/.venv/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:346)
reasons: tensor 'past_key_value[0]' strides mismatch at index 0. expected 2560, actual 3840
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
super().capture_end()
[2023-07-25 16:41:03,368] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: '_shape' (/home/artyom/kernl/.venv/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:342)
reasons: ___check_obj_id(self, 140211086954640)
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
[2023-07-25 16:41:37,820] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: '__setitem__' (/home/artyom/kernl/.venv/lib/python3.10/site-packages/transformers/utils/generic.py:328)
reasons: tensor 'self.past_key_values[0][0]' strides mismatch at index 0. expected 1280, actual 42240
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
###################################
time to warmup: 3.43min
Kernl speedup: 1.1X (1.0 VS 1.1 min)
# different outputs: 0/100 (0.00%)
memory footprint:
* allocated: 6.4GB
* reserved: 10.1GB
* max reserved: 12.1GB
*
*
Actual Behavior
Prints errors about cache_max_size overflow, about empty CUDA graphs and no speed up is reported (often, the 'optimized' version is slower that the original).
Description
Experimental/Whisper notebook (speedup.ipynb) is not working.
When run the unmodified notebook on RTX 4080/4090 (i.e., it is using the large-v2 model), it takes a lot of time to 'optimize', but at some point it starts to print the following messages:
and then many of these follow:
Tried to increase the cache_size_limit up to 1024, the only result is that it waits much longer to print the error about the cache size and the final outcome is the same.
The final outcome: there is no any speedup, moreover, the 'optimized' variant is usually slower (well, obviously it wasn't fully optimized).
Steps to reproduce
Repro steps are pretty much the same as in the experimental/whisper/README.md:
Or, this script could be used:
Expected Behavior
The hope is to see the promised speed up in action.
Full log of running the script above (with just 100 samples):
Actual Behavior
Prints errors about cache_max_size overflow, about empty CUDA graphs and no speed up is reported (often, the 'optimized' version is slower that the original).
Your environment
Self-service
Code of Conduct
The text was updated successfully, but these errors were encountered: