Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Torch XLA (TPU support) #264

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
6 changes: 4 additions & 2 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_batch_size():
batch_size = 36
return batch_size


def pad_to_batch_size(tensor, batch_size):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
Expand All @@ -37,6 +38,7 @@ def pad_to_batch_size(tensor, batch_size):

return F.pad(tensor, padding, mode='constant', value=0)


def batch_detection(
images: List,
model: EfficientViTForSemanticSegmentation,
Expand Down Expand Up @@ -86,7 +88,7 @@ def batch_detection(
if static_cache:
batch = pad_to_batch_size(batch, batch_size)

with torch.inference_mode():
with settings.INFERENCE_MODE():
pred = model(pixel_values=batch)

logits = pred.logits
Expand All @@ -95,7 +97,7 @@ def batch_detection(
if current_shape != correct_shape:
logits = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False)

logits = logits.cpu().detach().numpy().astype(np.float32)
logits = logits.to(torch.float32).cpu().detach().numpy()
preds = []
for i, (idx, height) in enumerate(zip(split_index, split_heights)):
# If our current prediction length is below the image idx, that means we have a new image
Expand Down
29 changes: 25 additions & 4 deletions surya/layout.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import List

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image

from tqdm import tqdm

from surya.input.slicing import ImageSlicer
from surya.model.layout.config import ID_TO_LABEL
from surya.postprocessing.heatmap import clean_boxes, intersects_other_boxes
from surya.schema import LayoutResult, LayoutBox
from surya.schema import LayoutBox, LayoutResult
from surya.settings import settings


Expand Down Expand Up @@ -63,6 +64,17 @@ def find_pause_items(preds):
return pause_sequence


def pad_to_batch_size(tensor, batch_size):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
return tensor

pad_size = batch_size - current_batch_size
padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)

return F.pad(tensor, padding, mode='constant', value=0)


def batch_layout_detection(images: List, model, processor, batch_size=None, top_k=5) -> List[LayoutResult]:
assert all([isinstance(image, Image.Image) for image in images])
if batch_size is None:
Expand Down Expand Up @@ -100,7 +112,6 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_

batch_pixel_values = model_inputs["pixel_values"]
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)

pause_token = [model.config.decoder.pause_token_id] * 7
start_token = [model.config.decoder.bos_token_id] * 7
batch_decoder_input = [
Expand All @@ -109,18 +120,26 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_
]
batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device)
inference_token_count = batch_decoder_input.shape[1]
if settings.LAYOUT_STATIC_CACHE:
batch_pixel_values = pad_to_batch_size(batch_pixel_values, batch_size)
batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size)

decoder_position_ids = torch.ones_like(batch_decoder_input[0, :, 0], dtype=torch.int64, device=model.device).cumsum(0) - 1
model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)

batch_predictions = [[] for _ in range(current_batch_size)]

with torch.inference_mode():
with settings.INFERENCE_MODE():
encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values)[0]

token_count = 0
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)

if settings.LAYOUT_STATIC_CACHE:
# Pad inputs to max batch size for static cache
encoder_hidden_states = pad_to_batch_size(encoder_hidden_states, batch_size)
batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size)

while token_count < settings.LAYOUT_MAX_BOXES:
is_prefill = token_count == 0
return_dict = model.decoder(
Expand Down Expand Up @@ -148,6 +167,8 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_
break

batch_decoder_input = torch.cat([box_preds.unsqueeze(1), class_preds.unsqueeze(1).unsqueeze(1)], dim=-1)
if settings.LAYOUT_STATIC_CACHE:
batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size)

for j, (pred, status) in enumerate(zip(batch_decoder_input, all_done)):
if not status:
Expand Down
30 changes: 24 additions & 6 deletions surya/model/common/adetr/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,25 @@
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import PretrainedConfig
from transformers.utils import ModelOutput

from transformers import PreTrainedModel
from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import ModelOutput

from surya.settings import settings

try:
import torch_xla.core.xla_model as xm
except:
pass


def mark_step():
if settings.TORCH_DEVICE_MODEL == 'xla':
xm.mark_step()


_MAX_SQRT_GRADIENT = 1000.0

Expand All @@ -20,6 +31,7 @@ class WrappedEmbedding(nn.Embedding):
def forward(self, input_ids, *args, **kwargs):
return super().forward(input_ids)


class SuryaADETRDecoderRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
Expand Down Expand Up @@ -70,7 +82,7 @@ def forward(self, x, position_ids, seq_len=None):
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)


Expand Down Expand Up @@ -167,6 +179,7 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

mark_step()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
Expand All @@ -175,6 +188,7 @@ def forward(
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
)
mark_step()

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
Expand Down Expand Up @@ -247,6 +261,7 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

mark_step()
causal_mask = attention_mask
if attention_mask is not None:
# Mask is batch, head, seq_len, kv_len
Expand All @@ -255,9 +270,11 @@ def forward(
if current_cache_position and self.static_cache:
# Mask out future cache positions
position_mask = torch.ones_like(causal_mask, dtype=torch.bool, device=causal_mask.device)
mark_step()
position_mask[:, :, :, :current_cache_position + 1] = False
causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask)

