Skip to content

Commit

Permalink
Merge pull request #212 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Fix MPS issue with pytorch 2.5
  • Loading branch information
VikParuchuri authored Oct 18, 2024
2 parents 865306f + 5a4efb1 commit 0013f92
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 92 deletions.
5 changes: 1 addition & 4 deletions ocr_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
import time
from collections import defaultdict

import torch

from surya.input.langs import replace_lang_with_code, get_unique_langs
from surya.input.langs import replace_lang_with_code
from surya.input.load import load_from_folder, load_from_file, load_lang_file
from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.model.recognition.tokenizer import _tokenize
from surya.ocr import run_ocr
from surya.postprocessing.text import draw_text_on_image
from surya.settings import settings
Expand Down
132 changes: 66 additions & 66 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.6.4"
version = "0.6.5"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand All @@ -23,7 +23,7 @@ include = [
[tool.poetry.dependencies]
python = "^3.10"
transformers = "^4.41.0"
torch = "^2.3.0"
torch = "^2.4.1"
pydantic = "^2.5.3"
pydantic-settings = "^2.1.0"
python-dotenv = "^1.0.0"
Expand All @@ -33,7 +33,7 @@ opencv-python = "^4.9.0.80"
tabulate = "^0.9.0"
filetype = "^1.2.0"
ftfy = "^6.1.3"
pdftext = "^0.3.16"
pdftext = "^0.3.17"

[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
Expand Down
12 changes: 6 additions & 6 deletions surya/model/recognition/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def forward(
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
Expand Down Expand Up @@ -261,9 +261,9 @@ def forward(
causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
Expand Down
6 changes: 3 additions & 3 deletions surya/model/recognition/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,9 @@ def forward(
attention_mask = attention_mask + relative_position_bias

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer.contiguous(),
key_layer.contiguous(),
value_layer.contiguous(),
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_p if self.training else 0.0,
scale=self.attention_head_size**-0.5,
Expand Down
5 changes: 1 addition & 4 deletions surya/model/recognition/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from itertools import chain
import random
from typing import List, Optional, Tuple, Union
from tokenizers import AddedToken
from typing import List, Union
from transformers import ByT5Tokenizer
import numpy as np
import torch
Expand Down
12 changes: 6 additions & 6 deletions surya/model/table_rec/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def forward(
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
Expand Down Expand Up @@ -260,9 +260,9 @@ def forward(
causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
Expand Down

0 comments on commit 0013f92

Please sign in to comment.