Skip to content

Commit

Permalink
Revert the Deberta attention optimisation (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
jimypbr authored Oct 7, 2022
1 parent 59c43f3 commit d62c4ae
Showing 1 changed file with 48 additions and 19 deletions.
67 changes: 48 additions & 19 deletions optimum/graphcore/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,39 @@
logger = logging.get_logger(__name__)


class FastGatherLastDim(nn.Module):
"""
Custom Op that does a faster specialised version of `gather`
on the last dimension of a tensor.
"""

def __init__(self):
super().__init__()

def forward(self, data, idx, target=None):
if poptorch.isRunningOnIpu():
if target is None:
target = torch.zeros_like(idx).to(data.dtype)
else:
target = target.type_as(data)

target.requires_grad_()
o = poptorch.custom_op(
[data, idx],
"FastGatherLastDim",
"poptorch.custom_ops",
1,
example_outputs=[target],
attributes={"axis": -1},
)
return o[0]
else:
return torch.gather(data, -1, idx)


gather_last_dim = FastGatherLastDim()


class XSoftmax(torch.nn.Module):
def __init__(self, dim):
super().__init__()
Expand Down Expand Up @@ -89,22 +122,6 @@ def __init__(self, config):
super().__init__(config)
self.xsoftmax = XSoftmax(-1)

def gather_p2c(self, p2c_att):
"""
Optimized position->content gather for disentangled attention
"""
bs, num_attn_heads, seq_len, _ = p2c_att.size()
p2c_att_flat = p2c_att.reshape(bs, num_attn_heads, -1)
return p2c_att_flat[:, :, seq_len:].unfold(2, seq_len, 2 * seq_len - 1)

def gather_c2p(self, c2p_att):
"""
Optimized content->position gather for disentangled attention
"""
bs, num_attn_heads, seq_len, _ = c2p_att.size()
c2p_att_flat = c2p_att.flip(3).reshape(bs, num_attn_heads, -1)
return c2p_att_flat[:, :, seq_len - 1 :].unfold(2, seq_len, 2 * seq_len - 1)

def forward(
self,
hidden_states,
Expand Down Expand Up @@ -221,19 +238,31 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd
pos_key_layer = self.pos_proj(rel_embeddings)
pos_key_layer = self.transpose_for_scores(pos_key_layer)
c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
c2p_att = self.gather_c2p(c2p_att)
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
index = c2p_pos.expand(
[query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]
)
c2p_att = gather_last_dim(c2p_att, index)
score += c2p_att

# position->content
if "p2c" in self.pos_att_type:
pos_query_layer = self.pos_q_proj(rel_embeddings)
pos_query_layer = self.transpose_for_scores(pos_query_layer)
pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor)
if query_layer.size(-2) != key_layer.size(-2):
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
else:
r_pos = relative_pos
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
index = p2c_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
p2c_att = self.gather_p2c(p2c_att).transpose(-1, -2)
p2c_att = gather_last_dim(p2c_att, index).transpose(-1, -2)

if query_layer.size(-2) != key_layer.size(-2):
p2c_att = self.gather_p2c(p2c_att)
pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
index = pos_index.expand(pos_index, p2c_att, key_layer)
p2c_att = gather_last_dim(p2c_att, index)
score += p2c_att

return score
Expand Down

0 comments on commit d62c4ae

Please sign in to comment.