mark_step()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
Expand All @@ -266,6 +283,7 @@ def forward(
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
)
mark_step()

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
Expand Down Expand Up @@ -577,4 +595,4 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask
return causal_mask
24 changes: 20 additions & 4 deletions surya/model/common/donut/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,24 @@
import torch
import torch.utils.checkpoint
from torch import nn

from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from transformers.utils import ModelOutput

from surya.model.recognition.config import DonutSwinConfig
from surya.settings import settings

try:
import torch_xla.core.xla_model as xm
except:
pass


def mark_step():
if settings.TORCH_DEVICE_MODEL == 'xla':
xm.mark_step()


_EXPECTED_OUTPUT_SHAPE = [1, 49, 1024]

Expand Down Expand Up @@ -334,7 +346,7 @@ def transpose_for_scores(self, x):
def transpose_kv_for_scores(self, x, repeats):
new_x_shape = x.size()[:-1] + (self.num_kv_heads, self.attention_head_size)
x = x.view(new_x_shape)
x = x.repeat(1, 1, repeats, 1) # repeat the values for each key-value head to match query dim
x = x.repeat(1, 1, repeats, 1) # repeat the values for each key-value head to match query dim
return x.permute(0, 2, 1, 3).contiguous()

def forward(
Expand Down Expand Up @@ -365,6 +377,7 @@ def forward(
attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1)
attention_mask = attention_mask + relative_position_bias

mark_step()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
Expand All @@ -373,6 +386,7 @@ def forward(
dropout_p=self.dropout_p if self.training else 0.0,
scale=self.attention_head_size**-0.5,
)
mark_step()

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, dim, num_channels)
Expand All @@ -381,6 +395,8 @@ def forward(
return outputs

# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput


class DonutSwinSelfOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
Expand Down Expand Up @@ -697,7 +713,7 @@ def __init__(self, config, grid_size):
depth=config.depths[i_layer],
num_heads=config.num_heads[i_layer],
num_kv_heads=config.num_kv_heads[i_layer] if hasattr(config, "num_kv_heads") else config.num_heads[i_layer],
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
drop_path=dpr[sum(config.depths[:i_layer]): sum(config.depths[: i_layer + 1])],
downsample=DonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
)
for i_layer in range(self.num_layers)
Expand Down Expand Up @@ -806,4 +822,4 @@ def _init_weights(self, module):
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.weight.data.fill_(1.0)
10 changes: 7 additions & 3 deletions surya/model/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def load_model(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT, device=settings.TO
torch._dynamo.config.suppress_errors = False

print(f"Compiling detection model {checkpoint} on device {device} with dtype {dtype}")
model = torch.compile(model)
if device == 'xla':
model = torch.compile(model, backend='openxla')
else:
model = torch.compile(model)

print(f"Loaded detection model {checkpoint} on device {device} with dtype {dtype}")
return model
Expand Down Expand Up @@ -73,6 +76,7 @@ def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding


class ConvNormAct(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -726,7 +730,7 @@ def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
all_hidden_states = ()
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
encoder_hidden_state = mlp(encoder_hidden_state) # Output is B, HW, C
encoder_hidden_state = mlp(encoder_hidden_state) # Output is B, HW, C
# Permute to B, C, HW
encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width)
Expand Down Expand Up @@ -805,4 +809,4 @@ def forward(
loss=None,
logits=logits,
hidden_states=encoder_hidden_states
)
)
8 changes: 6 additions & 2 deletions surya/model/layout/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ def load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT, device=settings.TORC
torch._dynamo.config.suppress_errors = False

print(f"Compiling layout model {checkpoint} on device {device} with dtype {dtype}")
model.encoder = torch.compile(model.encoder)
model.decoder = torch.compile(model.decoder)
if device == 'xla':
model.encoder = torch.compile(model.encoder, backend='openxla')
model.decoder = torch.compile(model.decoder, backend='openxla')
else:
model.encoder = torch.compile(model.encoder)
model.decoder = torch.compile(model.decoder)

print(f"Loaded layout model {checkpoint} on device {device} with dtype {dtype}")
return model
14 changes: 9 additions & 5 deletions surya/model/recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,15 @@ def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings
torch._dynamo.config.cache_size_limit = 16
torch._dynamo.config.suppress_errors = False


print(f"Compiling recognition model {checkpoint} on device {device} with dtype {dtype}")
model.encoder = torch.compile(model.encoder)
model.decoder = torch.compile(model.decoder)
model.text_encoder = torch.compile(model.text_encoder)
if device == 'xla':
model.encoder = torch.compile(model.encoder, backend='openxla')
model.decoder = torch.compile(model.decoder, backend='openxla')
model.text_encoder = torch.compile(model.text_encoder, backend='openxla')
else:
model.encoder = torch.compile(model.encoder)
model.decoder = torch.compile(model.decoder)
model.text_encoder = torch.compile(model.text_encoder)

print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}")
return model
return model
Loading
Loading