Skip to content

Commit

Permalink
Add main components
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangXInFD committed Aug 31, 2023
1 parent b49a95a commit 8620dfe
Show file tree
Hide file tree
Showing 15 changed files with 1,817 additions and 6 deletions.
83 changes: 77 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,90 @@
# SpeechTokenizer: Unified Speech Tokenizer for Speech Large Language Models

<a href='https://0nutation.github.io/SpeechTokenizer.github.io/'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://github.com/0nutation/SpeechTokenizer'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
<a href='https://github.com/ZhangXInFD/SpeechTokenizer'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://github.com/0nutation/SpeechTokenizer'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>

## Introduction
SpeechTokenizer is a unified speech tokenizer for speech large language models. SpeechTokenizer adopts the Encoder-Decoder architecture with residual vector quantization (RVQ). Unifying semantic and acoustic tokens, SpeechTokenizer disentangles different aspects of speech information hierarchically across different RVQ layers. Leveraging SpeechTokenizer, we construct a Unified Speech Language Model (USLM).
This is the code for the SpeechTokenizer presented in the [SpeechTokenizer: Unified Speech Tokenizer for Speech Large Language Models](https://0nutation.github.io/SpeechTokenizer.github.io/). SpeechTokenizer is a unified speech tokenizer for speech large language models, which adopts the Encoder-Decoder architecture with residual vector quantization (RVQ). Unifying semantic and acoustic tokens, SpeechTokenizer disentangles different aspects of speech information hierarchically across different RVQ layers. Specifically, The code indices that the first quantizer of RVQ outputs can be considered as semantic tokens and the output of the remaining quantizers can be regarded as acoustic tokens, which serve as supplements for the information lost by the first quantizer. We provide our models:
* A model operated at 16khz on monophonic speech trained on Librispeech with average representation across all HuBERT layers as semantic teachers.

<br>
<br>
<p align="center">
<img src="images/overview.png" width="95%"> <br>
Overview
</p>
<p align="center">
<img src="images/speechtokenizer_framework.jpg" width="95%"> <br>
The SpeechTokenizer framework.
</p>
<br>

<b> Welcome to try our [SLMTokBench](https://github.com/0nutation/SLMTokBench) <b>
<b> and we will also open source our [USLM](https://github.com/0nutation/USLM) !! <b>

<b> Code and models will be available soon!! <b>
Welcome to try our [SLMTokBench](https://github.com/0nutation/SLMTokBench)
and we will also open source our [USLM](https://github.com/0nutation/USLM) !!



## Samples

Samples are provided on [our demo page](https://github.com/0nutation/SpeechTokenizer).

## Installation

SpeechTokenizer requires Python>=3.8, and a reasonly recent version of PyTorch.
To install SpeechTokenizer, you can run from this repository:
```bash
# pip install -U speechtokenizer
# git clone the repo locally
cd SpeechTokenizer
pip install .
```
## Usage
### Model storage
[model list]()
### load model
```python
from speechtokenizer import SpeechTokenizer

config_path = /path/config.json
ckpt_path = /path/SpeechTokenizer.pt
model = SpeechTokenizer.load_from_checkpoint(cfg_path, ckpt_path)
model.eval()
```
### Extracting discrete representions
```python
import torchaudio
import torch

# Load and pre-process speech waveform
wav, sr = torchaudio.load('<SPEECH_FILE_PATH>')
if sr != model.sample_rate:
wav = torchaudio.functional.resample(wav, sr, model.sample_rate)
wav = wav.unsqueeze(0)

# Extract discrete codes from SpeechTokenizer
with torch.no_grad():
codes = model.encode(wav) # codes: (n_q, B, T)

semantic_tokens = codes[0, :, :]
acoustic_tokens = codes[1:, :, :]
```

### Decoding discrete representions
```python
# Decoding from the first quantizers to ith quantizers
wav = model.decode(codes[:(i + 1)]) # wav: (B, 1, T)

# Decoding from ith quantizers to jth quantizers
wav = model.decode(codes[i: (j + 1)])

# Cancatenating semantic tokens and acoustic tokens and then decoding
semantic_tokens = ... # (..., B, T)
acoustic_tokens = ... # (..., B, T)
wav = model.decode([semantic_tokens, acoustic_tokens], axis=0)
```

## Citation
If you use this code or result in your paper, please cite our work as:

## License
The code in this repository is released under the Apache 2.0 license as found in the
[LICENSE](LICENSE) file.
Binary file added images/speechtokenizer_framework.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
48 changes: 48 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from pathlib import Path
from setuptools import setup

NAME = 'speechtokenizer'
DESCRIPTION = 'Unified speech tokenizer for speech language model'
URL = 'https://github.com/ZhangXInFD/SpeechTokenizer'
EMAIL = '[email protected]'
AUTHOR = 'Xin Zhang, Don Zhang, Simin Li, Yaqian Zhou, Xipeng Qiu'
REQUIRES_PYTHON = '>=3.8.0'


for line in open('speechtokenizer/__init__.py'):
line = line.strip()
if '__version__' in line:
context = {}
exec(line, context)
VERSION = context['__version__']

HERE = Path(__file__).parent

try:
with open(HERE / "README.md", encoding='utf-8') as f:
long_description = '\n' + f.read()
except FileNotFoundError:
long_description = DESCRIPTION

setup(
name=NAME,
version=VERSION,
description=DESCRIPTION,
long_description=long_description,
long_description_content_type='text/markdown',
author=AUTHOR,
author_email=EMAIL,
python_requires=REQUIRES_PYTHON,
url=URL,
packages=['speechtokenizer', 'speechtokenizer.quantization', 'speechtokenizer.modules'],
# extras_require={
# 'dev': ['flake8', 'mypy', 'pdoc3'],
# },
install_requires=['numpy', 'torch', 'torchaudio', 'einops'],
include_package_data=True,
license='Apache License 2.0',
classifiers=[
'Topic :: Multimedia :: Sound/Audio',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: Apache License 2.0',
])
3 changes: 3 additions & 0 deletions speechtokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .model import SpeechTokenizer

__version__ = '0.1.1'
182 changes: 182 additions & 0 deletions speechtokenizer/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 15:47:55 2023
@author: zhangxin
"""

from .modules.seanet import SEANetEncoder, SEANetDecoder
from .quantization import ResidualVectorQuantizer
import torch.nn as nn
from einops import rearrange
import torch

class SpeechTokenizer(nn.Module):
def __init__(self, config):
'''
Parameters
----------
config : json
Model Config.
'''
super().__init__()
self.encoder = SEANetEncoder(n_filters=config.get('n_filters'),
dimension=config.get('dimension'),
ratios=config.get('strides'),
lstm=config.get('lstm_layers'),
bidirectional=config.get('bidirectional'),
dilation_base=config.get('dilation_base'),
residual_kernel_size=config.get('residual_kernel_size'),
n_residual_layers=config.get('n_residual_layers'),
activation=config.get('activation'))
self.sample_rate = config.get('sample_rate')
self.n_q = config.get('n_q')
if config.get('dimension') != config.get('semantic_dimension'):
self.transform = nn.Linear(config.get('dimension'), config.get('semantic_dimension'))
else:
self.transform = nn.Identity()
self.quantizer = ResidualVectorQuantizer(dimension=config.get('dimension'), n_q=config.get('n_q'), bins=config.get('codebook_size'))
self.decoder = SEANetDecoder(n_filters=config.get('n_filters'),
dimension=config.get('dimension'),
ratios=config.get('strides'),
lstm=config.get('lstm_layers'),
bidirectional=False,
dilation_base=config.get('dilation_base'),
residual_kernel_size=config.get('residual_kernel_size'),
n_residual_layers=config.get('n_residual_layers'),
activation=config.get('activation'))

@classmethod
def load_from_checkpoint(cls,
config_path: str,
ckpt_path: str):
'''
Parameters
----------
config_path : str
Path of model configuration file.
ckpt_path : str
Path of model checkpoint.
Returns
-------
model : SpeechTokenizer
SpeechTokenizer model.
'''
import json
with open(config_path) as f:
cfg = json.load(f)
model = cls(cfg)
params = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(params)
return model


def forward(self,
x: torch.tensor,
n_q: int=None,
layers: list=[0]):
'''
Parameters
----------
x : torch.tensor
Input wavs. Shape: (batch, channels, timesteps).
n_q : int, optional
Number of quantizers in RVQ used to encode. The default is all layers.
layers : list[int], optional
Layers of RVQ should return quantized result. The default is the first layer.
Returns
-------
o : torch.tensor
Output wavs. Shape: (batch, channels, timesteps).
commit_loss : torch.tensor
Commitment loss from residual vector quantizers.
feature : torch.tensor
Output of RVQ's first layer. Shape: (batch, timesteps, dimension)
'''
n_q = n_q if n_q else self.n_q
e = self.encoder(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(e, n_q=n_q, layers=layers)
feature = rearrange(quantized_list[0], 'b d t -> b t d')
feature = self.transform(feature)
o = self.decoder(quantized)
return o, commit_loss, feature

def forward_feature(self,
x: torch.tensor,
layers: list=None):
'''
Parameters
----------
x : torch.tensor
Input wavs. Shape should be (batch, channels, timesteps).
layers : list[int], optional
Layers of RVQ should return quantized result. The default is all layers.
Returns
-------
quantized_list : list[torch.tensor]
Quantized of required layers.
'''
e = self.encoder(x)
layers = layers if layers else list(range(self.n_q))
quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers)
return quantized_list

def encode(self,
x: torch.tensor,
n_q: int=None,
st: int=None):
'''
Parameters
----------
x : torch.tensor
Input wavs. Shape: (batch, channels, timesteps).
n_q : int, optional
Number of quantizers in RVQ used to encode. The default is all layers.
st : int, optional
Start quantizer index in RVQ. The default is 0.
Returns
-------
codes : torch.tensor
Output indices for each quantizer. Shape: (n_q, batch, timesteps)
'''
e = self.encoder(x)
if st is None:
st = 0
n_q = n_q if n_q else self.n_q
codes = self.quantizer.encode(e, n_q=n_q, st=st)
return codes

def decode(self,
codes: torch.tensor,
st: int=0):
'''
Parameters
----------
codes : torch.tensor
Indices for each quantizer. Shape: (n_q, batch, timesteps).
st : int, optional
Start quantizer index in RVQ. The default is 0.
Returns
-------
o : torch.tensor
Reconstruct wavs from codes. Shape: (batch, channels, timesteps)
'''
quantized = self.quantizer.decode(codes, st=st)
o = self.decoder(quantized)
return o
21 changes: 21 additions & 0 deletions speechtokenizer/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""Torch modules."""

# flake8: noqa
from .conv import (
pad1d,
unpad1d,
NormConv1d,
NormConvTranspose1d,
NormConv2d,
NormConvTranspose2d,
SConv1d,
SConvTranspose1d,
)
from .lstm import SLSTM
from .seanet import SEANetEncoder, SEANetDecoder
Loading

0 comments on commit 8620dfe

Please sign in to comment.