Skip to content

Commit

Permalink
handle xla graph breaks for layout
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Dec 18, 2024
1 parent 806e562 commit 06f755f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 10 deletions.
28 changes: 22 additions & 6 deletions surya/model/common/adetr/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,25 @@
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import PretrainedConfig
from transformers.utils import ModelOutput

from transformers import PreTrainedModel
from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import ModelOutput

from surya.settings import settings

try:
import torch_xla.core.xla_model as xm
except:
pass


def mark_step():
if settings.TORCH_DEVICE == 'xla':
xm.mark_step()


_MAX_SQRT_GRADIENT = 1000.0

Expand All @@ -20,6 +31,7 @@ class WrappedEmbedding(nn.Embedding):
def forward(self, input_ids, *args, **kwargs):
return super().forward(input_ids)


class SuryaADETRDecoderRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
Expand Down Expand Up @@ -70,7 +82,7 @@ def forward(self, x, position_ids, seq_len=None):
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)


Expand Down Expand Up @@ -167,6 +179,7 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

mark_step()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
Expand All @@ -175,6 +188,7 @@ def forward(
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
)
mark_step()

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
Expand Down Expand Up @@ -258,6 +272,7 @@ def forward(
position_mask[:, :, :, :current_cache_position + 1] = False
causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask)

mark_step()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
Expand All @@ -266,6 +281,7 @@ def forward(
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
)
mark_step()

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
Expand Down Expand Up @@ -577,4 +593,4 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask
return causal_mask
22 changes: 18 additions & 4 deletions surya/model/common/donut/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,24 @@
import torch
import torch.utils.checkpoint
from torch import nn

from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from transformers.utils import ModelOutput

from surya.model.recognition.config import DonutSwinConfig
from surya.settings import settings

try:
import torch_xla.core.xla_model as xm
except:
pass


def mark_step():
if settings.TORCH_DEVICE == 'xla':
xm.mark_step()


_EXPECTED_OUTPUT_SHAPE = [1, 49, 1024]

Expand Down Expand Up @@ -334,7 +346,7 @@ def transpose_for_scores(self, x):
def transpose_kv_for_scores(self, x, repeats):
new_x_shape = x.size()[:-1] + (self.num_kv_heads, self.attention_head_size)
x = x.view(new_x_shape)
x = x.repeat(1, 1, repeats, 1) # repeat the values for each key-value head to match query dim
x = x.repeat(1, 1, repeats, 1) # repeat the values for each key-value head to match query dim
return x.permute(0, 2, 1, 3).contiguous()

def forward(
Expand Down Expand Up @@ -365,6 +377,7 @@ def forward(
attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1)
attention_mask = attention_mask + relative_position_bias

mark_step()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
Expand All @@ -373,6 +386,7 @@ def forward(
dropout_p=self.dropout_p if self.training else 0.0,
scale=self.attention_head_size**-0.5,
)
mark_step()

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, dim, num_channels)
Expand Down Expand Up @@ -697,7 +711,7 @@ def __init__(self, config, grid_size):
depth=config.depths[i_layer],
num_heads=config.num_heads[i_layer],
num_kv_heads=config.num_kv_heads[i_layer] if hasattr(config, "num_kv_heads") else config.num_heads[i_layer],
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
drop_path=dpr[sum(config.depths[:i_layer]): sum(config.depths[: i_layer + 1])],
downsample=DonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
)
for i_layer in range(self.num_layers)
Expand Down Expand Up @@ -806,4 +820,4 @@ def _init_weights(self, module):
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.weight.data.fill_(1.0)

0 comments on commit 06f755f

Please sign in to comment.