Skip to content

Commit

Permalink
enable whisper batch for long sequences (tinygrad#6458)
Browse files Browse the repository at this point in the history
* long batch +test

* long batch +test

* cleanup

* rollback syntactic changes

---------

Co-authored-by: chenyu <[email protected]>
  • Loading branch information
DKormann and chenyuxyz authored Sep 17, 2024
1 parent 7c94241 commit f5dd25d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 40 deletions.
67 changes: 29 additions & 38 deletions examples/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def pad_or_trim(arr, target_len):
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes

log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = np.maximum(log_spec, log_spec.max((1,2), keepdims=True) - 8.0)
log_spec = (log_spec + 4.0) / 4.0

return log_spec
Expand Down Expand Up @@ -241,55 +241,46 @@ def transcribe_waveform(model: Whisper, enc, waveforms, truncate=False):
Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided
"""
N_audio = len(waveforms)
log_spec = prep_audio(waveforms, model.batch_size, truncate)

if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1:
# we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch
# if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent
raise Exception("Multi-segment transcription not supported with batch audio input")

log_spec = prep_audio(waveforms, model.batch_size, truncate)
nsample = model.decoder.max_tokens_to_sample

def inferloop(ctx: Union[np.ndarray, List[np.ndarray]], encoded_audio):
pos, next_tokens = 0, ctx
for i in range((nsample-len(start_tokens))*2):
next_tokens = model.decoder(Tensor(next_tokens), pos, encoded_audio)[:, -1].argmax(axis=-1).numpy().astype(np.int32).reshape(-1, 1)
next_tokens[ctx[:, -1] == eot] = eot
ctx = np.concatenate((ctx, next_tokens), axis=1)
pos = ctx.shape[-1] - 1
if (next_tokens == eot).all(): break
return ctx

def gettexttoks(line): return [tok for tok in line if tok < eot or tok > enc._special_tokens["<|notimestamps|>"]][-nsample+len(start_tokens):]
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
if model.is_multilingual:
# TODO detect language
language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index("en")
start_tokens.append(language_token)
start_tokens.append(enc._special_tokens["<|transcribe|>"])
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
transcription_start_index = len(start_tokens)

eot = enc._special_tokens["<|endoftext|>"]
transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0]

ctx = np.tile(start_tokens, (model.batch_size,1))
transcriptions = [[] for _ in waveforms]

for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT]))
pos = 0
curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
if curr_frame > 0:
# pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt = np.concatenate((
[enc._special_tokens["<|startofprev|>"]],
transcription_tokens[0][-model.decoder.max_tokens_to_sample+1:],
start_tokens))
curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1))
transcription_start_index = len(curr_segment_tokens[0])

for i in range(model.decoder.max_tokens_to_sample):
out = model.decoder.forward(Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]), pos, None if i else encoded_audio)
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
next_tokens[curr_segment_tokens[:, -1] == eot] = eot
curr_segment_tokens = np.concatenate((curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1)
pos = curr_segment_tokens.shape[-1] - 1
if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens)))
if (curr_segment_tokens[:, -1] == eot).all():
break

for i, t in enumerate(curr_segment_tokens):
eot_index = np.where(t == eot)[0]
eot_index = None if len(eot_index) == 0 else eot_index[0]
transcription_tokens[i] = np.concatenate((transcription_tokens[i], t[transcription_start_index:eot_index]))

transcriptions = list(map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens))
return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]

if all(len(c) == len(ctx[0]) for c in ctx): ctx = inferloop(np.array(ctx), encoded_audio)
else: ctx = [inferloop((np.array([c]*model.batch_size)), encoded_audio)[i] for i,c in enumerate(ctx)]

for i, (res, arr) in enumerate(zip(transcriptions, ctx)):
if curr_frame*HOP_LENGTH <= len(waveforms[i]):res.extend(arr[np.where(arr == start_tokens[-1])[0][0]+1:eoti[0] if len (eoti:=np.where(arr == eot)[0]) else None])
ctx = [[enc._special_tokens['<|startofprev|>']]+gettexttoks(cs)+start_tokens for cs in ctx]

transcriptions = list(map(lambda tokens: enc.decode(tokens).strip(), transcriptions))
return transcriptions if len(transcriptions) > 1 else transcriptions[0]

CHUNK = 1600
RECORD_SECONDS = 10
Expand Down
7 changes: 5 additions & 2 deletions test/models/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,11 @@ def test_transcribe_long(self):
@unittest.skipIf(CI, "too long for CI")
def test_transcribe_long_no_batch(self):
waveforms = [load_file_waveform(fetch(TEST_FILE_3_URL)), load_file_waveform(TEST_FILE_1)]
with self.assertRaises(Exception):
transcribe_waveform(self.model, self.enc, waveforms)

trancriptions = transcribe_waveform(self.model, self.enc, waveforms)
self.assertEqual(2, len(trancriptions))
self.assertEqual(TRANSCRIPTION_3, trancriptions[0])
self.assertEqual(TRANSCRIPTION_1, trancriptions[1])

if __name__ == '__main__':
unittest.main()

0 comments on commit f5dd25d

Please sign in to comment.