Skip to content

Commit

Permalink
Fix residual flow
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Dec 12, 2024
1 parent edeea3d commit fb6d4db
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ opencv-python = "^4.9.0.80"
tabulate = "^0.9.0"
filetype = "^1.2.0"
ftfy = "^6.1.3"
pdftext = "^0.3.18"
pdftext = "~0.3.18"

[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
Expand Down
3 changes: 2 additions & 1 deletion surya/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
prediction["pause_tokens"] = last_prediction["pause_tokens"]
prediction["token"].fill_(model.decoder.config.pause_token_id)
batch_decoder_input[j, :] = model.decoder.config.pause_token_id
elif intersects_other_boxes(prediction["polygon"], [p["polygon"] for p in batch_predictions[j]], thresh=.4):
elif intersects_other_boxes(prediction["polygon"], [p["polygon"] for p in batch_predictions[j]], thresh=.4) and \
model.decoder.config.max_pause_tokens > 0:
prediction["paused"] = True
prediction["pause_tokens"] = 1
prediction["token"].fill_(model.decoder.config.pause_token_id)
Expand Down
43 changes: 43 additions & 0 deletions surya/model/common/adetr/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,50 @@ def __init__(self, config, layer_idx, static_cache=False, max_boxes=None):
self.channel_pre_norm = SuryaADETRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp_block = SuryaADETRDecoderMlp(config)

self.double_residual_flow = getattr(config, "double_residual_flow", False)

def forward(
self,
activations: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
encoder_attention_mask: torch.Tensor = None,
cache_position: torch.Tensor = None,
use_cache: bool = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
if self.double_residual_flow:
return self.double_res_forward(
activations, position_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache
)

hidden_states = activations
if self.cross_attn_block is not None:
# Do cross-attention on encoder outputs
cross_attn_inputs = self.cross_pre_norm(hidden_states)
cross_attn_path = self.cross_attn_block(
cross_attn_inputs, position_ids, encoder_hidden_states, attention_mask, encoder_attention_mask
)
hidden_states = cross_attn_path + hidden_states

if self.temporal_block is not None:
temporal_inputs = self.temporal_pre_norm(
hidden_states) # RMSNorm introduces slight slight differences
temporal_path = self.temporal_block(
temporal_inputs, position_ids, attention_mask, cache_position=cache_position,
use_cache=use_cache, window_attn=self.window_attn
)

hidden_states = temporal_path + hidden_states

block_input = hidden_states
hidden_states = self.channel_pre_norm(block_input)
hidden_states = self.mlp_block(hidden_states)

hidden_states = hidden_states + block_input
return hidden_states

def double_res_forward(
self,
activations: torch.Tensor,
position_ids: torch.Tensor,
Expand Down
5 changes: 3 additions & 2 deletions surya/model/layout/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def __init__(
aux_heads=0, # How many n-token-ahead heads to add
causal=True,
layer_norm_eps=1e-5,
pause_token_count=5,
max_pause_tokens=3,
pause_token_count=0,
max_pause_tokens=0,
**kwargs,
):
self.num_hidden_layers = num_hidden_layers
Expand Down Expand Up @@ -217,6 +217,7 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.pause_token_count = pause_token_count
self.max_pause_tokens = max_pause_tokens
self.double_residual_flow = True # Residual flow slightly different

super().__init__(
pad_token_id=pad_token_id,
Expand Down
1 change: 1 addition & 0 deletions surya/model/recognition/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
self.aux_heads = aux_heads
self.encoder_hidden_size = encoder_hidden_size
self.causal = causal
self.double_residual_flow = True # Residual flow slightly different

super().__init__(
pad_token_id=pad_token_id,
Expand Down

0 comments on commit fb6d4db

Please sign in to comment.