Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Begin to add Contextual positional encoding #1645

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions egs/librispeech/ASR/zipformer/test_cope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python3

import torch
from zipformer import ContextualPositionalEncoding


def test():
embed_dim = 5
npos_max = 10

cope = ContextualPositionalEncoding(embed_dim=embed_dim, npos_max=npos_max)
q = torch.rand(2, 3, npos_max, embed_dim)

qk = torch.rand(2, 3, npos_max, npos_max)

p = cope(q=q, qk=qk)
print(p.shape)


def main():
test()


if __name__ == "__main__":
torch.manual_seed(20240703)
main()
91 changes: 90 additions & 1 deletion egs/librispeech/ASR/zipformer/zipformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class Zipformer2(EncoderInterface):
context chunks for causal training; will be rounded to a number of
chunks. Must not be less than cnn_module_kernel (after factoring in
rounding and downsampling); an error will be thrown if this is violated.
use_cope (bool): If true, use contextual positional encoding
"""

def __init__(
Expand All @@ -116,6 +117,7 @@ def __init__(
causal: bool = False,
chunk_size: Tuple[int] = [-1],
left_context_frames: Tuple[int] = [-1],
use_cope: bool = False,
) -> None:
super(Zipformer2, self).__init__()

Expand Down Expand Up @@ -183,6 +185,7 @@ def _to_tuple(x):
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
use_cope=use_cope,
)

if downsampling_factor[i] != 1:
Expand Down Expand Up @@ -1017,6 +1020,7 @@ def __init__(
warmup_end: float,
initial_layerdrop_rate: float = 0.5,
final_layerdrop_rate: float = 0.05,
use_cope: bool = False,
) -> None:
super().__init__()
self.encoder_pos = CompactRelPositionalEncoding(
Expand Down Expand Up @@ -1372,6 +1376,87 @@ def forward(self, src: Tensor) -> Tensor:
return src


class ContextualPositionalEncoding(torch.nn.Module):
"""
This class implements the following paper:
Contextual Position Encoding: Learning to Count What's Important
https://arxiv.org/abs/2405.18719

Args:
embed_dim: Embedding dimension.
npos_max: The maximum context size.
"""

def __init__(self, embed_dim: int, npos_max: int):
super().__init__()
self.npos_max = npos_max
self.embedding = nn.Embedding(
num_embeddings=npos_max,
embedding_dim=embed_dim,
)

def forward(self, q: torch.Tensor, qk: torch.Tensor) -> torch.Tensor:
"""
Args:
q (torch.Tensor): A tensor of shape (head, batch, time1, query_head_dim)
qk (torch.Tensor): A tensor of shape (head, batch, time1, time2)
Returns:
Return a tensor of shape (head, batch, time1, npos_max)

Note the implementation assumes time1 == time2 and npos_max <= time2.
The implementation is reasonable for the streaming ASR encoder where
only self attention is used.
"""
# The implementation on page 13 Listing 1 from the paper does not use
# a mask to ensure that only gates[:, :, i, j] where j < i is computed.
#
# Here we fix that by introducing a mask
mask = torch.triu(
torch.full((qk.size(3), qk.size(3)), True, dtype=torch.bool),
diagonal=0,
)
#
# if qk.size(3) is 4, mask is
#
# tensor([[ True, True, True, True],
# [False, True, True, True],
# [False, False, True, True],
# [False, False, False, True]])
#
# mask[i, j] is True if i >= j
gates = torch.sigmoid(qk)

# We don't use an in-place operation here for the sake of autograd
gates = gates.masked_fill(mask, 0)

# cumsum() is an inclusive sum in PyTorch
pos = gates.flip(-1).cumsum(dim=-1).flip(-1) # (head, batch, time1, time2)
# pos[:, :, i, j] should be 0 for j >= i
# pos[:, :, i, j] contains the position between i and j. If gates
# is a 0-1 matrix, then pos[:, :, i, j] equals to i - j (for j < i)
# Note: The paper says on page 4 it equals to i - j + 1 instead of i - j.

pos = pos.clamp(max=self.npos_max - 1)
pos_ceil = pos.ceil().long()
pos_floor = pos.floor().long()

# We assume query_head_dim equals to embed_dim

logits_int = torch.matmul(
q, self.embedding.weight.t()
) # (head, batch, time1, npos_max)

# We assume that npos_max <= time2
logits_cell = logits_int.gather(-1, pos_ceil)
logits_floor = logits_int.gather(-1, pos_floor)

w = pos - pos_floor

# Note: The code in the paper on page 13 is correct
# while the description on page 4 equation (5) is wrong
return logits_cell * w + logits_floor * (1 - w)


class CompactRelPositionalEncoding(torch.nn.Module):
"""
Relative positional encoding module. This version is "compact" meaning it is able to encode
Expand Down Expand Up @@ -1609,7 +1694,11 @@ def forward(
k = x[..., query_dim : 2 * query_dim]
# p is the position-encoding query
p = x[..., 2 * query_dim :]
assert p.shape[-1] == num_heads * pos_head_dim, (p.shape[-1], num_heads, pos_head_dim)
assert p.shape[-1] == num_heads * pos_head_dim, (
p.shape[-1],
num_heads,
pos_head_dim,
)

q = self.copy_query(q) # for diagnostics only, does nothing.
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
Expand Down
Loading