Skip to content

Commit

Permalink
add top_k to layout results
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Dec 13, 2024
1 parent a3fde2f commit cf7ee06
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
36 changes: 20 additions & 16 deletions surya/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def find_pause_items(preds):
return pause_sequence


def batch_layout_detection(images: List, model, processor, batch_size=None) -> List[LayoutResult]:
def batch_layout_detection(images: List, model, processor, batch_size=None, top_k=5) -> List[LayoutResult]:
assert all([isinstance(image, Image.Image) for image in images])
if batch_size is None:
batch_size = get_batch_size()
Expand All @@ -80,7 +80,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
if any([
sum(img_counts[start_idx:end_idx]) >= batch_size,
sum(img_counts[start_idx:end_idx + 1]) > batch_size,
]):
]):
batches.append((start_idx, end_idx))
start_idx = end_idx
end_idx += 1
Expand Down Expand Up @@ -136,7 +136,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
box_logits = return_dict["bbox_logits"][:current_batch_size, -1, :].detach()
class_logits = return_dict["class_logits"][:current_batch_size, -1, :].detach()

probs = torch.nn.functional.softmax(class_logits, dim=-1).detach().cpu()
probs = torch.nn.functional.softmax(class_logits, dim=-1)
entropy = torch.special.entr(probs).sum(dim=-1)

class_preds = class_logits.argmax(-1)
Expand All @@ -161,11 +161,11 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
"paused": False,
"pause_tokens": 0,
"polygon": prediction_to_polygon(
preds,
orig_sizes[j],
model.config.decoder.bbox_size,
model.config.decoder.skew_scaler
),
preds,
orig_sizes[j],
model.config.decoder.bbox_size,
model.config.decoder.skew_scaler
),
"label": preds[6].item() - model.decoder.config.special_token_count,
"class_logits": class_logits[j].detach().cpu(),
"orig_size": orig_sizes[j]
Expand All @@ -188,19 +188,20 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
prediction["token"].fill_(model.decoder.config.pause_token_id)
batch_decoder_input[j, :] = model.decoder.config.pause_token_id
elif all([
prediction["text_label"] in ["PageHeader", "PageFooter"],
prediction["polygon"][0][1] < prediction["orig_size"][1] * .8,
prediction["polygon"][2][1] > prediction["orig_size"][1] * .2,
prediction["polygon"][0][0] < prediction["orig_size"][0] * .8,
prediction["polygon"][2][0] > prediction["orig_size"][0] * .2
]):
prediction["text_label"] in ["PageHeader", "PageFooter"],
prediction["polygon"][0][1] < prediction["orig_size"][1] * .8,
prediction["polygon"][2][1] > prediction["orig_size"][1] * .2,
prediction["polygon"][0][0] < prediction["orig_size"][0] * .8,
prediction["polygon"][2][0] > prediction["orig_size"][0] * .2
]):
# Ensure page footers only occur at the bottom of the page, headers only at top
prediction["class_logits"][int(preds[6].item())] = 0
new_prediction = prediction["class_logits"].argmax(-1).item()
prediction["label"] = new_prediction - model.decoder.config.special_token_count
prediction["token"][6] = new_prediction
batch_decoder_input[j, -1, 6] = new_prediction

prediction["class_logits"], prediction["class_indices"] = torch.topk(prediction["class_logits"], k=top_k, dim=-1)
batch_predictions[j].append(prediction)

token_count += inference_token_count
Expand All @@ -209,16 +210,19 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L

for j, (pred_dict, orig_size) in enumerate(zip(batch_predictions, orig_sizes)):
boxes = []
preds = [p for p in pred_dict if p["token"][6] > model.decoder.config.special_token_count] # Remove special tokens, like pause
preds = [p for p in pred_dict if p["token"][6] > model.decoder.config.special_token_count] # Remove special tokens, like pause
if len(preds) > 0:
polygons = [p["polygon"] for p in preds]
labels = [p["label"] for p in preds]
top_k_probs = [p["class_logits"] for p in preds]
top_k_indices = [p["class_indices"] - model.decoder.config.special_token_count for p in preds]

for z, (poly, label) in enumerate(zip(polygons, labels)):
lb = LayoutBox(
polygon=poly,
label=ID_TO_LABEL[int(label)],
position=z
position=z,
top_k={ID_TO_LABEL.get(max(int(l), 0)): prob.item() for (l, prob) in zip(top_k_indices[z], top_k_probs[z])}
)
boxes.append(lb)
boxes = clean_boxes(boxes)
Expand Down
7 changes: 4 additions & 3 deletions surya/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
from typing import List, Tuple, Any, Optional
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, field_validator, computed_field
from pydantic import BaseModel, computed_field, field_validator

from surya.postprocessing.util import rescale_bbox

Expand Down Expand Up @@ -154,6 +154,7 @@ def intersection_pct(self, other):
class LayoutBox(PolygonBox):
label: str
position: int
top_k: Optional[Dict[str, float]] = None


class ColumnLine(Bbox):
Expand Down Expand Up @@ -183,7 +184,7 @@ class TextDetectionResult(BaseModel):
class LayoutResult(BaseModel):
bboxes: List[LayoutBox]
image_bbox: List[float]
sliced: bool = False # Whether the image was sliced and reconstructed
sliced: bool = False # Whether the image was sliced and reconstructed


class TableCell(Bbox):
Expand Down

0 comments on commit cf7ee06

Please sign in to comment.