Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix commitment loss #367

Merged
merged 3 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

Adding stereo models.

Fixed the commitment loss, which was until now only applied to the first RVQ layer.


## [1.1.0] - 2023-11-06

Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can ru
```shell
# Best to make sure you have torch installed first, in particular before installing xformers.
# Don't run this if you already have PyTorch installed.
python -m pip install 'torch>=2.0'
python -m pip install 'torch==2.1.0'
# You might need the following before trying to install the packages
python -m pip install setuptools wheel
# Then proceed to one of the following
python -m pip install -U audiocraft # stable release
python -m pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
Expand Down
2 changes: 1 addition & 1 deletion audiocraft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
# flake8: noqa
from . import data, modules, models

__version__ = '1.2.0a1'
__version__ = '1.2.0a2'
5 changes: 5 additions & 0 deletions audiocraft/quantization/core_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,16 @@ def forward(self, x, n_q: tp.Optional[int] = None):

for i, layer in enumerate(self.layers[:n_q]):
quantized, indices, loss = layer(residual)
quantized = quantized.detach()
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)

if self.training:
# Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25
quantized_out = x + (quantized_out - x).detach()

out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses

Expand Down
4 changes: 3 additions & 1 deletion tests/quantization/test_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
class TestResidualVectorQuantizer:

def test_rvq(self):
x = torch.randn(1, 16, 2048)
x = torch.randn(1, 16, 2048, requires_grad=True)
vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8)
res = vq(x, 1.)
assert res.x.shape == torch.Size([1, 16, 2048])
res.x.sum().backward()
assert torch.allclose(x.grad.data, torch.ones(1))
Loading