Skip to content

Commit

Permalink
examples : add embd_to_audio to tts-outetts.py [no ci] (ggerganov#11235)
Browse files Browse the repository at this point in the history
This commit contains a suggestion for adding the missing embd_to_audio
function from tts.cpp to tts-outetts.py. This introduces a depencency
numpy which I was not sure if that is acceptable or not (only PyTorch
was mentioned in referened PR).

Also the README has been updated with instructions to run the example
with llama-server and the python script.

Refs: ggerganov#10784 (comment)
  • Loading branch information
danbev authored Jan 15, 2025
1 parent f446c2c commit 0ccd7f3
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 2 deletions.
37 changes: 37 additions & 0 deletions examples/tts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,40 @@ play the audio:
$ aplay output.wav
```

### Running the example with llama-server
Running this example with `llama-server` is also possible and requires two
server instances to be started. One will serve the LLM model and the other
will serve the voice decoder model.

The LLM model server can be started with the following command:
```console
$ ./build/bin/llama-server -m ./models/outetts-0.2-0.5B-q8_0.gguf --port 8020
```

And the voice decoder model server can be started using:
```console
./build/bin/llama-server -m ./models/wavtokenizer-large-75-f16.gguf --port 8021 --embeddings --pooling none
```

Then we can run [tts-outetts.py](tts-outetts.py) to generate the audio.

First create a virtual environment for python and install the required
dependencies (this in only required to be done once):
```console
$ python3 -m venv venv
$ source venv/bin/activate
(venv) pip install requests numpy
```

And then run the python script using:
```conole
(venv) python ./examples/tts/tts-outetts.py http://localhost:8020 http://localhost:8021 "Hello world"
spectrogram generated: n_codes: 90, n_embd: 1282
converting to audio ...
audio generated: 28800 samples
audio written to file "output.wav"
```
And to play the audio we can again use aplay or any other media player:
```console
$ aplay output.wav
```
128 changes: 126 additions & 2 deletions examples/tts/tts-outetts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,121 @@
#import struct
import requests
import re
import struct
import numpy as np
from concurrent.futures import ThreadPoolExecutor


def fill_hann_window(size, periodic=True):
if periodic:
return np.hanning(size + 1)[:-1]
return np.hanning(size)


def irfft(n_fft, complex_input):
return np.fft.irfft(complex_input, n=n_fft)


def fold(buffer, n_out, n_win, n_hop, n_pad):
result = np.zeros(n_out)
n_frames = len(buffer) // n_win

for i in range(n_frames):
start = i * n_hop
end = start + n_win
result[start:end] += buffer[i * n_win:(i + 1) * n_win]

return result[n_pad:-n_pad] if n_pad > 0 else result


def process_frame(args):
l, n_fft, ST, hann = args
frame = irfft(n_fft, ST[l])
frame = frame * hann
hann2 = hann * hann
return frame, hann2


def embd_to_audio(embd, n_codes, n_embd, n_thread=4):
embd = np.asarray(embd, dtype=np.float32).reshape(n_codes, n_embd)

n_fft = 1280
n_hop = 320
n_win = 1280
n_pad = (n_win - n_hop) // 2
n_out = (n_codes - 1) * n_hop + n_win

hann = fill_hann_window(n_fft, True)

E = np.zeros((n_embd, n_codes), dtype=np.float32)
for l in range(n_codes):
for k in range(n_embd):
E[k, l] = embd[l, k]

half_embd = n_embd // 2
S = np.zeros((n_codes, half_embd + 1), dtype=np.complex64)

for k in range(half_embd):
for l in range(n_codes):
mag = E[k, l]
phi = E[k + half_embd, l]

mag = np.clip(np.exp(mag), 0, 1e2)
S[l, k] = mag * np.exp(1j * phi)

res = np.zeros(n_codes * n_fft)
hann2_buffer = np.zeros(n_codes * n_fft)

with ThreadPoolExecutor(max_workers=n_thread) as executor:
args = [(l, n_fft, S, hann) for l in range(n_codes)]
results = list(executor.map(process_frame, args))

for l, (frame, hann2) in enumerate(results):
res[l*n_fft:(l+1)*n_fft] = frame
hann2_buffer[l*n_fft:(l+1)*n_fft] = hann2

audio = fold(res, n_out, n_win, n_hop, n_pad)
env = fold(hann2_buffer, n_out, n_win, n_hop, n_pad)

mask = env > 1e-10
audio[mask] /= env[mask]

return audio


def save_wav(filename, audio_data, sample_rate):
num_channels = 1
bits_per_sample = 16
bytes_per_sample = bits_per_sample // 8
data_size = len(audio_data) * bytes_per_sample
byte_rate = sample_rate * num_channels * bytes_per_sample
block_align = num_channels * bytes_per_sample
chunk_size = 36 + data_size # 36 = size of header minus first 8 bytes

header = struct.pack(
'<4sI4s4sIHHIIHH4sI',
b'RIFF',
chunk_size,
b'WAVE',
b'fmt ',
16, # fmt chunk size
1, # audio format (PCM)
num_channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
b'data',
data_size
)

audio_data = np.clip(audio_data * 32767, -32768, 32767)
pcm_data = audio_data.astype(np.int16)

with open(filename, 'wb') as f:
f.write(header)
f.write(pcm_data.tobytes())


def process_text(text: str):
text = re.sub(r'\d+(\.\d+)?', lambda x: x.group(), text.lower()) # TODO this needs to be fixed
Expand Down Expand Up @@ -170,6 +285,15 @@ def process_text(text: str):
print('spectrogram generated: n_codes: %d, n_embd: %d' % (n_codes, n_embd))

# post-process the spectrogram to convert to audio
# TODO: see the tts.cpp:embd_to_audio() and implement it in Python
print('converting to audio ...')
print('TODO: see the tts.cpp:embd_to_audio() and implement it in Python')
audio = embd_to_audio(embd, n_codes, n_embd)
print('audio generated: %d samples' % len(audio))

filename = "output.wav"
sample_rate = 24000 # sampling rate

# zero out first 0.25 seconds
audio[:24000 // 4] = 0.0

save_wav(filename, audio, sample_rate)
print('audio written to file "%s"' % filename)

0 comments on commit 0ccd7f3

Please sign in to comment.