Skip to content

Commit

Permalink
more efficient to select out the frozen codes and then do projection …
Browse files Browse the repository at this point in the history
…in `indices_to_codes` for SimVQ
  • Loading branch information
lucidrains committed Nov 11, 2024
1 parent a0e8f2c commit 8f5b428
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.20.3"
version = "1.20.4"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
7 changes: 4 additions & 3 deletions vector_quantize_pytorch/sim_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
if not exists(codebook_transform):
codebook_transform = nn.Linear(dim, dim, bias = False)

self.codebook_to_codes = codebook_transform
self.code_transform = codebook_transform

self.register_buffer('frozen_codebook', codebook)

Expand All @@ -72,15 +72,16 @@ def __init__(

@property
def codebook(self):
return self.codebook_to_codes(self.frozen_codebook)
return self.code_transform(self.frozen_codebook)

def indices_to_codes(
self,
indices
):
implicit_codebook = self.codebook

quantized = get_at('[c] d, b ... -> b ... d', implicit_codebook, indices)
frozen_codes = get_at('[c] d, b ... -> b ... d', self.frozen_codebook, indices)
quantized = self.code_transform(frozen_codes)

if self.accept_image_fmap:
quantized = rearrange(quantized, 'b ... d -> b d ...')
Expand Down

0 comments on commit 8f5b428

Please sign in to comment.