forked from ZhangXInFD/SpeechTokenizer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b49a95a
commit 8620dfe
Showing
15 changed files
with
1,817 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,90 @@ | ||
# SpeechTokenizer: Unified Speech Tokenizer for Speech Large Language Models | ||
|
||
<a href='https://0nutation.github.io/SpeechTokenizer.github.io/'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://github.com/0nutation/SpeechTokenizer'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> | ||
<a href='https://github.com/ZhangXInFD/SpeechTokenizer'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://github.com/0nutation/SpeechTokenizer'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> | ||
|
||
## Introduction | ||
SpeechTokenizer is a unified speech tokenizer for speech large language models. SpeechTokenizer adopts the Encoder-Decoder architecture with residual vector quantization (RVQ). Unifying semantic and acoustic tokens, SpeechTokenizer disentangles different aspects of speech information hierarchically across different RVQ layers. Leveraging SpeechTokenizer, we construct a Unified Speech Language Model (USLM). | ||
This is the code for the SpeechTokenizer presented in the [SpeechTokenizer: Unified Speech Tokenizer for Speech Large Language Models](https://0nutation.github.io/SpeechTokenizer.github.io/). SpeechTokenizer is a unified speech tokenizer for speech large language models, which adopts the Encoder-Decoder architecture with residual vector quantization (RVQ). Unifying semantic and acoustic tokens, SpeechTokenizer disentangles different aspects of speech information hierarchically across different RVQ layers. Specifically, The code indices that the first quantizer of RVQ outputs can be considered as semantic tokens and the output of the remaining quantizers can be regarded as acoustic tokens, which serve as supplements for the information lost by the first quantizer. We provide our models: | ||
* A model operated at 16khz on monophonic speech trained on Librispeech with average representation across all HuBERT layers as semantic teachers. | ||
|
||
<br> | ||
<br> | ||
<p align="center"> | ||
<img src="images/overview.png" width="95%"> <br> | ||
Overview | ||
</p> | ||
<p align="center"> | ||
<img src="images/speechtokenizer_framework.jpg" width="95%"> <br> | ||
The SpeechTokenizer framework. | ||
</p> | ||
<br> | ||
|
||
<b> Welcome to try our [SLMTokBench](https://github.com/0nutation/SLMTokBench) <b> | ||
<b> and we will also open source our [USLM](https://github.com/0nutation/USLM) !! <b> | ||
|
||
<b> Code and models will be available soon!! <b> | ||
Welcome to try our [SLMTokBench](https://github.com/0nutation/SLMTokBench) | ||
and we will also open source our [USLM](https://github.com/0nutation/USLM) !! | ||
|
||
|
||
|
||
## Samples | ||
|
||
Samples are provided on [our demo page](https://github.com/0nutation/SpeechTokenizer). | ||
|
||
## Installation | ||
|
||
SpeechTokenizer requires Python>=3.8, and a reasonly recent version of PyTorch. | ||
To install SpeechTokenizer, you can run from this repository: | ||
```bash | ||
# pip install -U speechtokenizer | ||
# git clone the repo locally | ||
cd SpeechTokenizer | ||
pip install . | ||
``` | ||
## Usage | ||
### Model storage | ||
[model list]() | ||
### load model | ||
```python | ||
from speechtokenizer import SpeechTokenizer | ||
|
||
config_path = /path/config.json | ||
ckpt_path = /path/SpeechTokenizer.pt | ||
model = SpeechTokenizer.load_from_checkpoint(cfg_path, ckpt_path) | ||
model.eval() | ||
``` | ||
### Extracting discrete representions | ||
```python | ||
import torchaudio | ||
import torch | ||
|
||
# Load and pre-process speech waveform | ||
wav, sr = torchaudio.load('<SPEECH_FILE_PATH>') | ||
if sr != model.sample_rate: | ||
wav = torchaudio.functional.resample(wav, sr, model.sample_rate) | ||
wav = wav.unsqueeze(0) | ||
|
||
# Extract discrete codes from SpeechTokenizer | ||
with torch.no_grad(): | ||
codes = model.encode(wav) # codes: (n_q, B, T) | ||
|
||
semantic_tokens = codes[0, :, :] | ||
acoustic_tokens = codes[1:, :, :] | ||
``` | ||
|
||
### Decoding discrete representions | ||
```python | ||
# Decoding from the first quantizers to ith quantizers | ||
wav = model.decode(codes[:(i + 1)]) # wav: (B, 1, T) | ||
|
||
# Decoding from ith quantizers to jth quantizers | ||
wav = model.decode(codes[i: (j + 1)]) | ||
|
||
# Cancatenating semantic tokens and acoustic tokens and then decoding | ||
semantic_tokens = ... # (..., B, T) | ||
acoustic_tokens = ... # (..., B, T) | ||
wav = model.decode([semantic_tokens, acoustic_tokens], axis=0) | ||
``` | ||
|
||
## Citation | ||
If you use this code or result in your paper, please cite our work as: | ||
|
||
## License | ||
The code in this repository is released under the Apache 2.0 license as found in the | ||
[LICENSE](LICENSE) file. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from pathlib import Path | ||
from setuptools import setup | ||
|
||
NAME = 'speechtokenizer' | ||
DESCRIPTION = 'Unified speech tokenizer for speech language model' | ||
URL = 'https://github.com/ZhangXInFD/SpeechTokenizer' | ||
EMAIL = '[email protected]' | ||
AUTHOR = 'Xin Zhang, Don Zhang, Simin Li, Yaqian Zhou, Xipeng Qiu' | ||
REQUIRES_PYTHON = '>=3.8.0' | ||
|
||
|
||
for line in open('speechtokenizer/__init__.py'): | ||
line = line.strip() | ||
if '__version__' in line: | ||
context = {} | ||
exec(line, context) | ||
VERSION = context['__version__'] | ||
|
||
HERE = Path(__file__).parent | ||
|
||
try: | ||
with open(HERE / "README.md", encoding='utf-8') as f: | ||
long_description = '\n' + f.read() | ||
except FileNotFoundError: | ||
long_description = DESCRIPTION | ||
|
||
setup( | ||
name=NAME, | ||
version=VERSION, | ||
description=DESCRIPTION, | ||
long_description=long_description, | ||
long_description_content_type='text/markdown', | ||
author=AUTHOR, | ||
author_email=EMAIL, | ||
python_requires=REQUIRES_PYTHON, | ||
url=URL, | ||
packages=['speechtokenizer', 'speechtokenizer.quantization', 'speechtokenizer.modules'], | ||
# extras_require={ | ||
# 'dev': ['flake8', 'mypy', 'pdoc3'], | ||
# }, | ||
install_requires=['numpy', 'torch', 'torchaudio', 'einops'], | ||
include_package_data=True, | ||
license='Apache License 2.0', | ||
classifiers=[ | ||
'Topic :: Multimedia :: Sound/Audio', | ||
'Topic :: Scientific/Engineering :: Artificial Intelligence', | ||
'License :: Apache License 2.0', | ||
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .model import SpeechTokenizer | ||
|
||
__version__ = '0.1.1' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Wed Aug 30 15:47:55 2023 | ||
@author: zhangxin | ||
""" | ||
|
||
from .modules.seanet import SEANetEncoder, SEANetDecoder | ||
from .quantization import ResidualVectorQuantizer | ||
import torch.nn as nn | ||
from einops import rearrange | ||
import torch | ||
|
||
class SpeechTokenizer(nn.Module): | ||
def __init__(self, config): | ||
''' | ||
Parameters | ||
---------- | ||
config : json | ||
Model Config. | ||
''' | ||
super().__init__() | ||
self.encoder = SEANetEncoder(n_filters=config.get('n_filters'), | ||
dimension=config.get('dimension'), | ||
ratios=config.get('strides'), | ||
lstm=config.get('lstm_layers'), | ||
bidirectional=config.get('bidirectional'), | ||
dilation_base=config.get('dilation_base'), | ||
residual_kernel_size=config.get('residual_kernel_size'), | ||
n_residual_layers=config.get('n_residual_layers'), | ||
activation=config.get('activation')) | ||
self.sample_rate = config.get('sample_rate') | ||
self.n_q = config.get('n_q') | ||
if config.get('dimension') != config.get('semantic_dimension'): | ||
self.transform = nn.Linear(config.get('dimension'), config.get('semantic_dimension')) | ||
else: | ||
self.transform = nn.Identity() | ||
self.quantizer = ResidualVectorQuantizer(dimension=config.get('dimension'), n_q=config.get('n_q'), bins=config.get('codebook_size')) | ||
self.decoder = SEANetDecoder(n_filters=config.get('n_filters'), | ||
dimension=config.get('dimension'), | ||
ratios=config.get('strides'), | ||
lstm=config.get('lstm_layers'), | ||
bidirectional=False, | ||
dilation_base=config.get('dilation_base'), | ||
residual_kernel_size=config.get('residual_kernel_size'), | ||
n_residual_layers=config.get('n_residual_layers'), | ||
activation=config.get('activation')) | ||
|
||
@classmethod | ||
def load_from_checkpoint(cls, | ||
config_path: str, | ||
ckpt_path: str): | ||
''' | ||
Parameters | ||
---------- | ||
config_path : str | ||
Path of model configuration file. | ||
ckpt_path : str | ||
Path of model checkpoint. | ||
Returns | ||
------- | ||
model : SpeechTokenizer | ||
SpeechTokenizer model. | ||
''' | ||
import json | ||
with open(config_path) as f: | ||
cfg = json.load(f) | ||
model = cls(cfg) | ||
params = torch.load(ckpt_path, map_location='cpu') | ||
model.load_state_dict(params) | ||
return model | ||
|
||
|
||
def forward(self, | ||
x: torch.tensor, | ||
n_q: int=None, | ||
layers: list=[0]): | ||
''' | ||
Parameters | ||
---------- | ||
x : torch.tensor | ||
Input wavs. Shape: (batch, channels, timesteps). | ||
n_q : int, optional | ||
Number of quantizers in RVQ used to encode. The default is all layers. | ||
layers : list[int], optional | ||
Layers of RVQ should return quantized result. The default is the first layer. | ||
Returns | ||
------- | ||
o : torch.tensor | ||
Output wavs. Shape: (batch, channels, timesteps). | ||
commit_loss : torch.tensor | ||
Commitment loss from residual vector quantizers. | ||
feature : torch.tensor | ||
Output of RVQ's first layer. Shape: (batch, timesteps, dimension) | ||
''' | ||
n_q = n_q if n_q else self.n_q | ||
e = self.encoder(x) | ||
quantized, codes, commit_loss, quantized_list = self.quantizer(e, n_q=n_q, layers=layers) | ||
feature = rearrange(quantized_list[0], 'b d t -> b t d') | ||
feature = self.transform(feature) | ||
o = self.decoder(quantized) | ||
return o, commit_loss, feature | ||
|
||
def forward_feature(self, | ||
x: torch.tensor, | ||
layers: list=None): | ||
''' | ||
Parameters | ||
---------- | ||
x : torch.tensor | ||
Input wavs. Shape should be (batch, channels, timesteps). | ||
layers : list[int], optional | ||
Layers of RVQ should return quantized result. The default is all layers. | ||
Returns | ||
------- | ||
quantized_list : list[torch.tensor] | ||
Quantized of required layers. | ||
''' | ||
e = self.encoder(x) | ||
layers = layers if layers else list(range(self.n_q)) | ||
quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers) | ||
return quantized_list | ||
|
||
def encode(self, | ||
x: torch.tensor, | ||
n_q: int=None, | ||
st: int=None): | ||
''' | ||
Parameters | ||
---------- | ||
x : torch.tensor | ||
Input wavs. Shape: (batch, channels, timesteps). | ||
n_q : int, optional | ||
Number of quantizers in RVQ used to encode. The default is all layers. | ||
st : int, optional | ||
Start quantizer index in RVQ. The default is 0. | ||
Returns | ||
------- | ||
codes : torch.tensor | ||
Output indices for each quantizer. Shape: (n_q, batch, timesteps) | ||
''' | ||
e = self.encoder(x) | ||
if st is None: | ||
st = 0 | ||
n_q = n_q if n_q else self.n_q | ||
codes = self.quantizer.encode(e, n_q=n_q, st=st) | ||
return codes | ||
|
||
def decode(self, | ||
codes: torch.tensor, | ||
st: int=0): | ||
''' | ||
Parameters | ||
---------- | ||
codes : torch.tensor | ||
Indices for each quantizer. Shape: (n_q, batch, timesteps). | ||
st : int, optional | ||
Start quantizer index in RVQ. The default is 0. | ||
Returns | ||
------- | ||
o : torch.tensor | ||
Reconstruct wavs from codes. Shape: (batch, channels, timesteps) | ||
''' | ||
quantized = self.quantizer.decode(codes, st=st) | ||
o = self.decoder(quantized) | ||
return o |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Torch modules.""" | ||
|
||
# flake8: noqa | ||
from .conv import ( | ||
pad1d, | ||
unpad1d, | ||
NormConv1d, | ||
NormConvTranspose1d, | ||
NormConv2d, | ||
NormConvTranspose2d, | ||
SConv1d, | ||
SConvTranspose1d, | ||
) | ||
from .lstm import SLSTM | ||
from .seanet import SEANetEncoder, SEANetDecoder |
Oops, something went wrong.