Skip to content

Commit

Permalink
make inference mode configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Dec 20, 2024
1 parent 67de001 commit a3b80c1
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 12 deletions.
4 changes: 3 additions & 1 deletion surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_batch_size():
batch_size = 36
return batch_size


def pad_to_batch_size(tensor, batch_size):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
Expand All @@ -37,6 +38,7 @@ def pad_to_batch_size(tensor, batch_size):

return F.pad(tensor, padding, mode='constant', value=0)


def batch_detection(
images: List,
model: EfficientViTForSemanticSegmentation,
Expand Down Expand Up @@ -86,7 +88,7 @@ def batch_detection(
if static_cache:
batch = pad_to_batch_size(batch, batch_size)

with torch.no_grad():
with settings.INFERENCE_MODE():
pred = model(pixel_values=batch)

logits = pred.logits
Expand Down
2 changes: 1 addition & 1 deletion surya/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_

batch_predictions = [[] for _ in range(current_batch_size)]

with torch.no_grad():
with settings.INFERENCE_MODE():
encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values)[0]

token_count = 0
Expand Down
18 changes: 10 additions & 8 deletions surya/ocr_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from surya.settings import settings
from surya.schema import OCRErrorDetectionResult


def get_batch_size():
batch_size = settings.OCR_ERROR_BATCH_SIZE
if batch_size is None:
Expand All @@ -20,26 +21,27 @@ def get_batch_size():
batch_size = 64
return batch_size


def batch_ocr_error_detection(
texts: List[str],
model: DistilBertForSequenceClassification,
tokenizer: DistilBertTokenizer,
batch_size: Optional[int] = None
texts: List[str],
model: DistilBertForSequenceClassification,
tokenizer: DistilBertTokenizer,
batch_size: Optional[int] = None
):
if batch_size is None:
batch_size = get_batch_size()

num_batches = ceil(len(texts)/batch_size)
num_batches = ceil(len(texts) / batch_size)
texts_processed = tokenizer(texts, padding='longest', truncation=True, return_tensors='pt')
predictions = []
for batch_idx in tqdm(range(num_batches)):
start_idx, end_idx = batch_idx*batch_size, (batch_idx+1)*batch_size
start_idx, end_idx = batch_idx * batch_size, (batch_idx + 1) * batch_size
batch_input_ids = texts_processed.input_ids[start_idx:end_idx].to(model.device)
batch_attention_mask = texts_processed.attention_mask[start_idx:end_idx].to(model.device)

with torch.inference_mode():
with settings.INFERENCE_MODE():
pred = model(batch_input_ids, attention_mask=batch_attention_mask)
logits = pred.logits.detach().cpu().numpy().astype(np.float32)
logits = pred.logits.to(torch.float32).cpu().detach().numpy()
predictions.extend(np.argmax(logits, axis=1).tolist())

return OCRErrorDetectionResult(
Expand Down
2 changes: 1 addition & 1 deletion surya/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def batch_recognition(images: List[Image.Image], languages: List[List[str] | Non
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)
encoder_hidden_states = None

with torch.no_grad():
with settings.INFERENCE_MODE():
encoder_batch_size = batch_size // settings.RECOGNITION_ENCODER_BATCH_DIVISOR
for z in range(0, batch_pixel_values.shape[0], encoder_batch_size):
encoder_pixel_values = batch_pixel_values[z:min(z + encoder_batch_size, batch_pixel_values.shape[0])]
Expand Down
7 changes: 7 additions & 0 deletions surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ def MODEL_DTYPE(self) -> torch.dtype:
return torch.bfloat16
return torch.float16

@computed_field
@property
def INFERENCE_MODE(self):
if self.TORCH_DEVICE_MODEL == "xla":
return torch.no_grad
return torch.inference_mode

class Config:
env_file = find_dotenv("local.env")
extra = "ignore"
Expand Down
2 changes: 1 addition & 1 deletion surya/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def batch_table_recognition(images: List, table_cells: List[List[Dict]], model:

batch_predictions = [[] for _ in range(current_batch_size)]

with torch.no_grad():
with settings.INFERENCE_MODE():
encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values).last_hidden_state
text_encoder_hidden_states = model.text_encoder(
input_boxes=batch_bboxes,
Expand Down

0 comments on commit a3b80c1

Please sign in to comment.