From cf7ee06ef6a0aa3d76d5db5bed9301e7e3e6c0eb Mon Sep 17 00:00:00 2001 From: Moses Paul R Date: Fri, 13 Dec 2024 15:48:13 +0000 Subject: [PATCH 01/12] add top_k to layout results --- surya/layout.py | 36 ++++++++++++++++++++---------------- surya/schema.py | 7 ++++--- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/surya/layout.py b/surya/layout.py index 429bc62e..ee67fdf1 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -64,7 +64,7 @@ def find_pause_items(preds): return pause_sequence -def batch_layout_detection(images: List, model, processor, batch_size=None) -> List[LayoutResult]: +def batch_layout_detection(images: List, model, processor, batch_size=None, top_k=5) -> List[LayoutResult]: assert all([isinstance(image, Image.Image) for image in images]) if batch_size is None: batch_size = get_batch_size() @@ -80,7 +80,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L if any([ sum(img_counts[start_idx:end_idx]) >= batch_size, sum(img_counts[start_idx:end_idx + 1]) > batch_size, - ]): + ]): batches.append((start_idx, end_idx)) start_idx = end_idx end_idx += 1 @@ -136,7 +136,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L box_logits = return_dict["bbox_logits"][:current_batch_size, -1, :].detach() class_logits = return_dict["class_logits"][:current_batch_size, -1, :].detach() - probs = torch.nn.functional.softmax(class_logits, dim=-1).detach().cpu() + probs = torch.nn.functional.softmax(class_logits, dim=-1) entropy = torch.special.entr(probs).sum(dim=-1) class_preds = class_logits.argmax(-1) @@ -161,11 +161,11 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L "paused": False, "pause_tokens": 0, "polygon": prediction_to_polygon( - preds, - orig_sizes[j], - model.config.decoder.bbox_size, - model.config.decoder.skew_scaler - ), + preds, + orig_sizes[j], + model.config.decoder.bbox_size, + model.config.decoder.skew_scaler + ), "label": preds[6].item() - model.decoder.config.special_token_count, "class_logits": class_logits[j].detach().cpu(), "orig_size": orig_sizes[j] @@ -188,12 +188,12 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L prediction["token"].fill_(model.decoder.config.pause_token_id) batch_decoder_input[j, :] = model.decoder.config.pause_token_id elif all([ - prediction["text_label"] in ["PageHeader", "PageFooter"], - prediction["polygon"][0][1] < prediction["orig_size"][1] * .8, - prediction["polygon"][2][1] > prediction["orig_size"][1] * .2, - prediction["polygon"][0][0] < prediction["orig_size"][0] * .8, - prediction["polygon"][2][0] > prediction["orig_size"][0] * .2 - ]): + prediction["text_label"] in ["PageHeader", "PageFooter"], + prediction["polygon"][0][1] < prediction["orig_size"][1] * .8, + prediction["polygon"][2][1] > prediction["orig_size"][1] * .2, + prediction["polygon"][0][0] < prediction["orig_size"][0] * .8, + prediction["polygon"][2][0] > prediction["orig_size"][0] * .2 + ]): # Ensure page footers only occur at the bottom of the page, headers only at top prediction["class_logits"][int(preds[6].item())] = 0 new_prediction = prediction["class_logits"].argmax(-1).item() @@ -201,6 +201,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L prediction["token"][6] = new_prediction batch_decoder_input[j, -1, 6] = new_prediction + prediction["class_logits"], prediction["class_indices"] = torch.topk(prediction["class_logits"], k=top_k, dim=-1) batch_predictions[j].append(prediction) token_count += inference_token_count @@ -209,16 +210,19 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L for j, (pred_dict, orig_size) in enumerate(zip(batch_predictions, orig_sizes)): boxes = [] - preds = [p for p in pred_dict if p["token"][6] > model.decoder.config.special_token_count] # Remove special tokens, like pause + preds = [p for p in pred_dict if p["token"][6] > model.decoder.config.special_token_count] # Remove special tokens, like pause if len(preds) > 0: polygons = [p["polygon"] for p in preds] labels = [p["label"] for p in preds] + top_k_probs = [p["class_logits"] for p in preds] + top_k_indices = [p["class_indices"] - model.decoder.config.special_token_count for p in preds] for z, (poly, label) in enumerate(zip(polygons, labels)): lb = LayoutBox( polygon=poly, label=ID_TO_LABEL[int(label)], - position=z + position=z, + top_k={ID_TO_LABEL.get(max(int(l), 0)): prob.item() for (l, prob) in zip(top_k_indices[z], top_k_probs[z])} ) boxes.append(lb) boxes = clean_boxes(boxes) diff --git a/surya/schema.py b/surya/schema.py index becfb870..96cbe535 100644 --- a/surya/schema.py +++ b/surya/schema.py @@ -1,7 +1,7 @@ import copy -from typing import List, Tuple, Any, Optional +from typing import Any, Dict, List, Optional -from pydantic import BaseModel, field_validator, computed_field +from pydantic import BaseModel, computed_field, field_validator from surya.postprocessing.util import rescale_bbox @@ -154,6 +154,7 @@ def intersection_pct(self, other): class LayoutBox(PolygonBox): label: str position: int + top_k: Optional[Dict[str, float]] = None class ColumnLine(Bbox): @@ -183,7 +184,7 @@ class TextDetectionResult(BaseModel): class LayoutResult(BaseModel): bboxes: List[LayoutBox] image_bbox: List[float] - sliced: bool = False # Whether the image was sliced and reconstructed + sliced: bool = False # Whether the image was sliced and reconstructed class TableCell(Bbox): From ed05c016f5e663fe8e9c00fc3209dc9ffccf7f51 Mon Sep 17 00:00:00 2001 From: Moses Paul R Date: Fri, 13 Dec 2024 16:02:07 +0000 Subject: [PATCH 02/12] proper probs --- surya/layout.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/surya/layout.py b/surya/layout.py index ee67fdf1..c12e6621 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -201,7 +201,8 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_ prediction["token"][6] = new_prediction batch_decoder_input[j, -1, 6] = new_prediction - prediction["class_logits"], prediction["class_indices"] = torch.topk(prediction["class_logits"], k=top_k, dim=-1) + prediction["top_k_probs"], prediction["top_k_indices"] = torch.topk(torch.nn.functional.softmax(prediction["class_logits"], dim=-1), k=top_k, dim=-1) + del prediction["class_logits"] batch_predictions[j].append(prediction) token_count += inference_token_count @@ -214,8 +215,8 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_ if len(preds) > 0: polygons = [p["polygon"] for p in preds] labels = [p["label"] for p in preds] - top_k_probs = [p["class_logits"] for p in preds] - top_k_indices = [p["class_indices"] - model.decoder.config.special_token_count for p in preds] + top_k_probs = [p["top_k_probs"] for p in preds] + top_k_indices = [p["top_k_indices"] - model.decoder.config.special_token_count for p in preds] for z, (poly, label) in enumerate(zip(polygons, labels)): lb = LayoutBox( From d3cf58ae98588b3e31974aa0a8acd4a1c0389d83 Mon Sep 17 00:00:00 2001 From: Moses Paul R Date: Fri, 13 Dec 2024 16:31:36 +0000 Subject: [PATCH 03/12] ditch blanks [skip ci] --- surya/layout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/surya/layout.py b/surya/layout.py index c12e6621..c30d6f67 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -223,7 +223,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_ polygon=poly, label=ID_TO_LABEL[int(label)], position=z, - top_k={ID_TO_LABEL.get(max(int(l), 0)): prob.item() for (l, prob) in zip(top_k_indices[z], top_k_probs[z])} + top_k={ID_TO_LABEL.get(int(l)): prob.item() for (l, prob) in zip(top_k_indices[z], top_k_probs[z]) if l > 0} ) boxes.append(lb) boxes = clean_boxes(boxes) From f000187f1d718fb792bb2ccb9e1cc20d2e394bc8 Mon Sep 17 00:00:00 2001 From: Moses Paul R Date: Fri, 13 Dec 2024 17:37:47 +0000 Subject: [PATCH 04/12] fix missing confidence in surya layout predictions --- surya/layout.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/surya/layout.py b/surya/layout.py index c30d6f67..476e54e4 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -219,11 +219,14 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_ top_k_indices = [p["top_k_indices"] - model.decoder.config.special_token_count for p in preds] for z, (poly, label) in enumerate(zip(polygons, labels)): + l = ID_TO_LABEL[int(label)] + top_k_dict = {ID_TO_LABEL.get(int(l)): prob.item() for (l, prob) in zip(top_k_indices[z], top_k_probs[z]) if l > 0} lb = LayoutBox( polygon=poly, - label=ID_TO_LABEL[int(label)], + label=l, position=z, - top_k={ID_TO_LABEL.get(int(l)): prob.item() for (l, prob) in zip(top_k_indices[z], top_k_probs[z]) if l > 0} + top_k=top_k_dict, + confidence=top_k_dict[l] ) boxes.append(lb) boxes = clean_boxes(boxes) From 9f7ea2938aa4ee1bcea56f7e1f648807385bb2fe Mon Sep 17 00:00:00 2001 From: Tarun Ram Date: Sat, 14 Dec 2024 21:33:04 +0530 Subject: [PATCH 05/12] Add OCR Error Detection Model --- surya/model/ocr_error/config.py | 66 +++ surya/model/ocr_error/model.py | 820 +++++++++++++++++++++++++++++ surya/model/ocr_error/tokenizer.py | 642 ++++++++++++++++++++++ surya/ocr_error.py | 47 ++ surya/schema.py | 5 + surya/settings.py | 8 + 6 files changed, 1588 insertions(+) create mode 100644 surya/model/ocr_error/config.py create mode 100644 surya/model/ocr_error/model.py create mode 100644 surya/model/ocr_error/tokenizer.py create mode 100644 surya/ocr_error.py diff --git a/surya/model/ocr_error/config.py b/surya/model/ocr_error/config.py new file mode 100644 index 00000000..025f8d10 --- /dev/null +++ b/surya/model/ocr_error/config.py @@ -0,0 +1,66 @@ +from collections import OrderedDict +from typing import Mapping + +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxConfig + +ID2LABEL = { + 0: 'good', + 1: 'bad' +} + +class DistilBertConfig(PretrainedConfig): + model_type = "distilbert" + attribute_map = { + "hidden_size": "dim", + "num_attention_heads": "n_heads", + "num_hidden_layers": "n_layers", + } + + def __init__( + self, + vocab_size=30522, + max_position_embeddings=512, + sinusoidal_pos_embds=False, + n_layers=6, + n_heads=12, + dim=768, + hidden_dim=4 * 768, + dropout=0.1, + attention_dropout=0.1, + activation="gelu", + initializer_range=0.02, + qa_dropout=0.1, + seq_classif_dropout=0.2, + pad_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.sinusoidal_pos_embds = sinusoidal_pos_embds + self.n_layers = n_layers + self.n_heads = n_heads + self.dim = dim + self.hidden_dim = hidden_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation = activation + self.initializer_range = initializer_range + self.qa_dropout = qa_dropout + self.seq_classif_dropout = seq_classif_dropout + super().__init__(**kwargs, pad_token_id=pad_token_id) + + +class DistilBertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ] + ) \ No newline at end of file diff --git a/surya/model/ocr_error/model.py b/surya/model/ocr_error/model.py new file mode 100644 index 00000000..9b6dfcbc --- /dev/null +++ b/surya/model/ocr_error/model.py @@ -0,0 +1,820 @@ +import math +from typing import Dict, List, Optional, Set, Tuple, Union +from __future__ import annotations + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import get_activation +from transformers.modeling_outputs import ( + BaseModelOutput, + SequenceClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import ( + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, +) + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +from surya.model.ocr_error.config import DistilBertConfig +from surya.model.ocr_error.tokenizer import DistilBertTokenizer +from surya.settings import settings + +def load_model(checkpoint=settings.OCR_ERROR_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE) -> DistilBertForSequenceClassification: + config = DistilBertConfig.from_pretrained(checkpoint) + model = DistilBertForSequenceClassification.from_pretrained(checkpoint, torch_dtype=dtype, config=config).to(device).eval() + + if settings.OCR_ERROR_STATIC_CACHE: + torch.set_float32_matmul_precision('high') + torch._dynamo.config.cache_size_limit = 1 + torch._dynamo.config.suppress_errors = False + + print(f"Compiling detection model {checkpoint} on device {device} with dtype {dtype}") + model = torch.compile(model) + + return model + +def load_tokenizer(checkpoint=settings.OCR_ERROR_MODEL_CHECKPOINT): + tokenizer = DistilBertTokenizer.from_pretrained(checkpoint) + return tokenizer + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): + position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) + out.requires_grad = False + out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + out.detach_() + + +class Embeddings(nn.Module): + def __init__(self, config: DistilBertConfig): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim) + + self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) + self.dropout = nn.Dropout(config.dropout) + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Parameters: + input_ids (torch.Tensor): + torch.tensor(bs, max_seq_length) The token ids to embed. + input_embeds (*optional*, torch.Tensor): + The pre-computed word embeddings. Can only be passed if the input ids are `None`. + + + Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type + embeddings) + """ + if input_ids is not None: + input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) + + seq_length = input_embeds.size(1) + + # Setting the position-ids to the registered buffer in constructor, it helps + # when tracing the model without passing position-ids, solves + # isues similar to issue #5664 + if hasattr(self, "position_ids"): + position_ids = self.position_ids[:, :seq_length] + else: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) + + position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) + + embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim) + embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) + embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim) + return embeddings + + +class MultiHeadSelfAttention(nn.Module): + def __init__(self, config: DistilBertConfig): + super().__init__() + self.config = config + + self.n_heads = config.n_heads + self.dim = config.dim + self.dropout = nn.Dropout(p=config.attention_dropout) + self.is_causal = False + + # Have an even number of multi heads that divide the dimensions + if self.dim % self.n_heads != 0: + # Raise value errors for even multi-head attention nodes + raise ValueError(f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly") + + self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + + self.pruned_heads: Set[int] = set() + self.attention_head_size = self.dim // self.n_heads + + def prune_heads(self, heads: List[int]): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.attention_head_size, self.pruned_heads + ) + # Prune linear layers + self.q_lin = prune_linear_layer(self.q_lin, index) + self.k_lin = prune_linear_layer(self.k_lin, index) + self.v_lin = prune_linear_layer(self.v_lin, index) + self.out_lin = prune_linear_layer(self.out_lin, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.dim = self.attention_head_size * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + query: torch.tensor(bs, seq_length, dim) + key: torch.tensor(bs, seq_length, dim) + value: torch.tensor(bs, seq_length, dim) + mask: torch.tensor(bs, seq_length) + + Returns: + weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, + seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` + """ + bs, q_length, dim = query.size() + k_length = key.size(1) + # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' + # assert key.size() == value.size() + + dim_per_head = self.dim // self.n_heads + + mask_reshp = (bs, 1, 1, k_length) + + def shape(x: torch.Tensor) -> torch.Tensor: + """separate heads""" + return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) + + def unshape(x: torch.Tensor) -> torch.Tensor: + """group heads""" + return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) + + q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) + k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) + v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) + + q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) + scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) + mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) + scores = scores.masked_fill( + mask, torch.tensor(torch.finfo(scores.dtype).min) + ) # (bs, n_heads, q_length, k_length) + + weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length) + weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) + + # Mask heads if we want to + if head_mask is not None: + weights = weights * head_mask + + context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) + context = unshape(context) # (bs, q_length, dim) + context = self.out_lin(context) # (bs, q_length, dim) + + if output_attentions: + return (context, weights) + else: + return (context,) + + +class DistilBertFlashAttention2(MultiHeadSelfAttention): + """ + DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module + stays untouched. The only required change would be on the forward pass where it needs to correctly call the public + API of flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + query: torch.tensor(bs, seq_length, dim) + key: torch.tensor(bs, seq_length, dim) + value: torch.tensor(bs, seq_length, dim) + mask: torch.tensor(bs, seq_length) + + Returns: + weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, + seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` + """ + batch_size, q_length, dim = query.size() + + dim_per_head = self.dim // self.n_heads + + def reshape(x: torch.Tensor) -> torch.Tensor: + """separate heads""" + return x.view(batch_size, -1, self.n_heads, dim_per_head) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = reshape(self.q_lin(query)) + key_states = reshape(self.k_lin(key)) + value_states = reshape(self.v_lin(value)) + + attn_dropout = self.config.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_lin.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_weights = self._flash_attention_forward( + query_states, key_states, value_states, mask, q_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head) + attn_output = self.out_lin(attn_weights_reshaped) + + if output_attentions: + return (attn_output, attn_weights) + else: + return (attn_output,) + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class FFN(nn.Module): + def __init__(self, config: DistilBertConfig): + super().__init__() + self.dropout = nn.Dropout(p=config.dropout) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim) + self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim) + self.activation = get_activation(config.activation) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input) + + def ff_chunk(self, input: torch.Tensor) -> torch.Tensor: + x = self.lin1(input) + x = self.activation(x) + x = self.lin2(x) + x = self.dropout(x) + return x + + +DISTILBERT_ATTENTION_CLASSES = { + "eager": MultiHeadSelfAttention, + "flash_attention_2": DistilBertFlashAttention2, +} + + +class TransformerBlock(nn.Module): + def __init__(self, config: DistilBertConfig): + super().__init__() + + # Have an even number of Configure multi-heads + if config.dim % config.n_heads != 0: + raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly") + + self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](config) + self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) + + self.ffn = FFN(config) + self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + x: torch.tensor(bs, seq_length, dim) + attn_mask: torch.tensor(bs, seq_length) + + Returns: + sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output: + torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization. + """ + # Self-Attention + sa_output = self.attention( + query=x, + key=x, + value=x, + mask=attn_mask, + head_mask=head_mask, + output_attentions=output_attentions, + ) + if output_attentions: + sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) + else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples + if type(sa_output) != tuple: + raise TypeError(f"sa_output must be a tuple but it is {type(sa_output)} type") + + sa_output = sa_output[0] + sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) + + # Feed Forward Network + ffn_output = self.ffn(sa_output) # (bs, seq_length, dim) + ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim) + + output = (ffn_output,) + if output_attentions: + output = (sa_weights,) + output + return output + + +class Transformer(nn.Module): + def __init__(self, config: DistilBertConfig): + super().__init__() + self.n_layers = config.n_layers + self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: # docstyle-ignore + """ + Parameters: + x: torch.tensor(bs, seq_length, dim) Input sequence embedded. + attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence. + + Returns: + hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top) + layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)] + Tuple of length n_layers with the hidden states from each layer. + Optional: only if output_hidden_states=True + all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] + Tuple of length n_layers with the attention weights from each layer + Optional: only if output_attentions=True + """ + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_state = x + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_state, + attn_mask, + head_mask[i], + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_state, + attn_mask, + head_mask[i], + output_attentions, + ) + + hidden_state = layer_outputs[-1] + + if output_attentions: + if len(layer_outputs) != 2: + raise ValueError(f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}") + + attentions = layer_outputs[0] + all_attentions = all_attentions + (attentions,) + else: + if len(layer_outputs) != 1: + raise ValueError(f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}") + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL # +class DistilBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DistilBertConfig + load_tf_weights = None + base_model_prefix = "distilbert" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds: + create_sinusoidal_embeddings( + self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight + ) + + +class DistilBertModel(DistilBertPreTrainedModel): + def __init__(self, config: DistilBertConfig): + super().__init__(config) + + self.embeddings = Embeddings(config) # Embeddings + self.transformer = Transformer(config) # Encoder + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + # Initialize weights and apply final processing + self.post_init() + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings + """ + return self.embeddings.position_embeddings + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embedding matrix. If position embeddings are learned, increasing the size + will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the + end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the + size will add correct vectors at the end following the position encoding algorithm, whereas reducing + the size will remove vectors from the end. + """ + num_position_embeds_diff = new_num_position_embeddings - self.config.max_position_embeddings + + # no resizing needs to be done if the length stays the same + if num_position_embeds_diff == 0: + return + + self.config.max_position_embeddings = new_num_position_embeddings + + old_position_embeddings_weight = self.embeddings.position_embeddings.weight.clone() + + self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim) + + if self.config.sinusoidal_pos_embds: + create_sinusoidal_embeddings( + n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight + ) + else: + with torch.no_grad(): + if num_position_embeds_diff > 0: + self.embeddings.position_embeddings.weight[:-num_position_embeds_diff] = nn.Parameter( + old_position_embeddings_weight + ) + else: + self.embeddings.position_embeddings.weight = nn.Parameter( + old_position_embeddings_weight[:num_position_embeds_diff] + ) + # move position_embeddings to correct device + self.embeddings.position_embeddings.to(self.device) + + def get_input_embeddings(self) -> nn.Embedding: + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings: nn.Embedding): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.transformer.layer[layer].attention.prune_heads(heads) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim) + + if self._use_flash_attention_2: + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length) + + return self.transformer( + x=embeddings, + attn_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class DistilBertForSequenceClassification(DistilBertPreTrainedModel): + def __init__(self, config: DistilBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.distilbert = DistilBertModel(config) + self.pre_classifier = nn.Linear(config.dim, config.dim) + self.classifier = nn.Linear(config.dim, config.num_labels) + self.dropout = nn.Dropout(config.seq_classif_dropout) + + # Initialize weights and apply final processing + self.post_init() + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings + """ + return self.distilbert.get_position_embeddings() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embedding matrix. If position embeddings are learned, increasing the size + will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the + end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the + size will add correct vectors at the end following the position encoding algorithm, whereas reducing + the size will remove vectors from the end. + """ + self.distilbert.resize_position_embeddings(new_num_position_embeddings) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + distilbert_output = self.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_state = distilbert_output[0] # (bs, seq_len, dim) + pooled_output = hidden_state[:, 0] # (bs, dim) + pooled_output = self.pre_classifier(pooled_output) # (bs, dim) + pooled_output = nn.ReLU()(pooled_output) # (bs, dim) + pooled_output = self.dropout(pooled_output) # (bs, dim) + logits = self.classifier(pooled_output) # (bs, num_labels) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + distilbert_output[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=distilbert_output.hidden_states, + attentions=distilbert_output.attentions, + ) \ No newline at end of file diff --git a/surya/model/ocr_error/tokenizer.py b/surya/model/ocr_error/tokenizer.py new file mode 100644 index 00000000..e6304c38 --- /dev/null +++ b/surya/model/ocr_error/tokenizer.py @@ -0,0 +1,642 @@ +import collections +import os +import json +import unicodedata +from typing import List, Optional, Tuple + +from tokenizers import normalizers + +from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + +# Copied from transformers.models.bert.tokenization_bert.load_vocab +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class DistilBertTokenizer(PreTrainedTokenizer): + r""" + Construct a DistilBERT tokenizer. Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = DistilBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + @property + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.do_lower_case + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size + def vocab_size(self): + return len(self.vocab) + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize + def _tokenize(self, text, split_special_tokens=False): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + # logger.warning( + # f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + # " Please check that the vocabulary is not corrupted!" + # ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer +class BasicTokenizer(object): + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + +class DistilBertTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" DistilBERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + clean_text (`bool`, *optional*, defaults to `True`): + Whether or not to clean the text before tokenization by removing any control characters and replacing all + whitespaces by the classic one. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this + issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + wordpieces_prefix (`str`, *optional*, defaults to `"##"`): + The prefix for subwords. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = DistilBertTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + if ( + normalizer_state.get("lowercase", do_lower_case) != do_lower_case + or normalizer_state.get("strip_accents", strip_accents) != strip_accents + or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars + ): + normalizer_class = getattr(normalizers, normalizer_state.pop("type")) + normalizer_state["lowercase"] = do_lower_case + normalizer_state["strip_accents"] = strip_accents + normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars + self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state) + + self.do_lower_case = do_lower_case + + # Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast.build_inputs_with_special_tokens + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1 is not None: + output += token_ids_1 + [self.sep_token_id] + + return output + + # Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + # Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) \ No newline at end of file diff --git a/surya/ocr_error.py b/surya/ocr_error.py new file mode 100644 index 00000000..4739ad6d --- /dev/null +++ b/surya/ocr_error.py @@ -0,0 +1,47 @@ +from typing import List, Optional +from math import ceil +from tqdm import tqdm +import torch +import numpy as np + +from surya.model.ocr_error.model import DistilBertForSequenceClassification, DistilBertTokenizer +from surya.model.ocr_error.config import ID2LABEL +from surya.settings import settings +from surya.schema import OCRErrorDetectionResult + +def get_batch_size(): + batch_size = settings.OCR_ERROR_BATCH_SIZE + if batch_size is None: + batch_size = 8 + if settings.TORCH_DEVICE_MODEL == "mps": + batch_size = 8 + if settings.TORCH_DEVICE_MODEL == "cuda": + batch_size = 64 + return batch_size + +def batch_ocr_error_detection( + texts: List[str], + model: DistilBertForSequenceClassification, + tokenizer: DistilBertTokenizer, + batch_size: Optional[int] = None +): + if batch_size is None: + batch_size = get_batch_size() + + num_batches = ceil(len(texts)/batch_size) + texts_processed = tokenizer(texts, padding='longest', truncation=True, return_tensors='pt') + predictions = [] + for batch_idx in tqdm(range(num_batches)): + start_idx, end_idx = batch_idx*batch_size, (batch_idx+1)*batch_size + batch_input_ids = texts_processed.input_ids[start_idx:end_idx].to(model.device) + batch_attention_mask = texts_processed.attention_mask[start_idx:end_idx].to(model.device) + + with torch.inference_mode(): + pred = model(batch_input_ids, attention_mask=batch_attention_mask) + logits = pred.logits.detach().cpu().numpy().astype(np.float32) + predictions.extend(np.argmax(logits, axis=1).tolist()) + + return OCRErrorDetectionResult( + texts=texts, + labels=[ID2LABEL[p] for p in predictions] + ) diff --git a/surya/schema.py b/surya/schema.py index becfb870..f52340e0 100644 --- a/surya/schema.py +++ b/surya/schema.py @@ -211,3 +211,8 @@ class TableResult(BaseModel): rows: List[TableRow] cols: List[TableCol] image_bbox: List[float] + + +class OCRErrorDetectionResult(BaseModel): + texts: List[str] + labels: List[str] \ No newline at end of file diff --git a/surya/settings.py b/surya/settings.py index 8d35f2c8..4340a737 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -84,6 +84,11 @@ def TORCH_DEVICE_MODEL(self) -> str: TABLE_REC_BENCH_DATASET_NAME: str = "vikp/fintabnet_bench" COMPILE_TABLE_REC: bool = False + # OCR Error Detection + OCR_ERROR_MODEL_CHECKPOINT: str = "tarun-menta/ocr_error_detection" + OCR_ERROR_BATCH_SIZE: Optional[int] = None + COMPILE_OCR_ERROR: bool = False + # Tesseract (for benchmarks only) TESSDATA_PREFIX: Optional[str] = None @@ -105,6 +110,9 @@ def LAYOUT_STATIC_CACHE(self) -> bool: def TABLE_REC_STATIC_CACHE(self) -> bool: return self.COMPILE_ALL or self.COMPILE_TABLE_REC + def OCR_ERROR_STATIC_CACHE(self) -> bool: + return self.COMPILE_ALL or self.COMPILE_OCR_ERROR + @computed_field @property def MODEL_DTYPE(self) -> torch.dtype: From 872c31faf3ef4365d0ea6c7a62cb849b841fb36d Mon Sep 17 00:00:00 2001 From: Tarun Ram Date: Sat, 14 Dec 2024 21:46:23 +0530 Subject: [PATCH 06/12] Minor bugfix --- surya/model/ocr_error/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/surya/model/ocr_error/model.py b/surya/model/ocr_error/model.py index 9b6dfcbc..542e8c30 100644 --- a/surya/model/ocr_error/model.py +++ b/surya/model/ocr_error/model.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import math from typing import Dict, List, Optional, Set, Tuple, Union -from __future__ import annotations import numpy as np import torch From f63a4bf835e4167b4234475d1c6c7d288389e958 Mon Sep 17 00:00:00 2001 From: Tarun Menta <66506307+tarun-menta@users.noreply.github.com> Date: Sun, 15 Dec 2024 20:21:00 +0530 Subject: [PATCH 07/12] Update settings.py --- surya/settings.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/surya/settings.py b/surya/settings.py index 4340a737..59385076 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -110,6 +110,7 @@ def LAYOUT_STATIC_CACHE(self) -> bool: def TABLE_REC_STATIC_CACHE(self) -> bool: return self.COMPILE_ALL or self.COMPILE_TABLE_REC + @computed_field def OCR_ERROR_STATIC_CACHE(self) -> bool: return self.COMPILE_ALL or self.COMPILE_OCR_ERROR @@ -123,4 +124,4 @@ class Config: extra = "ignore" -settings = Settings() \ No newline at end of file +settings = Settings() From 20b7b62d5d9d5bd2e222a26196410e0c0898b97f Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Thu, 19 Dec 2024 10:48:54 -0500 Subject: [PATCH 08/12] Add bad OCR detection to app --- ocr_app.py | 59 +++++++++++++++++++++++++++++++++++++++++++---- surya/settings.py | 2 +- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/ocr_app.py b/ocr_app.py index 064a33b9..04ccefc7 100644 --- a/ocr_app.py +++ b/ocr_app.py @@ -1,4 +1,5 @@ import io +import tempfile from typing import List import pypdfium2 @@ -15,6 +16,7 @@ from surya.model.recognition.processor import load_processor as load_rec_processor from surya.model.table_rec.model import load_model as load_table_model from surya.model.table_rec.processor import load_processor as load_table_processor +from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image from surya.ocr import run_ocr from surya.postprocessing.text import draw_text_on_image @@ -24,7 +26,9 @@ from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult from surya.settings import settings from surya.tables import batch_table_recognition -from surya.postprocessing.util import rescale_bboxes, rescale_bbox +from surya.postprocessing.util import rescale_bbox +from pdftext.extraction import plain_text_output +from surya.ocr_error import batch_ocr_error_detection @st.cache_resource() @@ -46,6 +50,39 @@ def load_layout_cached(): def load_table_cached(): return load_table_model(), load_table_processor() +@st.cache_resource() +def load_ocr_error_cached(): + return load_ocr_error_model(), load_ocr_error_processor() + + +def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15): + with tempfile.NamedTemporaryFile(suffix=".pdf") as f: + f.write(pdf_file.getvalue()) + f.seek(0) + + # Sample the text from the middle of the PDF + page_middle = page_count // 2 + page_range = range(max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count)) + text = plain_text_output(f.name, page_range=page_range) + + sample_gap = len(text) // max_samples + if len(text) == 0 or sample_gap == 0: + return "This PDF has no text or very little text", ["no text"] + + if sample_gap < sample_len: + sample_gap = sample_len + + # Split the text into samples for the model + samples = [] + for i in range(0, len(text), sample_gap): + samples.append(text[i:i + sample_len]) + + results = batch_ocr_error_detection(samples, ocr_error_model, ocr_error_processor) + label = "This PDF has good text." + if results.labels.count("bad") / len(results.labels) > .2: + label = "This PDF may have garbled or bad OCR text." + return label, results.labels + def text_detection(img) -> (Image.Image, TextDetectionResult): pred = batch_text_detection([img], det_model, det_processor)[0] @@ -139,13 +176,16 @@ def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI): ) png = list(renderer)[0] png_image = png.convert("RGB") + doc.close() return png_image @st.cache_data() -def page_count(pdf_file): +def page_counter(pdf_file): doc = open_pdf(pdf_file) - return len(doc) + doc_len = len(doc) + doc.close() + return doc_len st.set_page_config(layout="wide") @@ -155,6 +195,7 @@ def page_count(pdf_file): rec_model, rec_processor = load_rec_cached() layout_model, layout_processor = load_layout_cached() table_model, table_processor = load_table_cached() +ocr_error_model, ocr_error_processor = load_ocr_error_cached() st.markdown(""" @@ -179,8 +220,9 @@ def page_count(pdf_file): filetype = in_file.type whole_image = False +page_count = None if "pdf" in filetype: - page_count = page_count(in_file) + page_count = page_counter(in_file) page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count) pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI) @@ -194,6 +236,7 @@ def page_count(pdf_file): text_rec = st.sidebar.button("Run OCR") layout_det = st.sidebar.button("Run Layout Analysis") table_rec = st.sidebar.button("Run Table Rec") +ocr_errors = st.sidebar.button("Run bad PDF text detection") use_pdf_boxes = st.sidebar.checkbox("PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.") skip_table_detection = st.sidebar.checkbox("Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.") @@ -233,5 +276,13 @@ def page_count(pdf_file): st.image(table_img, caption="Table Recognition", use_container_width=True) st.json([p.model_dump() for p in pred], expanded=True) +if ocr_errors: + if "pdf" not in filetype: + st.error("This feature only works with PDFs.") + label, results = run_ocr_errors(in_file, page_count) + with col1: + st.write(label) + st.json(results) + with col2: st.image(pil_image, caption="Uploaded Image", use_container_width=True) \ No newline at end of file diff --git a/surya/settings.py b/surya/settings.py index 59385076..3692c7ad 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -85,7 +85,7 @@ def TORCH_DEVICE_MODEL(self) -> str: COMPILE_TABLE_REC: bool = False # OCR Error Detection - OCR_ERROR_MODEL_CHECKPOINT: str = "tarun-menta/ocr_error_detection" + OCR_ERROR_MODEL_CHECKPOINT: str = "datalab-to/ocr_error_detection" OCR_ERROR_BATCH_SIZE: Optional[int] = None COMPILE_OCR_ERROR: bool = False From 2281aec8b97373bd02ba57f5b4a39cde853ec513 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Thu, 19 Dec 2024 11:01:58 -0500 Subject: [PATCH 09/12] Bump version and lockfile --- poetry.lock | 712 +++++++++++++++++++++++++------------------------ pyproject.toml | 2 +- 2 files changed, 358 insertions(+), 356 deletions(-) diff --git a/poetry.lock b/poetry.lock index f3369923..48dbc9fe 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -13,87 +13,87 @@ files = [ [[package]] name = "aiohttp" -version = "3.11.10" +version = "3.11.11" description = "Async http client/server framework (asyncio)" optional = false python-versions = ">=3.9" files = [ - {file = "aiohttp-3.11.10-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cbad88a61fa743c5d283ad501b01c153820734118b65aee2bd7dbb735475ce0d"}, - {file = "aiohttp-3.11.10-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:80886dac673ceaef499de2f393fc80bb4481a129e6cb29e624a12e3296cc088f"}, - {file = "aiohttp-3.11.10-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:61b9bae80ed1f338c42f57c16918853dc51775fb5cb61da70d590de14d8b5fb4"}, - {file = "aiohttp-3.11.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9e2e576caec5c6a6b93f41626c9c02fc87cd91538b81a3670b2e04452a63def6"}, - {file = "aiohttp-3.11.10-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:02c13415b5732fb6ee7ff64583a5e6ed1c57aa68f17d2bda79c04888dfdc2769"}, - {file = "aiohttp-3.11.10-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4cfce37f31f20800a6a6620ce2cdd6737b82e42e06e6e9bd1b36f546feb3c44f"}, - {file = "aiohttp-3.11.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bbbfff4c679c64e6e23cb213f57cc2c9165c9a65d63717108a644eb5a7398df"}, - {file = "aiohttp-3.11.10-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:49c7dbbc1a559ae14fc48387a115b7d4bbc84b4a2c3b9299c31696953c2a5219"}, - {file = "aiohttp-3.11.10-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:68386d78743e6570f054fe7949d6cb37ef2b672b4d3405ce91fafa996f7d9b4d"}, - {file = "aiohttp-3.11.10-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:9ef405356ba989fb57f84cac66f7b0260772836191ccefbb987f414bcd2979d9"}, - {file = "aiohttp-3.11.10-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:5d6958671b296febe7f5f859bea581a21c1d05430d1bbdcf2b393599b1cdce77"}, - {file = "aiohttp-3.11.10-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:99b7920e7165be5a9e9a3a7f1b680f06f68ff0d0328ff4079e5163990d046767"}, - {file = "aiohttp-3.11.10-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0dc49f42422163efb7e6f1df2636fe3db72713f6cd94688e339dbe33fe06d61d"}, - {file = "aiohttp-3.11.10-cp310-cp310-win32.whl", hash = "sha256:40d1c7a7f750b5648642586ba7206999650208dbe5afbcc5284bcec6579c9b91"}, - {file = "aiohttp-3.11.10-cp310-cp310-win_amd64.whl", hash = "sha256:68ff6f48b51bd78ea92b31079817aff539f6c8fc80b6b8d6ca347d7c02384e33"}, - {file = "aiohttp-3.11.10-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:77c4aa15a89847b9891abf97f3d4048f3c2d667e00f8a623c89ad2dccee6771b"}, - {file = "aiohttp-3.11.10-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:909af95a72cedbefe5596f0bdf3055740f96c1a4baa0dd11fd74ca4de0b4e3f1"}, - {file = "aiohttp-3.11.10-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:386fbe79863eb564e9f3615b959e28b222259da0c48fd1be5929ac838bc65683"}, - {file = "aiohttp-3.11.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3de34936eb1a647aa919655ff8d38b618e9f6b7f250cc19a57a4bf7fd2062b6d"}, - {file = "aiohttp-3.11.10-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0c9527819b29cd2b9f52033e7fb9ff08073df49b4799c89cb5754624ecd98299"}, - {file = "aiohttp-3.11.10-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65a96e3e03300b41f261bbfd40dfdbf1c301e87eab7cd61c054b1f2e7c89b9e8"}, - {file = "aiohttp-3.11.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98f5635f7b74bcd4f6f72fcd85bea2154b323a9f05226a80bc7398d0c90763b0"}, - {file = "aiohttp-3.11.10-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:03b6002e20938fc6ee0918c81d9e776bebccc84690e2b03ed132331cca065ee5"}, - {file = "aiohttp-3.11.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6362cc6c23c08d18ddbf0e8c4d5159b5df74fea1a5278ff4f2c79aed3f4e9f46"}, - {file = "aiohttp-3.11.10-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:3691ed7726fef54e928fe26344d930c0c8575bc968c3e239c2e1a04bd8cf7838"}, - {file = "aiohttp-3.11.10-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:31d5093d3acd02b31c649d3a69bb072d539d4c7659b87caa4f6d2bcf57c2fa2b"}, - {file = "aiohttp-3.11.10-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:8b3cf2dc0f0690a33f2d2b2cb15db87a65f1c609f53c37e226f84edb08d10f52"}, - {file = "aiohttp-3.11.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:fbbaea811a2bba171197b08eea288b9402faa2bab2ba0858eecdd0a4105753a3"}, - {file = "aiohttp-3.11.10-cp311-cp311-win32.whl", hash = "sha256:4b2c7ac59c5698a7a8207ba72d9e9c15b0fc484a560be0788b31312c2c5504e4"}, - {file = "aiohttp-3.11.10-cp311-cp311-win_amd64.whl", hash = "sha256:974d3a2cce5fcfa32f06b13ccc8f20c6ad9c51802bb7f829eae8a1845c4019ec"}, - {file = "aiohttp-3.11.10-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:b78f053a7ecfc35f0451d961dacdc671f4bcbc2f58241a7c820e9d82559844cf"}, - {file = "aiohttp-3.11.10-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ab7485222db0959a87fbe8125e233b5a6f01f4400785b36e8a7878170d8c3138"}, - {file = "aiohttp-3.11.10-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cf14627232dfa8730453752e9cdc210966490992234d77ff90bc8dc0dce361d5"}, - {file = "aiohttp-3.11.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:076bc454a7e6fd646bc82ea7f98296be0b1219b5e3ef8a488afbdd8e81fbac50"}, - {file = "aiohttp-3.11.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:482cafb7dc886bebeb6c9ba7925e03591a62ab34298ee70d3dd47ba966370d2c"}, - {file = "aiohttp-3.11.10-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bf3d1a519a324af764a46da4115bdbd566b3c73fb793ffb97f9111dbc684fc4d"}, - {file = "aiohttp-3.11.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24213ba85a419103e641e55c27dc7ff03536c4873470c2478cce3311ba1eee7b"}, - {file = "aiohttp-3.11.10-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b99acd4730ad1b196bfb03ee0803e4adac371ae8efa7e1cbc820200fc5ded109"}, - {file = "aiohttp-3.11.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:14cdb5a9570be5a04eec2ace174a48ae85833c2aadc86de68f55541f66ce42ab"}, - {file = "aiohttp-3.11.10-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:7e97d622cb083e86f18317282084bc9fbf261801b0192c34fe4b1febd9f7ae69"}, - {file = "aiohttp-3.11.10-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:012f176945af138abc10c4a48743327a92b4ca9adc7a0e078077cdb5dbab7be0"}, - {file = "aiohttp-3.11.10-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:44224d815853962f48fe124748227773acd9686eba6dc102578defd6fc99e8d9"}, - {file = "aiohttp-3.11.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c87bf31b7fdab94ae3adbe4a48e711bfc5f89d21cf4c197e75561def39e223bc"}, - {file = "aiohttp-3.11.10-cp312-cp312-win32.whl", hash = "sha256:06a8e2ee1cbac16fe61e51e0b0c269400e781b13bcfc33f5425912391a542985"}, - {file = "aiohttp-3.11.10-cp312-cp312-win_amd64.whl", hash = "sha256:be2b516f56ea883a3e14dda17059716593526e10fb6303189aaf5503937db408"}, - {file = "aiohttp-3.11.10-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8cc5203b817b748adccb07f36390feb730b1bc5f56683445bfe924fc270b8816"}, - {file = "aiohttp-3.11.10-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5ef359ebc6949e3a34c65ce20230fae70920714367c63afd80ea0c2702902ccf"}, - {file = "aiohttp-3.11.10-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9bca390cb247dbfaec3c664326e034ef23882c3f3bfa5fbf0b56cad0320aaca5"}, - {file = "aiohttp-3.11.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:811f23b3351ca532af598405db1093f018edf81368e689d1b508c57dcc6b6a32"}, - {file = "aiohttp-3.11.10-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddf5f7d877615f6a1e75971bfa5ac88609af3b74796ff3e06879e8422729fd01"}, - {file = "aiohttp-3.11.10-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6ab29b8a0beb6f8eaf1e5049252cfe74adbaafd39ba91e10f18caeb0e99ffb34"}, - {file = "aiohttp-3.11.10-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c49a76c1038c2dd116fa443eba26bbb8e6c37e924e2513574856de3b6516be99"}, - {file = "aiohttp-3.11.10-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f3dc0e330575f5b134918976a645e79adf333c0a1439dcf6899a80776c9ab39"}, - {file = "aiohttp-3.11.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:efb15a17a12497685304b2d976cb4939e55137df7b09fa53f1b6a023f01fcb4e"}, - {file = "aiohttp-3.11.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:db1d0b28fcb7f1d35600150c3e4b490775251dea70f894bf15c678fdd84eda6a"}, - {file = "aiohttp-3.11.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:15fccaf62a4889527539ecb86834084ecf6e9ea70588efde86e8bc775e0e7542"}, - {file = "aiohttp-3.11.10-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:593c114a2221444f30749cc5e5f4012488f56bd14de2af44fe23e1e9894a9c60"}, - {file = "aiohttp-3.11.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7852bbcb4d0d2f0c4d583f40c3bc750ee033265d80598d0f9cb6f372baa6b836"}, - {file = "aiohttp-3.11.10-cp313-cp313-win32.whl", hash = "sha256:65e55ca7debae8faaffee0ebb4b47a51b4075f01e9b641c31e554fd376595c6c"}, - {file = "aiohttp-3.11.10-cp313-cp313-win_amd64.whl", hash = "sha256:beb39a6d60a709ae3fb3516a1581777e7e8b76933bb88c8f4420d875bb0267c6"}, - {file = "aiohttp-3.11.10-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0580f2e12de2138f34debcd5d88894786453a76e98febaf3e8fe5db62d01c9bf"}, - {file = "aiohttp-3.11.10-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a55d2ad345684e7c3dd2c20d2f9572e9e1d5446d57200ff630e6ede7612e307f"}, - {file = "aiohttp-3.11.10-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:04814571cb72d65a6899db6099e377ed00710bf2e3eafd2985166f2918beaf59"}, - {file = "aiohttp-3.11.10-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e44a9a3c053b90c6f09b1bb4edd880959f5328cf63052503f892c41ea786d99f"}, - {file = "aiohttp-3.11.10-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:502a1464ccbc800b4b1995b302efaf426e8763fadf185e933c2931df7db9a199"}, - {file = "aiohttp-3.11.10-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:613e5169f8ae77b1933e42e418a95931fb4867b2991fc311430b15901ed67079"}, - {file = "aiohttp-3.11.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4cca22a61b7fe45da8fc73c3443150c3608750bbe27641fc7558ec5117b27fdf"}, - {file = "aiohttp-3.11.10-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:86a5dfcc39309470bd7b68c591d84056d195428d5d2e0b5ccadfbaf25b026ebc"}, - {file = "aiohttp-3.11.10-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:77ae58586930ee6b2b6f696c82cf8e78c8016ec4795c53e36718365f6959dc82"}, - {file = "aiohttp-3.11.10-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:78153314f26d5abef3239b4a9af20c229c6f3ecb97d4c1c01b22c4f87669820c"}, - {file = "aiohttp-3.11.10-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:98283b94cc0e11c73acaf1c9698dea80c830ca476492c0fe2622bd931f34b487"}, - {file = "aiohttp-3.11.10-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:53bf2097e05c2accc166c142a2090e4c6fd86581bde3fd9b2d3f9e93dda66ac1"}, - {file = "aiohttp-3.11.10-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c5532f0441fc09c119e1dca18fbc0687e64fbeb45aa4d6a87211ceaee50a74c4"}, - {file = "aiohttp-3.11.10-cp39-cp39-win32.whl", hash = "sha256:47ad15a65fb41c570cd0ad9a9ff8012489e68176e7207ec7b82a0940dddfd8be"}, - {file = "aiohttp-3.11.10-cp39-cp39-win_amd64.whl", hash = "sha256:c6b9e6d7e41656d78e37ce754813fa44b455c3d0d0dced2a047def7dc5570b74"}, - {file = "aiohttp-3.11.10.tar.gz", hash = "sha256:b1fc6b45010a8d0ff9e88f9f2418c6fd408c99c211257334aff41597ebece42e"}, + {file = "aiohttp-3.11.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a60804bff28662cbcf340a4d61598891f12eea3a66af48ecfdc975ceec21e3c8"}, + {file = "aiohttp-3.11.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4b4fa1cb5f270fb3eab079536b764ad740bb749ce69a94d4ec30ceee1b5940d5"}, + {file = "aiohttp-3.11.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:731468f555656767cda219ab42e033355fe48c85fbe3ba83a349631541715ba2"}, + {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb23d8bb86282b342481cad4370ea0853a39e4a32a0042bb52ca6bdde132df43"}, + {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f047569d655f81cb70ea5be942ee5d4421b6219c3f05d131f64088c73bb0917f"}, + {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd7659baae9ccf94ae5fe8bfaa2c7bc2e94d24611528395ce88d009107e00c6d"}, + {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af01e42ad87ae24932138f154105e88da13ce7d202a6de93fafdafb2883a00ef"}, + {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5854be2f3e5a729800bac57a8d76af464e160f19676ab6aea74bde18ad19d438"}, + {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6526e5fb4e14f4bbf30411216780c9967c20c5a55f2f51d3abd6de68320cc2f3"}, + {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:85992ee30a31835fc482468637b3e5bd085fa8fe9392ba0bdcbdc1ef5e9e3c55"}, + {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:88a12ad8ccf325a8a5ed80e6d7c3bdc247d66175afedbe104ee2aaca72960d8e"}, + {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:0a6d3fbf2232e3a08c41eca81ae4f1dff3d8f1a30bae415ebe0af2d2458b8a33"}, + {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:84a585799c58b795573c7fa9b84c455adf3e1d72f19a2bf498b54a95ae0d194c"}, + {file = "aiohttp-3.11.11-cp310-cp310-win32.whl", hash = "sha256:bfde76a8f430cf5c5584553adf9926534352251d379dcb266ad2b93c54a29745"}, + {file = "aiohttp-3.11.11-cp310-cp310-win_amd64.whl", hash = "sha256:0fd82b8e9c383af11d2b26f27a478640b6b83d669440c0a71481f7c865a51da9"}, + {file = "aiohttp-3.11.11-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ba74ec819177af1ef7f59063c6d35a214a8fde6f987f7661f4f0eecc468a8f76"}, + {file = "aiohttp-3.11.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4af57160800b7a815f3fe0eba9b46bf28aafc195555f1824555fa2cfab6c1538"}, + {file = "aiohttp-3.11.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ffa336210cf9cd8ed117011085817d00abe4c08f99968deef0013ea283547204"}, + {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81b8fe282183e4a3c7a1b72f5ade1094ed1c6345a8f153506d114af5bf8accd9"}, + {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3af41686ccec6a0f2bdc66686dc0f403c41ac2089f80e2214a0f82d001052c03"}, + {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:70d1f9dde0e5dd9e292a6d4d00058737052b01f3532f69c0c65818dac26dc287"}, + {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:249cc6912405917344192b9f9ea5cd5b139d49e0d2f5c7f70bdfaf6b4dbf3a2e"}, + {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0eb98d90b6690827dcc84c246811feeb4e1eea683c0eac6caed7549be9c84665"}, + {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ec82bf1fda6cecce7f7b915f9196601a1bd1a3079796b76d16ae4cce6d0ef89b"}, + {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:9fd46ce0845cfe28f108888b3ab17abff84ff695e01e73657eec3f96d72eef34"}, + {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:bd176afcf8f5d2aed50c3647d4925d0db0579d96f75a31e77cbaf67d8a87742d"}, + {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:ec2aa89305006fba9ffb98970db6c8221541be7bee4c1d027421d6f6df7d1ce2"}, + {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:92cde43018a2e17d48bb09c79e4d4cb0e236de5063ce897a5e40ac7cb4878773"}, + {file = "aiohttp-3.11.11-cp311-cp311-win32.whl", hash = "sha256:aba807f9569455cba566882c8938f1a549f205ee43c27b126e5450dc9f83cc62"}, + {file = "aiohttp-3.11.11-cp311-cp311-win_amd64.whl", hash = "sha256:ae545f31489548c87b0cced5755cfe5a5308d00407000e72c4fa30b19c3220ac"}, + {file = "aiohttp-3.11.11-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e595c591a48bbc295ebf47cb91aebf9bd32f3ff76749ecf282ea7f9f6bb73886"}, + {file = "aiohttp-3.11.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3ea1b59dc06396b0b424740a10a0a63974c725b1c64736ff788a3689d36c02d2"}, + {file = "aiohttp-3.11.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8811f3f098a78ffa16e0ea36dffd577eb031aea797cbdba81be039a4169e242c"}, + {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd7227b87a355ce1f4bf83bfae4399b1f5bb42e0259cb9405824bd03d2f4336a"}, + {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d40f9da8cabbf295d3a9dae1295c69975b86d941bc20f0a087f0477fa0a66231"}, + {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ffb3dc385f6bb1568aa974fe65da84723210e5d9707e360e9ecb51f59406cd2e"}, + {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8f5f7515f3552d899c61202d99dcb17d6e3b0de777900405611cd747cecd1b8"}, + {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3499c7ffbfd9c6a3d8d6a2b01c26639da7e43d47c7b4f788016226b1e711caa8"}, + {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8e2bf8029dbf0810c7bfbc3e594b51c4cc9101fbffb583a3923aea184724203c"}, + {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b6212a60e5c482ef90f2d788835387070a88d52cf6241d3916733c9176d39eab"}, + {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:d119fafe7b634dbfa25a8c597718e69a930e4847f0b88e172744be24515140da"}, + {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:6fba278063559acc730abf49845d0e9a9e1ba74f85f0ee6efd5803f08b285853"}, + {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:92fc484e34b733704ad77210c7957679c5c3877bd1e6b6d74b185e9320cc716e"}, + {file = "aiohttp-3.11.11-cp312-cp312-win32.whl", hash = "sha256:9f5b3c1ed63c8fa937a920b6c1bec78b74ee09593b3f5b979ab2ae5ef60d7600"}, + {file = "aiohttp-3.11.11-cp312-cp312-win_amd64.whl", hash = "sha256:1e69966ea6ef0c14ee53ef7a3d68b564cc408121ea56c0caa2dc918c1b2f553d"}, + {file = "aiohttp-3.11.11-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:541d823548ab69d13d23730a06f97460f4238ad2e5ed966aaf850d7c369782d9"}, + {file = "aiohttp-3.11.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:929f3ed33743a49ab127c58c3e0a827de0664bfcda566108989a14068f820194"}, + {file = "aiohttp-3.11.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0882c2820fd0132240edbb4a51eb8ceb6eef8181db9ad5291ab3332e0d71df5f"}, + {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b63de12e44935d5aca7ed7ed98a255a11e5cb47f83a9fded7a5e41c40277d104"}, + {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa54f8ef31d23c506910c21163f22b124facb573bff73930735cf9fe38bf7dff"}, + {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a344d5dc18074e3872777b62f5f7d584ae4344cd6006c17ba12103759d407af3"}, + {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b7fb429ab1aafa1f48578eb315ca45bd46e9c37de11fe45c7f5f4138091e2f1"}, + {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c341c7d868750e31961d6d8e60ff040fb9d3d3a46d77fd85e1ab8e76c3e9a5c4"}, + {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ed9ee95614a71e87f1a70bc81603f6c6760128b140bc4030abe6abaa988f1c3d"}, + {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:de8d38f1c2810fa2a4f1d995a2e9c70bb8737b18da04ac2afbf3971f65781d87"}, + {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:a9b7371665d4f00deb8f32208c7c5e652059b0fda41cf6dbcac6114a041f1cc2"}, + {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:620598717fce1b3bd14dd09947ea53e1ad510317c85dda2c9c65b622edc96b12"}, + {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:bf8d9bfee991d8acc72d060d53860f356e07a50f0e0d09a8dfedea1c554dd0d5"}, + {file = "aiohttp-3.11.11-cp313-cp313-win32.whl", hash = "sha256:9d73ee3725b7a737ad86c2eac5c57a4a97793d9f442599bea5ec67ac9f4bdc3d"}, + {file = "aiohttp-3.11.11-cp313-cp313-win_amd64.whl", hash = "sha256:c7a06301c2fb096bdb0bd25fe2011531c1453b9f2c163c8031600ec73af1cc99"}, + {file = "aiohttp-3.11.11-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3e23419d832d969f659c208557de4a123e30a10d26e1e14b73431d3c13444c2e"}, + {file = "aiohttp-3.11.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:21fef42317cf02e05d3b09c028712e1d73a9606f02467fd803f7c1f39cc59add"}, + {file = "aiohttp-3.11.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1f21bb8d0235fc10c09ce1d11ffbd40fc50d3f08a89e4cf3a0c503dc2562247a"}, + {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1642eceeaa5ab6c9b6dfeaaa626ae314d808188ab23ae196a34c9d97efb68350"}, + {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2170816e34e10f2fd120f603e951630f8a112e1be3b60963a1f159f5699059a6"}, + {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8be8508d110d93061197fd2d6a74f7401f73b6d12f8822bbcd6d74f2b55d71b1"}, + {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4eed954b161e6b9b65f6be446ed448ed3921763cc432053ceb606f89d793927e"}, + {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6c9af134da4bc9b3bd3e6a70072509f295d10ee60c697826225b60b9959acdd"}, + {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:44167fc6a763d534a6908bdb2592269b4bf30a03239bcb1654781adf5e49caf1"}, + {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:479b8c6ebd12aedfe64563b85920525d05d394b85f166b7873c8bde6da612f9c"}, + {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:10b4ff0ad793d98605958089fabfa350e8e62bd5d40aa65cdc69d6785859f94e"}, + {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:b540bd67cfb54e6f0865ceccd9979687210d7ed1a1cc8c01f8e67e2f1e883d28"}, + {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1dac54e8ce2ed83b1f6b1a54005c87dfed139cf3f777fdc8afc76e7841101226"}, + {file = "aiohttp-3.11.11-cp39-cp39-win32.whl", hash = "sha256:568c1236b2fde93b7720f95a890741854c1200fba4a3471ff48b2934d2d93fd3"}, + {file = "aiohttp-3.11.11-cp39-cp39-win_amd64.whl", hash = "sha256:943a8b052e54dfd6439fd7989f67fc6a7f2138d0a2cf0a7de5f18aa4fe7eb3b1"}, + {file = "aiohttp-3.11.11.tar.gz", hash = "sha256:bb49c7f1e6ebf3821a42d81d494f538107610c3a705987f53068546b0e90303e"}, ] [package.dependencies] @@ -111,13 +111,13 @@ speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] [[package]] name = "aiosignal" -version = "1.3.1" +version = "1.3.2" description = "aiosignal: a list of registered asynchronous callbacks" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" files = [ - {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, - {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, + {file = "aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5"}, + {file = "aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54"}, ] [package.dependencies] @@ -323,19 +323,19 @@ files = [ [[package]] name = "attrs" -version = "24.2.0" +version = "24.3.0" description = "Classes Without Boilerplate" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, - {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, + {file = "attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308"}, + {file = "attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff"}, ] [package.extras] benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] @@ -416,13 +416,13 @@ files = [ [[package]] name = "certifi" -version = "2024.8.30" +version = "2024.12.14" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8"}, - {file = "certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9"}, + {file = "certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56"}, + {file = "certifi-2024.12.14.tar.gz", hash = "sha256:b650d30f370c2b724812bee08008be0c4163b163ddaec3f2546c1caf65f191db"}, ] [[package]] @@ -706,37 +706,37 @@ vision = ["Pillow (>=9.4.0)"] [[package]] name = "debugpy" -version = "1.8.9" +version = "1.8.11" description = "An implementation of the Debug Adapter Protocol for Python" optional = false python-versions = ">=3.8" files = [ - {file = "debugpy-1.8.9-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:cfe1e6c6ad7178265f74981edf1154ffce97b69005212fbc90ca22ddfe3d017e"}, - {file = "debugpy-1.8.9-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada7fb65102a4d2c9ab62e8908e9e9f12aed9d76ef44880367bc9308ebe49a0f"}, - {file = "debugpy-1.8.9-cp310-cp310-win32.whl", hash = "sha256:c36856343cbaa448171cba62a721531e10e7ffb0abff838004701454149bc037"}, - {file = "debugpy-1.8.9-cp310-cp310-win_amd64.whl", hash = "sha256:17c5e0297678442511cf00a745c9709e928ea4ca263d764e90d233208889a19e"}, - {file = "debugpy-1.8.9-cp311-cp311-macosx_14_0_universal2.whl", hash = "sha256:b74a49753e21e33e7cf030883a92fa607bddc4ede1aa4145172debc637780040"}, - {file = "debugpy-1.8.9-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:62d22dacdb0e296966d7d74a7141aaab4bec123fa43d1a35ddcb39bf9fd29d70"}, - {file = "debugpy-1.8.9-cp311-cp311-win32.whl", hash = "sha256:8138efff315cd09b8dcd14226a21afda4ca582284bf4215126d87342bba1cc66"}, - {file = "debugpy-1.8.9-cp311-cp311-win_amd64.whl", hash = "sha256:ff54ef77ad9f5c425398efb150239f6fe8e20c53ae2f68367eba7ece1e96226d"}, - {file = "debugpy-1.8.9-cp312-cp312-macosx_14_0_universal2.whl", hash = "sha256:957363d9a7a6612a37458d9a15e72d03a635047f946e5fceee74b50d52a9c8e2"}, - {file = "debugpy-1.8.9-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e565fc54b680292b418bb809f1386f17081d1346dca9a871bf69a8ac4071afe"}, - {file = "debugpy-1.8.9-cp312-cp312-win32.whl", hash = "sha256:3e59842d6c4569c65ceb3751075ff8d7e6a6ada209ceca6308c9bde932bcef11"}, - {file = "debugpy-1.8.9-cp312-cp312-win_amd64.whl", hash = "sha256:66eeae42f3137eb428ea3a86d4a55f28da9bd5a4a3d369ba95ecc3a92c1bba53"}, - {file = "debugpy-1.8.9-cp313-cp313-macosx_14_0_universal2.whl", hash = "sha256:957ecffff80d47cafa9b6545de9e016ae8c9547c98a538ee96ab5947115fb3dd"}, - {file = "debugpy-1.8.9-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1efbb3ff61487e2c16b3e033bc8595aea578222c08aaf3c4bf0f93fadbd662ee"}, - {file = "debugpy-1.8.9-cp313-cp313-win32.whl", hash = "sha256:7c4d65d03bee875bcb211c76c1d8f10f600c305dbd734beaed4077e902606fee"}, - {file = "debugpy-1.8.9-cp313-cp313-win_amd64.whl", hash = "sha256:e46b420dc1bea64e5bbedd678148be512442bc589b0111bd799367cde051e71a"}, - {file = "debugpy-1.8.9-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:472a3994999fe6c0756945ffa359e9e7e2d690fb55d251639d07208dbc37caea"}, - {file = "debugpy-1.8.9-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:365e556a4772d7d0d151d7eb0e77ec4db03bcd95f26b67b15742b88cacff88e9"}, - {file = "debugpy-1.8.9-cp38-cp38-win32.whl", hash = "sha256:54a7e6d3014c408eb37b0b06021366ee985f1539e12fe49ca2ee0d392d9ceca5"}, - {file = "debugpy-1.8.9-cp38-cp38-win_amd64.whl", hash = "sha256:8e99c0b1cc7bf86d83fb95d5ccdc4ad0586d4432d489d1f54e4055bcc795f693"}, - {file = "debugpy-1.8.9-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:7e8b079323a56f719977fde9d8115590cb5e7a1cba2fcee0986ef8817116e7c1"}, - {file = "debugpy-1.8.9-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6953b335b804a41f16a192fa2e7851bdcfd92173cbb2f9f777bb934f49baab65"}, - {file = "debugpy-1.8.9-cp39-cp39-win32.whl", hash = "sha256:7e646e62d4602bb8956db88b1e72fe63172148c1e25c041e03b103a25f36673c"}, - {file = "debugpy-1.8.9-cp39-cp39-win_amd64.whl", hash = "sha256:3d9755e77a2d680ce3d2c5394a444cf42be4a592caaf246dbfbdd100ffcf7ae5"}, - {file = "debugpy-1.8.9-py2.py3-none-any.whl", hash = "sha256:cc37a6c9987ad743d9c3a14fa1b1a14b7e4e6041f9dd0c8abf8895fe7a97b899"}, - {file = "debugpy-1.8.9.zip", hash = "sha256:1339e14c7d980407248f09824d1b25ff5c5616651689f1e0f0e51bdead3ea13e"}, + {file = "debugpy-1.8.11-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:2b26fefc4e31ff85593d68b9022e35e8925714a10ab4858fb1b577a8a48cb8cd"}, + {file = "debugpy-1.8.11-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61bc8b3b265e6949855300e84dc93d02d7a3a637f2aec6d382afd4ceb9120c9f"}, + {file = "debugpy-1.8.11-cp310-cp310-win32.whl", hash = "sha256:c928bbf47f65288574b78518449edaa46c82572d340e2750889bbf8cd92f3737"}, + {file = "debugpy-1.8.11-cp310-cp310-win_amd64.whl", hash = "sha256:8da1db4ca4f22583e834dcabdc7832e56fe16275253ee53ba66627b86e304da1"}, + {file = "debugpy-1.8.11-cp311-cp311-macosx_14_0_universal2.whl", hash = "sha256:85de8474ad53ad546ff1c7c7c89230db215b9b8a02754d41cb5a76f70d0be296"}, + {file = "debugpy-1.8.11-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ffc382e4afa4aee367bf413f55ed17bd91b191dcaf979890af239dda435f2a1"}, + {file = "debugpy-1.8.11-cp311-cp311-win32.whl", hash = "sha256:40499a9979c55f72f4eb2fc38695419546b62594f8af194b879d2a18439c97a9"}, + {file = "debugpy-1.8.11-cp311-cp311-win_amd64.whl", hash = "sha256:987bce16e86efa86f747d5151c54e91b3c1e36acc03ce1ddb50f9d09d16ded0e"}, + {file = "debugpy-1.8.11-cp312-cp312-macosx_14_0_universal2.whl", hash = "sha256:84e511a7545d11683d32cdb8f809ef63fc17ea2a00455cc62d0a4dbb4ed1c308"}, + {file = "debugpy-1.8.11-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce291a5aca4985d82875d6779f61375e959208cdf09fcec40001e65fb0a54768"}, + {file = "debugpy-1.8.11-cp312-cp312-win32.whl", hash = "sha256:28e45b3f827d3bf2592f3cf7ae63282e859f3259db44ed2b129093ca0ac7940b"}, + {file = "debugpy-1.8.11-cp312-cp312-win_amd64.whl", hash = "sha256:44b1b8e6253bceada11f714acf4309ffb98bfa9ac55e4fce14f9e5d4484287a1"}, + {file = "debugpy-1.8.11-cp313-cp313-macosx_14_0_universal2.whl", hash = "sha256:8988f7163e4381b0da7696f37eec7aca19deb02e500245df68a7159739bbd0d3"}, + {file = "debugpy-1.8.11-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c1f6a173d1140e557347419767d2b14ac1c9cd847e0b4c5444c7f3144697e4e"}, + {file = "debugpy-1.8.11-cp313-cp313-win32.whl", hash = "sha256:bb3b15e25891f38da3ca0740271e63ab9db61f41d4d8541745cfc1824252cb28"}, + {file = "debugpy-1.8.11-cp313-cp313-win_amd64.whl", hash = "sha256:d8768edcbeb34da9e11bcb8b5c2e0958d25218df7a6e56adf415ef262cd7b6d1"}, + {file = "debugpy-1.8.11-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:ad7efe588c8f5cf940f40c3de0cd683cc5b76819446abaa50dc0829a30c094db"}, + {file = "debugpy-1.8.11-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:189058d03a40103a57144752652b3ab08ff02b7595d0ce1f651b9acc3a3a35a0"}, + {file = "debugpy-1.8.11-cp38-cp38-win32.whl", hash = "sha256:32db46ba45849daed7ccf3f2e26f7a386867b077f39b2a974bb5c4c2c3b0a280"}, + {file = "debugpy-1.8.11-cp38-cp38-win_amd64.whl", hash = "sha256:116bf8342062246ca749013df4f6ea106f23bc159305843491f64672a55af2e5"}, + {file = "debugpy-1.8.11-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:654130ca6ad5de73d978057eaf9e582244ff72d4574b3e106fb8d3d2a0d32458"}, + {file = "debugpy-1.8.11-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23dc34c5e03b0212fa3c49a874df2b8b1b8fda95160bd79c01eb3ab51ea8d851"}, + {file = "debugpy-1.8.11-cp39-cp39-win32.whl", hash = "sha256:52d8a3166c9f2815bfae05f386114b0b2d274456980d41f320299a8d9a5615a7"}, + {file = "debugpy-1.8.11-cp39-cp39-win_amd64.whl", hash = "sha256:52c3cf9ecda273a19cc092961ee34eb9ba8687d67ba34cc7b79a521c1c64c4c0"}, + {file = "debugpy-1.8.11-py2.py3-none-any.whl", hash = "sha256:0e22f846f4211383e6a416d04b4c13ed174d24cc5d43f5fd52e7821d0ebc8920"}, + {file = "debugpy-1.8.11.tar.gz", hash = "sha256:6ad2688b69235c43b020e04fecccdf6a96c8943ca9c2fb340b8adc103c655e57"}, ] [[package]] @@ -1189,13 +1189,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "huggingface-hub" -version = "0.26.5" +version = "0.27.0" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.26.5-py3-none-any.whl", hash = "sha256:fb7386090bbe892072e64b85f7c4479fd2d65eea5f2543327c970d5169e83924"}, - {file = "huggingface_hub-0.26.5.tar.gz", hash = "sha256:1008bd18f60bfb65e8dbc0a97249beeeaa8c99d3c2fa649354df9fa5a13ed83b"}, + {file = "huggingface_hub-0.27.0-py3-none-any.whl", hash = "sha256:8f2e834517f1f1ddf1ecc716f91b120d7333011b7485f665a9a412eacb1a2a81"}, + {file = "huggingface_hub-0.27.0.tar.gz", hash = "sha256:902cce1a1be5739f5589e560198a65a8edcfd3b830b1666f36e4b961f0454fac"}, ] [package.dependencies] @@ -1532,13 +1532,13 @@ test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout" [[package]] name = "jupyter-events" -version = "0.10.0" +version = "0.11.0" description = "Jupyter Event System library" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "jupyter_events-0.10.0-py3-none-any.whl", hash = "sha256:4b72130875e59d57716d327ea70d3ebc3af1944d3717e5a498b8a06c6c159960"}, - {file = "jupyter_events-0.10.0.tar.gz", hash = "sha256:670b8229d3cc882ec782144ed22e0d29e1c2d639263f92ca8383e66682845e22"}, + {file = "jupyter_events-0.11.0-py3-none-any.whl", hash = "sha256:36399b41ce1ca45fe8b8271067d6a140ffa54cec4028e95491c93b78a855cacf"}, + {file = "jupyter_events-0.11.0.tar.gz", hash = "sha256:c0bc56a37aac29c1fbc3bcfbddb8c8c49533f9cf11f1c4e6adadba936574ab90"}, ] [package.dependencies] @@ -1552,7 +1552,7 @@ traitlets = ">=5.3" [package.extras] cli = ["click", "rich"] -docs = ["jupyterlite-sphinx", "myst-parser", "pydata-sphinx-theme", "sphinxcontrib-spelling"] +docs = ["jupyterlite-sphinx", "myst-parser", "pydata-sphinx-theme (>=0.16)", "sphinx (>=8)", "sphinxcontrib-spelling"] test = ["click", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>=0.19.0)", "pytest-console-scripts", "rich"] [[package]] @@ -1626,13 +1626,13 @@ test = ["jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-jupyter[server] (> [[package]] name = "jupyterlab" -version = "4.3.3" +version = "4.3.4" description = "JupyterLab computational environment" optional = false python-versions = ">=3.8" files = [ - {file = "jupyterlab-4.3.3-py3-none-any.whl", hash = "sha256:32a8fd30677e734ffcc3916a4758b9dab21b02015b668c60eb36f84357b7d4b1"}, - {file = "jupyterlab-4.3.3.tar.gz", hash = "sha256:76fa39e548fdac94dc1204af5956c556f54c785f70ee26aa47ea08eda4d5bbcd"}, + {file = "jupyterlab-4.3.4-py3-none-any.whl", hash = "sha256:b754c2601c5be6adf87cb5a1d8495d653ffb945f021939f77776acaa94dae952"}, + {file = "jupyterlab-4.3.4.tar.gz", hash = "sha256:f0bb9b09a04766e3423cccc2fc23169aa2ffedcdf8713e9e0fb33cac0b6859d0"}, ] [package.dependencies] @@ -1982,18 +1982,20 @@ dill = ">=0.3.8" [[package]] name = "narwhals" -version = "1.17.0" +version = "1.19.0" description = "Extremely lightweight compatibility layer between dataframe libraries" optional = false python-versions = ">=3.8" files = [ - {file = "narwhals-1.17.0-py3-none-any.whl", hash = "sha256:03772e08a6e70d6a5029b2299fa4d4acc51411cca5119afffd6bbc4c762038f4"}, - {file = "narwhals-1.17.0.tar.gz", hash = "sha256:a2607fa66f0b6ccf6f54df57ee6bc4967b68095d2968c828894a5f77ed9d3c50"}, + {file = "narwhals-1.19.0-py3-none-any.whl", hash = "sha256:517eca140103dbf61e4513fe462885a06bc21b565521a5ac0b79a7e31f152efe"}, + {file = "narwhals-1.19.0.tar.gz", hash = "sha256:a1a03bf922548ed1da5426acc954327a94c02c3a08558b65a0937dbd3b77fb48"}, ] [package.extras] cudf = ["cudf (>=23.08.00)"] dask = ["dask[dataframe] (>=2024.7)"] +dev = ["covdefaults", "dask[dataframe]", "duckdb", "hypothesis", "hypothesis[numpy]", "pandas", "polars", "pyarrow", "pyarrow-stubs", "pyspark", "pytest", "pytest-cov", "pytest-env", "pytest-randomly", "scikit-learn", "tqdm", "typing-extensions"] +docs = ["black", "duckdb", "jinja2", "markdown-exec[ansi]", "mkdocs", "mkdocs-autorefs", "mkdocs-material", "mkdocstrings[python]", "pandas", "polars (>=1.0.0)", "pyarrow"] modin = ["modin"] pandas = ["pandas (>=0.25.3)"] polars = ["polars (>=0.20.3)"] @@ -2002,13 +2004,13 @@ pyspark = ["pyspark (>=3.3.0)"] [[package]] name = "nbclient" -version = "0.10.1" +version = "0.10.2" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." optional = false -python-versions = ">=3.8.0" +python-versions = ">=3.9.0" files = [ - {file = "nbclient-0.10.1-py3-none-any.whl", hash = "sha256:949019b9240d66897e442888cfb618f69ef23dc71c01cb5fced8499c2cfc084d"}, - {file = "nbclient-0.10.1.tar.gz", hash = "sha256:3e93e348ab27e712acd46fccd809139e356eb9a31aab641d1a7991a6eb4e6f68"}, + {file = "nbclient-0.10.2-py3-none-any.whl", hash = "sha256:4ffee11e788b4a27fabeb7955547e4318a5298f34342a4bfd01f2e1faaeadc3d"}, + {file = "nbclient-0.10.2.tar.gz", hash = "sha256:90b7fc6b810630db87a6d0c2250b1f0ab4cf4d3c27a299b0cde78a4ed3fd9193"}, ] [package.dependencies] @@ -2019,8 +2021,8 @@ traitlets = ">=5.4" [package.extras] dev = ["pre-commit"] -docs = ["autodoc-traits", "flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "mock", "moto", "myst-parser", "nbconvert (>=7.0.0)", "pytest (>=7.0,<8)", "pytest-asyncio", "pytest-cov (>=4.0)", "sphinx (>=1.7)", "sphinx-book-theme", "sphinxcontrib-spelling", "testpath", "xmltodict"] -test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>=7.0,<8)", "pytest-asyncio", "pytest-cov (>=4.0)", "testpath", "xmltodict"] +docs = ["autodoc-traits", "flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "mock", "moto", "myst-parser", "nbconvert (>=7.1.0)", "pytest (>=7.0,<8)", "pytest-asyncio", "pytest-cov (>=4.0)", "sphinx (>=1.7)", "sphinx-book-theme", "sphinxcontrib-spelling", "testpath", "xmltodict"] +test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>=7.1.0)", "pytest (>=7.0,<8)", "pytest-asyncio", "pytest-cov (>=4.0)", "testpath", "xmltodict"] [[package]] name = "nbconvert" @@ -2811,22 +2813,22 @@ files = [ [[package]] name = "protobuf" -version = "5.29.1" +version = "5.29.2" description = "" optional = false python-versions = ">=3.8" files = [ - {file = "protobuf-5.29.1-cp310-abi3-win32.whl", hash = "sha256:22c1f539024241ee545cbcb00ee160ad1877975690b16656ff87dde107b5f110"}, - {file = "protobuf-5.29.1-cp310-abi3-win_amd64.whl", hash = "sha256:1fc55267f086dd4050d18ef839d7bd69300d0d08c2a53ca7df3920cc271a3c34"}, - {file = "protobuf-5.29.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:d473655e29c0c4bbf8b69e9a8fb54645bc289dead6d753b952e7aa660254ae18"}, - {file = "protobuf-5.29.1-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:b5ba1d0e4c8a40ae0496d0e2ecfdbb82e1776928a205106d14ad6985a09ec155"}, - {file = "protobuf-5.29.1-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:8ee1461b3af56145aca2800e6a3e2f928108c749ba8feccc6f5dd0062c410c0d"}, - {file = "protobuf-5.29.1-cp38-cp38-win32.whl", hash = "sha256:50879eb0eb1246e3a5eabbbe566b44b10348939b7cc1b267567e8c3d07213853"}, - {file = "protobuf-5.29.1-cp38-cp38-win_amd64.whl", hash = "sha256:027fbcc48cea65a6b17028510fdd054147057fa78f4772eb547b9274e5219331"}, - {file = "protobuf-5.29.1-cp39-cp39-win32.whl", hash = "sha256:5a41deccfa5e745cef5c65a560c76ec0ed8e70908a67cc8f4da5fce588b50d57"}, - {file = "protobuf-5.29.1-cp39-cp39-win_amd64.whl", hash = "sha256:012ce28d862ff417fd629285aca5d9772807f15ceb1a0dbd15b88f58c776c98c"}, - {file = "protobuf-5.29.1-py3-none-any.whl", hash = "sha256:32600ddb9c2a53dedc25b8581ea0f1fd8ea04956373c0c07577ce58d312522e0"}, - {file = "protobuf-5.29.1.tar.gz", hash = "sha256:683be02ca21a6ffe80db6dd02c0b5b2892322c59ca57fd6c872d652cb80549cb"}, + {file = "protobuf-5.29.2-cp310-abi3-win32.whl", hash = "sha256:c12ba8249f5624300cf51c3d0bfe5be71a60c63e4dcf51ffe9a68771d958c851"}, + {file = "protobuf-5.29.2-cp310-abi3-win_amd64.whl", hash = "sha256:842de6d9241134a973aab719ab42b008a18a90f9f07f06ba480df268f86432f9"}, + {file = "protobuf-5.29.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a0c53d78383c851bfa97eb42e3703aefdc96d2036a41482ffd55dc5f529466eb"}, + {file = "protobuf-5.29.2-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:494229ecd8c9009dd71eda5fd57528395d1eacdf307dbece6c12ad0dd09e912e"}, + {file = "protobuf-5.29.2-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:b6b0d416bbbb9d4fbf9d0561dbfc4e324fd522f61f7af0fe0f282ab67b22477e"}, + {file = "protobuf-5.29.2-cp38-cp38-win32.whl", hash = "sha256:e621a98c0201a7c8afe89d9646859859be97cb22b8bf1d8eacfd90d5bda2eb19"}, + {file = "protobuf-5.29.2-cp38-cp38-win_amd64.whl", hash = "sha256:13d6d617a2a9e0e82a88113d7191a1baa1e42c2cc6f5f1398d3b054c8e7e714a"}, + {file = "protobuf-5.29.2-cp39-cp39-win32.whl", hash = "sha256:36000f97ea1e76e8398a3f02936aac2a5d2b111aae9920ec1b769fc4a222c4d9"}, + {file = "protobuf-5.29.2-cp39-cp39-win_amd64.whl", hash = "sha256:2d2e674c58a06311c8e99e74be43e7f3a8d1e2b2fdf845eaa347fbd866f23355"}, + {file = "protobuf-5.29.2-py3-none-any.whl", hash = "sha256:fde4554c0e578a5a0bcc9a276339594848d1e89f9ea47b4427c80e5d72f90181"}, + {file = "protobuf-5.29.2.tar.gz", hash = "sha256:b2cc8e8bb7c9326996f0e160137b0861f1a82162502658df2951209d0cb0309e"}, ] [[package]] @@ -2951,18 +2953,18 @@ files = [ [[package]] name = "pydantic" -version = "2.10.3" +version = "2.10.4" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.10.3-py3-none-any.whl", hash = "sha256:be04d85bbc7b65651c5f8e6b9976ed9c6f41782a55524cef079a34a0bb82144d"}, - {file = "pydantic-2.10.3.tar.gz", hash = "sha256:cb5ac360ce894ceacd69c403187900a02c4b20b693a9dd1d643e1effab9eadf9"}, + {file = "pydantic-2.10.4-py3-none-any.whl", hash = "sha256:597e135ea68be3a37552fb524bc7d0d66dcf93d395acd93a00682f1efcb8ee3d"}, + {file = "pydantic-2.10.4.tar.gz", hash = "sha256:82f12e9723da6de4fe2ba888b5971157b3be7ad914267dea8f05f82b28254f06"}, ] [package.dependencies] annotated-types = ">=0.6.0" -pydantic-core = "2.27.1" +pydantic-core = "2.27.2" typing-extensions = ">=4.12.2" [package.extras] @@ -2971,111 +2973,111 @@ timezone = ["tzdata"] [[package]] name = "pydantic-core" -version = "2.27.1" +version = "2.27.2" description = "Core functionality for Pydantic validation and serialization" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic_core-2.27.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:71a5e35c75c021aaf400ac048dacc855f000bdfed91614b4a726f7432f1f3d6a"}, - {file = "pydantic_core-2.27.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f82d068a2d6ecfc6e054726080af69a6764a10015467d7d7b9f66d6ed5afa23b"}, - {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:121ceb0e822f79163dd4699e4c54f5ad38b157084d97b34de8b232bcaad70278"}, - {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4603137322c18eaf2e06a4495f426aa8d8388940f3c457e7548145011bb68e05"}, - {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a33cd6ad9017bbeaa9ed78a2e0752c5e250eafb9534f308e7a5f7849b0b1bfb4"}, - {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15cc53a3179ba0fcefe1e3ae50beb2784dede4003ad2dfd24f81bba4b23a454f"}, - {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45d9c5eb9273aa50999ad6adc6be5e0ecea7e09dbd0d31bd0c65a55a2592ca08"}, - {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8bf7b66ce12a2ac52d16f776b31d16d91033150266eb796967a7e4621707e4f6"}, - {file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:655d7dd86f26cb15ce8a431036f66ce0318648f8853d709b4167786ec2fa4807"}, - {file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:5556470f1a2157031e676f776c2bc20acd34c1990ca5f7e56f1ebf938b9ab57c"}, - {file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f69ed81ab24d5a3bd93861c8c4436f54afdf8e8cc421562b0c7504cf3be58206"}, - {file = "pydantic_core-2.27.1-cp310-none-win32.whl", hash = "sha256:f5a823165e6d04ccea61a9f0576f345f8ce40ed533013580e087bd4d7442b52c"}, - {file = "pydantic_core-2.27.1-cp310-none-win_amd64.whl", hash = "sha256:57866a76e0b3823e0b56692d1a0bf722bffb324839bb5b7226a7dbd6c9a40b17"}, - {file = "pydantic_core-2.27.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac3b20653bdbe160febbea8aa6c079d3df19310d50ac314911ed8cc4eb7f8cb8"}, - {file = "pydantic_core-2.27.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a5a8e19d7c707c4cadb8c18f5f60c843052ae83c20fa7d44f41594c644a1d330"}, - {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f7059ca8d64fea7f238994c97d91f75965216bcbe5f695bb44f354893f11d52"}, - {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bed0f8a0eeea9fb72937ba118f9db0cb7e90773462af7962d382445f3005e5a4"}, - {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a3cb37038123447cf0f3ea4c74751f6a9d7afef0eb71aa07bf5f652b5e6a132c"}, - {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84286494f6c5d05243456e04223d5a9417d7f443c3b76065e75001beb26f88de"}, - {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:acc07b2cfc5b835444b44a9956846b578d27beeacd4b52e45489e93276241025"}, - {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4fefee876e07a6e9aad7a8c8c9f85b0cdbe7df52b8a9552307b09050f7512c7e"}, - {file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:258c57abf1188926c774a4c94dd29237e77eda19462e5bb901d88adcab6af919"}, - {file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:35c14ac45fcfdf7167ca76cc80b2001205a8d5d16d80524e13508371fb8cdd9c"}, - {file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d1b26e1dff225c31897696cab7d4f0a315d4c0d9e8666dbffdb28216f3b17fdc"}, - {file = "pydantic_core-2.27.1-cp311-none-win32.whl", hash = "sha256:2cdf7d86886bc6982354862204ae3b2f7f96f21a3eb0ba5ca0ac42c7b38598b9"}, - {file = "pydantic_core-2.27.1-cp311-none-win_amd64.whl", hash = "sha256:3af385b0cee8df3746c3f406f38bcbfdc9041b5c2d5ce3e5fc6637256e60bbc5"}, - {file = "pydantic_core-2.27.1-cp311-none-win_arm64.whl", hash = "sha256:81f2ec23ddc1b476ff96563f2e8d723830b06dceae348ce02914a37cb4e74b89"}, - {file = "pydantic_core-2.27.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9cbd94fc661d2bab2bc702cddd2d3370bbdcc4cd0f8f57488a81bcce90c7a54f"}, - {file = "pydantic_core-2.27.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5f8c4718cd44ec1580e180cb739713ecda2bdee1341084c1467802a417fe0f02"}, - {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15aae984e46de8d376df515f00450d1522077254ef6b7ce189b38ecee7c9677c"}, - {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ba5e3963344ff25fc8c40da90f44b0afca8cfd89d12964feb79ac1411a260ac"}, - {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:992cea5f4f3b29d6b4f7f1726ed8ee46c8331c6b4eed6db5b40134c6fe1768bb"}, - {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0325336f348dbee6550d129b1627cb8f5351a9dc91aad141ffb96d4937bd9529"}, - {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7597c07fbd11515f654d6ece3d0e4e5093edc30a436c63142d9a4b8e22f19c35"}, - {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3bbd5d8cc692616d5ef6fbbbd50dbec142c7e6ad9beb66b78a96e9c16729b089"}, - {file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:dc61505e73298a84a2f317255fcc72b710b72980f3a1f670447a21efc88f8381"}, - {file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:e1f735dc43da318cad19b4173dd1ffce1d84aafd6c9b782b3abc04a0d5a6f5bb"}, - {file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f4e5658dbffe8843a0f12366a4c2d1c316dbe09bb4dfbdc9d2d9cd6031de8aae"}, - {file = "pydantic_core-2.27.1-cp312-none-win32.whl", hash = "sha256:672ebbe820bb37988c4d136eca2652ee114992d5d41c7e4858cdd90ea94ffe5c"}, - {file = "pydantic_core-2.27.1-cp312-none-win_amd64.whl", hash = "sha256:66ff044fd0bb1768688aecbe28b6190f6e799349221fb0de0e6f4048eca14c16"}, - {file = "pydantic_core-2.27.1-cp312-none-win_arm64.whl", hash = "sha256:9a3b0793b1bbfd4146304e23d90045f2a9b5fd5823aa682665fbdaf2a6c28f3e"}, - {file = "pydantic_core-2.27.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f216dbce0e60e4d03e0c4353c7023b202d95cbaeff12e5fd2e82ea0a66905073"}, - {file = "pydantic_core-2.27.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a2e02889071850bbfd36b56fd6bc98945e23670773bc7a76657e90e6b6603c08"}, - {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42b0e23f119b2b456d07ca91b307ae167cc3f6c846a7b169fca5326e32fdc6cf"}, - {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:764be71193f87d460a03f1f7385a82e226639732214b402f9aa61f0d025f0737"}, - {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c00666a3bd2f84920a4e94434f5974d7bbc57e461318d6bb34ce9cdbbc1f6b2"}, - {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ccaa88b24eebc0f849ce0a4d09e8a408ec5a94afff395eb69baf868f5183107"}, - {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c65af9088ac534313e1963443d0ec360bb2b9cba6c2909478d22c2e363d98a51"}, - {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:206b5cf6f0c513baffaeae7bd817717140770c74528f3e4c3e1cec7871ddd61a"}, - {file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:062f60e512fc7fff8b8a9d680ff0ddaaef0193dba9fa83e679c0c5f5fbd018bc"}, - {file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:a0697803ed7d4af5e4c1adf1670af078f8fcab7a86350e969f454daf598c4960"}, - {file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:58ca98a950171f3151c603aeea9303ef6c235f692fe555e883591103da709b23"}, - {file = "pydantic_core-2.27.1-cp313-none-win32.whl", hash = "sha256:8065914ff79f7eab1599bd80406681f0ad08f8e47c880f17b416c9f8f7a26d05"}, - {file = "pydantic_core-2.27.1-cp313-none-win_amd64.whl", hash = "sha256:ba630d5e3db74c79300d9a5bdaaf6200172b107f263c98a0539eeecb857b2337"}, - {file = "pydantic_core-2.27.1-cp313-none-win_arm64.whl", hash = "sha256:45cf8588c066860b623cd11c4ba687f8d7175d5f7ef65f7129df8a394c502de5"}, - {file = "pydantic_core-2.27.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:5897bec80a09b4084aee23f9b73a9477a46c3304ad1d2d07acca19723fb1de62"}, - {file = "pydantic_core-2.27.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d0165ab2914379bd56908c02294ed8405c252250668ebcb438a55494c69f44ab"}, - {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b9af86e1d8e4cfc82c2022bfaa6f459381a50b94a29e95dcdda8442d6d83864"}, - {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f6c8a66741c5f5447e047ab0ba7a1c61d1e95580d64bce852e3df1f895c4067"}, - {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a42d6a8156ff78981f8aa56eb6394114e0dedb217cf8b729f438f643608cbcd"}, - {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64c65f40b4cd8b0e049a8edde07e38b476da7e3aaebe63287c899d2cff253fa5"}, - {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdcf339322a3fae5cbd504edcefddd5a50d9ee00d968696846f089b4432cf78"}, - {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bf99c8404f008750c846cb4ac4667b798a9f7de673ff719d705d9b2d6de49c5f"}, - {file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8f1edcea27918d748c7e5e4d917297b2a0ab80cad10f86631e488b7cddf76a36"}, - {file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_armv7l.whl", hash = "sha256:159cac0a3d096f79ab6a44d77a961917219707e2a130739c64d4dd46281f5c2a"}, - {file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:029d9757eb621cc6e1848fa0b0310310de7301057f623985698ed7ebb014391b"}, - {file = "pydantic_core-2.27.1-cp38-none-win32.whl", hash = "sha256:a28af0695a45f7060e6f9b7092558a928a28553366519f64083c63a44f70e618"}, - {file = "pydantic_core-2.27.1-cp38-none-win_amd64.whl", hash = "sha256:2d4567c850905d5eaaed2f7a404e61012a51caf288292e016360aa2b96ff38d4"}, - {file = "pydantic_core-2.27.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:e9386266798d64eeb19dd3677051f5705bf873e98e15897ddb7d76f477131967"}, - {file = "pydantic_core-2.27.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4228b5b646caa73f119b1ae756216b59cc6e2267201c27d3912b592c5e323b60"}, - {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b3dfe500de26c52abe0477dde16192ac39c98f05bf2d80e76102d394bd13854"}, - {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:aee66be87825cdf72ac64cb03ad4c15ffef4143dbf5c113f64a5ff4f81477bf9"}, - {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b748c44bb9f53031c8cbc99a8a061bc181c1000c60a30f55393b6e9c45cc5bd"}, - {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ca038c7f6a0afd0b2448941b6ef9d5e1949e999f9e5517692eb6da58e9d44be"}, - {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e0bd57539da59a3e4671b90a502da9a28c72322a4f17866ba3ac63a82c4498e"}, - {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ac6c2c45c847bbf8f91930d88716a0fb924b51e0c6dad329b793d670ec5db792"}, - {file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b94d4ba43739bbe8b0ce4262bcc3b7b9f31459ad120fb595627eaeb7f9b9ca01"}, - {file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:00e6424f4b26fe82d44577b4c842d7df97c20be6439e8e685d0d715feceb9fb9"}, - {file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:38de0a70160dd97540335b7ad3a74571b24f1dc3ed33f815f0880682e6880131"}, - {file = "pydantic_core-2.27.1-cp39-none-win32.whl", hash = "sha256:7ccebf51efc61634f6c2344da73e366c75e735960b5654b63d7e6f69a5885fa3"}, - {file = "pydantic_core-2.27.1-cp39-none-win_amd64.whl", hash = "sha256:a57847b090d7892f123726202b7daa20df6694cbd583b67a592e856bff603d6c"}, - {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3fa80ac2bd5856580e242dbc202db873c60a01b20309c8319b5c5986fbe53ce6"}, - {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d950caa237bb1954f1b8c9227b5065ba6875ac9771bb8ec790d956a699b78676"}, - {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e4216e64d203e39c62df627aa882f02a2438d18a5f21d7f721621f7a5d3611d"}, - {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02a3d637bd387c41d46b002f0e49c52642281edacd2740e5a42f7017feea3f2c"}, - {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:161c27ccce13b6b0c8689418da3885d3220ed2eae2ea5e9b2f7f3d48f1d52c27"}, - {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:19910754e4cc9c63bc1c7f6d73aa1cfee82f42007e407c0f413695c2f7ed777f"}, - {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:e173486019cc283dc9778315fa29a363579372fe67045e971e89b6365cc035ed"}, - {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:af52d26579b308921b73b956153066481f064875140ccd1dfd4e77db89dbb12f"}, - {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:981fb88516bd1ae8b0cbbd2034678a39dedc98752f264ac9bc5839d3923fa04c"}, - {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5fde892e6c697ce3e30c61b239330fc5d569a71fefd4eb6512fc6caec9dd9e2f"}, - {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:816f5aa087094099fff7edabb5e01cc370eb21aa1a1d44fe2d2aefdfb5599b31"}, - {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c10c309e18e443ddb108f0ef64e8729363adbfd92d6d57beec680f6261556f3"}, - {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98476c98b02c8e9b2eec76ac4156fd006628b1b2d0ef27e548ffa978393fd154"}, - {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c3027001c28434e7ca5a6e1e527487051136aa81803ac812be51802150d880dd"}, - {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:7699b1df36a48169cdebda7ab5a2bac265204003f153b4bd17276153d997670a"}, - {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:1c39b07d90be6b48968ddc8c19e7585052088fd7ec8d568bb31ff64c70ae3c97"}, - {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:46ccfe3032b3915586e469d4972973f893c0a2bb65669194a5bdea9bacc088c2"}, - {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:62ba45e21cf6571d7f716d903b5b7b6d2617e2d5d67c0923dc47b9d41369f840"}, - {file = "pydantic_core-2.27.1.tar.gz", hash = "sha256:62a763352879b84aa31058fc931884055fd75089cccbd9d58bb6afd01141b235"}, + {file = "pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa"}, + {file = "pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c"}, + {file = "pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7969e133a6f183be60e9f6f56bfae753585680f3b7307a8e555a948d443cc05a"}, + {file = "pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3de9961f2a346257caf0aa508a4da705467f53778e9ef6fe744c038119737ef5"}, + {file = "pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2bb4d3e5873c37bb3dd58714d4cd0b0e6238cebc4177ac8fe878f8b3aa8e74c"}, + {file = "pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:280d219beebb0752699480fe8f1dc61ab6615c2046d76b7ab7ee38858de0a4e7"}, + {file = "pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47956ae78b6422cbd46f772f1746799cbb862de838fd8d1fbd34a82e05b0983a"}, + {file = "pydantic_core-2.27.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:14d4a5c49d2f009d62a2a7140d3064f686d17a5d1a268bc641954ba181880236"}, + {file = "pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:337b443af21d488716f8d0b6164de833e788aa6bd7e3a39c005febc1284f4962"}, + {file = "pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:03d0f86ea3184a12f41a2d23f7ccb79cdb5a18e06993f8a45baa8dfec746f0e9"}, + {file = "pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7041c36f5680c6e0f08d922aed302e98b3745d97fe1589db0a3eebf6624523af"}, + {file = "pydantic_core-2.27.2-cp310-cp310-win32.whl", hash = "sha256:50a68f3e3819077be2c98110c1f9dcb3817e93f267ba80a2c05bb4f8799e2ff4"}, + {file = "pydantic_core-2.27.2-cp310-cp310-win_amd64.whl", hash = "sha256:e0fd26b16394ead34a424eecf8a31a1f5137094cabe84a1bcb10fa6ba39d3d31"}, + {file = "pydantic_core-2.27.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:8e10c99ef58cfdf2a66fc15d66b16c4a04f62bca39db589ae8cba08bc55331bc"}, + {file = "pydantic_core-2.27.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:26f32e0adf166a84d0cb63be85c562ca8a6fa8de28e5f0d92250c6b7e9e2aff7"}, + {file = "pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c19d1ea0673cd13cc2f872f6c9ab42acc4e4f492a7ca9d3795ce2b112dd7e15"}, + {file = "pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e68c4446fe0810e959cdff46ab0a41ce2f2c86d227d96dc3847af0ba7def306"}, + {file = "pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9640b0059ff4f14d1f37321b94061c6db164fbe49b334b31643e0528d100d99"}, + {file = "pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:40d02e7d45c9f8af700f3452f329ead92da4c5f4317ca9b896de7ce7199ea459"}, + {file = "pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c1fd185014191700554795c99b347d64f2bb637966c4cfc16998a0ca700d048"}, + {file = "pydantic_core-2.27.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d81d2068e1c1228a565af076598f9e7451712700b673de8f502f0334f281387d"}, + {file = "pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1a4207639fb02ec2dbb76227d7c751a20b1a6b4bc52850568e52260cae64ca3b"}, + {file = "pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:3de3ce3c9ddc8bbd88f6e0e304dea0e66d843ec9de1b0042b0911c1663ffd474"}, + {file = "pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:30c5f68ded0c36466acede341551106821043e9afaad516adfb6e8fa80a4e6a6"}, + {file = "pydantic_core-2.27.2-cp311-cp311-win32.whl", hash = "sha256:c70c26d2c99f78b125a3459f8afe1aed4d9687c24fd677c6a4436bc042e50d6c"}, + {file = "pydantic_core-2.27.2-cp311-cp311-win_amd64.whl", hash = "sha256:08e125dbdc505fa69ca7d9c499639ab6407cfa909214d500897d02afb816e7cc"}, + {file = "pydantic_core-2.27.2-cp311-cp311-win_arm64.whl", hash = "sha256:26f0d68d4b235a2bae0c3fc585c585b4ecc51382db0e3ba402a22cbc440915e4"}, + {file = "pydantic_core-2.27.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9e0c8cfefa0ef83b4da9588448b6d8d2a2bf1a53c3f1ae5fca39eb3061e2f0b0"}, + {file = "pydantic_core-2.27.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:83097677b8e3bd7eaa6775720ec8e0405f1575015a463285a92bfdfe254529ef"}, + {file = "pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:172fce187655fece0c90d90a678424b013f8fbb0ca8b036ac266749c09438cb7"}, + {file = "pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:519f29f5213271eeeeb3093f662ba2fd512b91c5f188f3bb7b27bc5973816934"}, + {file = "pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05e3a55d124407fffba0dd6b0c0cd056d10e983ceb4e5dbd10dda135c31071d6"}, + {file = "pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c3ed807c7b91de05e63930188f19e921d1fe90de6b4f5cd43ee7fcc3525cb8c"}, + {file = "pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fb4aadc0b9a0c063206846d603b92030eb6f03069151a625667f982887153e2"}, + {file = "pydantic_core-2.27.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:28ccb213807e037460326424ceb8b5245acb88f32f3d2777427476e1b32c48c4"}, + {file = "pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:de3cd1899e2c279b140adde9357c4495ed9d47131b4a4eaff9052f23398076b3"}, + {file = "pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:220f892729375e2d736b97d0e51466252ad84c51857d4d15f5e9692f9ef12be4"}, + {file = "pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a0fcd29cd6b4e74fe8ddd2c90330fd8edf2e30cb52acda47f06dd615ae72da57"}, + {file = "pydantic_core-2.27.2-cp312-cp312-win32.whl", hash = "sha256:1e2cb691ed9834cd6a8be61228471d0a503731abfb42f82458ff27be7b2186fc"}, + {file = "pydantic_core-2.27.2-cp312-cp312-win_amd64.whl", hash = "sha256:cc3f1a99a4f4f9dd1de4fe0312c114e740b5ddead65bb4102884b384c15d8bc9"}, + {file = "pydantic_core-2.27.2-cp312-cp312-win_arm64.whl", hash = "sha256:3911ac9284cd8a1792d3cb26a2da18f3ca26c6908cc434a18f730dc0db7bfa3b"}, + {file = "pydantic_core-2.27.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:7d14bd329640e63852364c306f4d23eb744e0f8193148d4044dd3dacdaacbd8b"}, + {file = "pydantic_core-2.27.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:82f91663004eb8ed30ff478d77c4d1179b3563df6cdb15c0817cd1cdaf34d154"}, + {file = "pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71b24c7d61131bb83df10cc7e687433609963a944ccf45190cfc21e0887b08c9"}, + {file = "pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fa8e459d4954f608fa26116118bb67f56b93b209c39b008277ace29937453dc9"}, + {file = "pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce8918cbebc8da707ba805b7fd0b382816858728ae7fe19a942080c24e5b7cd1"}, + {file = "pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eda3f5c2a021bbc5d976107bb302e0131351c2ba54343f8a496dc8783d3d3a6a"}, + {file = "pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd8086fa684c4775c27f03f062cbb9eaa6e17f064307e86b21b9e0abc9c0f02e"}, + {file = "pydantic_core-2.27.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8d9b3388db186ba0c099a6d20f0604a44eabdeef1777ddd94786cdae158729e4"}, + {file = "pydantic_core-2.27.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7a66efda2387de898c8f38c0cf7f14fca0b51a8ef0b24bfea5849f1b3c95af27"}, + {file = "pydantic_core-2.27.2-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:18a101c168e4e092ab40dbc2503bdc0f62010e95d292b27827871dc85450d7ee"}, + {file = "pydantic_core-2.27.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ba5dd002f88b78a4215ed2f8ddbdf85e8513382820ba15ad5ad8955ce0ca19a1"}, + {file = "pydantic_core-2.27.2-cp313-cp313-win32.whl", hash = "sha256:1ebaf1d0481914d004a573394f4be3a7616334be70261007e47c2a6fe7e50130"}, + {file = "pydantic_core-2.27.2-cp313-cp313-win_amd64.whl", hash = "sha256:953101387ecf2f5652883208769a79e48db18c6df442568a0b5ccd8c2723abee"}, + {file = "pydantic_core-2.27.2-cp313-cp313-win_arm64.whl", hash = "sha256:ac4dbfd1691affb8f48c2c13241a2e3b60ff23247cbcf981759c768b6633cf8b"}, + {file = "pydantic_core-2.27.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d3e8d504bdd3f10835468f29008d72fc8359d95c9c415ce6e767203db6127506"}, + {file = "pydantic_core-2.27.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:521eb9b7f036c9b6187f0b47318ab0d7ca14bd87f776240b90b21c1f4f149320"}, + {file = "pydantic_core-2.27.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85210c4d99a0114f5a9481b44560d7d1e35e32cc5634c656bc48e590b669b145"}, + {file = "pydantic_core-2.27.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d716e2e30c6f140d7560ef1538953a5cd1a87264c737643d481f2779fc247fe1"}, + {file = "pydantic_core-2.27.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f66d89ba397d92f840f8654756196d93804278457b5fbede59598a1f9f90b228"}, + {file = "pydantic_core-2.27.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:669e193c1c576a58f132e3158f9dfa9662969edb1a250c54d8fa52590045f046"}, + {file = "pydantic_core-2.27.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdbe7629b996647b99c01b37f11170a57ae675375b14b8c13b8518b8320ced5"}, + {file = "pydantic_core-2.27.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d262606bf386a5ba0b0af3b97f37c83d7011439e3dc1a9298f21efb292e42f1a"}, + {file = "pydantic_core-2.27.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:cabb9bcb7e0d97f74df8646f34fc76fbf793b7f6dc2438517d7a9e50eee4f14d"}, + {file = "pydantic_core-2.27.2-cp38-cp38-musllinux_1_1_armv7l.whl", hash = "sha256:d2d63f1215638d28221f664596b1ccb3944f6e25dd18cd3b86b0a4c408d5ebb9"}, + {file = "pydantic_core-2.27.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bca101c00bff0adb45a833f8451b9105d9df18accb8743b08107d7ada14bd7da"}, + {file = "pydantic_core-2.27.2-cp38-cp38-win32.whl", hash = "sha256:f6f8e111843bbb0dee4cb6594cdc73e79b3329b526037ec242a3e49012495b3b"}, + {file = "pydantic_core-2.27.2-cp38-cp38-win_amd64.whl", hash = "sha256:fd1aea04935a508f62e0d0ef1f5ae968774a32afc306fb8545e06f5ff5cdf3ad"}, + {file = "pydantic_core-2.27.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c10eb4f1659290b523af58fa7cffb452a61ad6ae5613404519aee4bfbf1df993"}, + {file = "pydantic_core-2.27.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ef592d4bad47296fb11f96cd7dc898b92e795032b4894dfb4076cfccd43a9308"}, + {file = "pydantic_core-2.27.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c61709a844acc6bf0b7dce7daae75195a10aac96a596ea1b776996414791ede4"}, + {file = "pydantic_core-2.27.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42c5f762659e47fdb7b16956c71598292f60a03aa92f8b6351504359dbdba6cf"}, + {file = "pydantic_core-2.27.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c9775e339e42e79ec99c441d9730fccf07414af63eac2f0e48e08fd38a64d76"}, + {file = "pydantic_core-2.27.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57762139821c31847cfb2df63c12f725788bd9f04bc2fb392790959b8f70f118"}, + {file = "pydantic_core-2.27.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d1e85068e818c73e048fe28cfc769040bb1f475524f4745a5dc621f75ac7630"}, + {file = "pydantic_core-2.27.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:097830ed52fd9e427942ff3b9bc17fab52913b2f50f2880dc4a5611446606a54"}, + {file = "pydantic_core-2.27.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:044a50963a614ecfae59bb1eaf7ea7efc4bc62f49ed594e18fa1e5d953c40e9f"}, + {file = "pydantic_core-2.27.2-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:4e0b4220ba5b40d727c7f879eac379b822eee5d8fff418e9d3381ee45b3b0362"}, + {file = "pydantic_core-2.27.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5e4f4bb20d75e9325cc9696c6802657b58bc1dbbe3022f32cc2b2b632c3fbb96"}, + {file = "pydantic_core-2.27.2-cp39-cp39-win32.whl", hash = "sha256:cca63613e90d001b9f2f9a9ceb276c308bfa2a43fafb75c8031c4f66039e8c6e"}, + {file = "pydantic_core-2.27.2-cp39-cp39-win_amd64.whl", hash = "sha256:77d1bca19b0f7021b3a982e6f903dcd5b2b06076def36a652e3907f596e29f67"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2bf14caea37e91198329b828eae1618c068dfb8ef17bb33287a7ad4b61ac314e"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:b0cb791f5b45307caae8810c2023a184c74605ec3bcbb67d13846c28ff731ff8"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:688d3fd9fcb71f41c4c015c023d12a79d1c4c0732ec9eb35d96e3388a120dcf3"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d591580c34f4d731592f0e9fe40f9cc1b430d297eecc70b962e93c5c668f15f"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:82f986faf4e644ffc189a7f1aafc86e46ef70372bb153e7001e8afccc6e54133"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:bec317a27290e2537f922639cafd54990551725fc844249e64c523301d0822fc"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:0296abcb83a797db256b773f45773da397da75a08f5fcaef41f2044adec05f50"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0d75070718e369e452075a6017fbf187f788e17ed67a3abd47fa934d001863d9"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7e17b560be3c98a8e3aa66ce828bdebb9e9ac6ad5466fba92eb74c4c95cb1151"}, + {file = "pydantic_core-2.27.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c33939a82924da9ed65dab5a65d427205a73181d8098e79b6b426bdf8ad4e656"}, + {file = "pydantic_core-2.27.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:00bad2484fa6bda1e216e7345a798bd37c68fb2d97558edd584942aa41b7d278"}, + {file = "pydantic_core-2.27.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c817e2b40aba42bac6f457498dacabc568c3b7a986fc9ba7c8d9d260b71485fb"}, + {file = "pydantic_core-2.27.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:251136cdad0cb722e93732cb45ca5299fb56e1344a833640bf93b2803f8d1bfd"}, + {file = "pydantic_core-2.27.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d2088237af596f0a524d3afc39ab3b036e8adb054ee57cbb1dcf8e09da5b29cc"}, + {file = "pydantic_core-2.27.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d4041c0b966a84b4ae7a09832eb691a35aec90910cd2dbe7a208de59be77965b"}, + {file = "pydantic_core-2.27.2-pp39-pypy39_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:8083d4e875ebe0b864ffef72a4304827015cff328a1be6e22cc850753bfb122b"}, + {file = "pydantic_core-2.27.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f141ee28a0ad2123b6611b6ceff018039df17f32ada8b534e6aa039545a3efb2"}, + {file = "pydantic_core-2.27.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7d0c8399fcc1848491f00e0314bd59fb34a9c008761bcb422a057670c3f65e35"}, + {file = "pydantic_core-2.27.2.tar.gz", hash = "sha256:eb026e5a4c1fee05726072337ff51d1efb6f59090b7da90d30ea58625b1ffb39"}, ] [package.dependencies] @@ -3083,13 +3085,13 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" [[package]] name = "pydantic-settings" -version = "2.6.1" +version = "2.7.0" description = "Settings management using Pydantic" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic_settings-2.6.1-py3-none-any.whl", hash = "sha256:7fb0637c786a558d3103436278a7c4f1cfd29ba8973238a50c5bb9a55387da87"}, - {file = "pydantic_settings-2.6.1.tar.gz", hash = "sha256:e0f92546d8a9923cb8941689abf85d6601a8c19a23e97a34b2964a2e3f813ca0"}, + {file = "pydantic_settings-2.7.0-py3-none-any.whl", hash = "sha256:e00c05d5fa6cbbb227c84bd7487c5c1065084119b750df7c8c1a554aed236eb5"}, + {file = "pydantic_settings-2.7.0.tar.gz", hash = "sha256:ac4bfd4a36831a48dbf8b2d9325425b549a0a6f18cea118436d728eb4f1c4d66"}, ] [package.dependencies] @@ -3235,13 +3237,13 @@ cli = ["click (>=5.0)"] [[package]] name = "python-json-logger" -version = "3.2.0" +version = "3.2.1" description = "JSON Log Formatter for the Python Logging Package" optional = false python-versions = ">=3.8" files = [ - {file = "python_json_logger-3.2.0-py3-none-any.whl", hash = "sha256:d73522ddcfc6d0461394120feaddea9025dc64bf804d96357dd42fa878cc5fe8"}, - {file = "python_json_logger-3.2.0.tar.gz", hash = "sha256:2c11056458d3f56614480b24e9cb28f7aba69cbfbebddbb77c92f0ec0d4947ab"}, + {file = "python_json_logger-3.2.1-py3-none-any.whl", hash = "sha256:cdc17047eb5374bd311e748b42f99d71223f3b0e186f4206cc5d52aefe85b090"}, + {file = "python_json_logger-3.2.1.tar.gz", hash = "sha256:8eb0554ea17cb75b05d2848bc14fb02fbdbd9d6972120781b974380bfa162008"}, ] [package.extras] @@ -3485,99 +3487,99 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} [[package]] name = "rapidfuzz" -version = "3.10.1" +version = "3.11.0" description = "rapid fuzzy string matching" optional = false python-versions = ">=3.9" files = [ - {file = "rapidfuzz-3.10.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f17d9f21bf2f2f785d74f7b0d407805468b4c173fa3e52c86ec94436b338e74a"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b31f358a70efc143909fb3d75ac6cd3c139cd41339aa8f2a3a0ead8315731f2b"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f4f43f2204b56a61448ec2dd061e26fd344c404da99fb19f3458200c5874ba2"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9d81bf186a453a2757472133b24915768abc7c3964194406ed93e170e16c21cb"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3611c8f45379a12063d70075c75134f2a8bd2e4e9b8a7995112ddae95ca1c982"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3c3b537b97ac30da4b73930fa8a4fe2f79c6d1c10ad535c5c09726612cd6bed9"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:231ef1ec9cf7b59809ce3301006500b9d564ddb324635f4ea8f16b3e2a1780da"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ed4f3adc1294834955b7e74edd3c6bd1aad5831c007f2d91ea839e76461a5879"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:7b6015da2e707bf632a71772a2dbf0703cff6525732c005ad24987fe86e8ec32"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1b35a118d61d6f008e8e3fb3a77674d10806a8972c7b8be433d6598df4d60b01"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:bc308d79a7e877226f36bdf4e149e3ed398d8277c140be5c1fd892ec41739e6d"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f017dbfecc172e2d0c37cf9e3d519179d71a7f16094b57430dffc496a098aa17"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-win32.whl", hash = "sha256:36c0e1483e21f918d0f2f26799fe5ac91c7b0c34220b73007301c4f831a9c4c7"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:10746c1d4c8cd8881c28a87fd7ba0c9c102346dfe7ff1b0d021cdf093e9adbff"}, - {file = "rapidfuzz-3.10.1-cp310-cp310-win_arm64.whl", hash = "sha256:dfa64b89dcb906835e275187569e51aa9d546a444489e97aaf2cc84011565fbe"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:92958ae075c87fef393f835ed02d4fe8d5ee2059a0934c6c447ea3417dfbf0e8"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ba7521e072c53e33c384e78615d0718e645cab3c366ecd3cc8cb732befd94967"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00d02cbd75d283c287471b5b3738b3e05c9096150f93f2d2dfa10b3d700f2db9"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:efa1582a397da038e2f2576c9cd49b842f56fde37d84a6b0200ffebc08d82350"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f12912acee1f506f974f58de9fdc2e62eea5667377a7e9156de53241c05fdba8"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666d5d8b17becc3f53447bcb2b6b33ce6c2df78792495d1fa82b2924cd48701a"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26f71582c0d62445067ee338ddad99b655a8f4e4ed517a90dcbfbb7d19310474"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8a2ef08b27167bcff230ffbfeedd4c4fa6353563d6aaa015d725dd3632fc3de7"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:365e4fc1a2b95082c890f5e98489b894e6bf8c338c6ac89bb6523c2ca6e9f086"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:1996feb7a61609fa842e6b5e0c549983222ffdedaf29644cc67e479902846dfe"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:cf654702f144beaa093103841a2ea6910d617d0bb3fccb1d1fd63c54dde2cd49"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ec108bf25de674781d0a9a935030ba090c78d49def3d60f8724f3fc1e8e75024"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-win32.whl", hash = "sha256:031f8b367e5d92f7a1e27f7322012f3c321c3110137b43cc3bf678505583ef48"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:f98f36c6a1bb9a6c8bbec99ad87c8c0e364f34761739b5ea9adf7b48129ae8cf"}, - {file = "rapidfuzz-3.10.1-cp311-cp311-win_arm64.whl", hash = "sha256:f1da2028cb4e41be55ee797a82d6c1cf589442504244249dfeb32efc608edee7"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:1340b56340896bede246f612b6ecf685f661a56aabef3d2512481bfe23ac5835"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2316515169b7b5a453f0ce3adbc46c42aa332cae9f2edb668e24d1fc92b2f2bb"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e06fe6a12241ec1b72c0566c6b28cda714d61965d86569595ad24793d1ab259"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d99c1cd9443b19164ec185a7d752f4b4db19c066c136f028991a480720472e23"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a1d9aa156ed52d3446388ba4c2f335e312191d1ca9d1f5762ee983cf23e4ecf6"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:54bcf4efaaee8e015822be0c2c28214815f4f6b4f70d8362cfecbd58a71188ac"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0c955e32afdbfdf6e9ee663d24afb25210152d98c26d22d399712d29a9b976b"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:191633722203f5b7717efcb73a14f76f3b124877d0608c070b827c5226d0b972"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:195baad28057ec9609e40385991004e470af9ef87401e24ebe72c064431524ab"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:0fff4a6b87c07366662b62ae994ffbeadc472e72f725923f94b72a3db49f4671"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:4ffed25f9fdc0b287f30a98467493d1e1ce5b583f6317f70ec0263b3c97dbba6"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d02cf8e5af89a9ac8f53c438ddff6d773f62c25c6619b29db96f4aae248177c0"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-win32.whl", hash = "sha256:f3bb81d4fe6a5d20650f8c0afcc8f6e1941f6fecdb434f11b874c42467baded0"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-win_amd64.whl", hash = "sha256:aaf83e9170cb1338922ae42d320699dccbbdca8ffed07faeb0b9257822c26e24"}, - {file = "rapidfuzz-3.10.1-cp312-cp312-win_arm64.whl", hash = "sha256:c5da802a0d085ad81b0f62828fb55557996c497b2d0b551bbdfeafd6d447892f"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fc22d69a1c9cccd560a5c434c0371b2df0f47c309c635a01a913e03bbf183710"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:38b0dac2c8e057562b8f0d8ae5b663d2d6a28c5ab624de5b73cef9abb6129a24"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6fde3bbb14e92ce8fcb5c2edfff72e474d0080cadda1c97785bf4822f037a309"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9141fb0592e55f98fe9ac0f3ce883199b9c13e262e0bf40c5b18cdf926109d16"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:237bec5dd1bfc9b40bbd786cd27949ef0c0eb5fab5eb491904c6b5df59d39d3c"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18123168cba156ab5794ea6de66db50f21bb3c66ae748d03316e71b27d907b95"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b75fe506c8e02769cc47f5ab21ce3e09b6211d3edaa8f8f27331cb6988779be"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9da82aa4b46973aaf9e03bb4c3d6977004648c8638febfc0f9d237e865761270"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:c34c022d5ad564f1a5a57a4a89793bd70d7bad428150fb8ff2760b223407cdcf"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:1e96c84d6c2a0ca94e15acb5399118fff669f4306beb98a6d8ec6f5dccab4412"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:e8e154b84a311263e1aca86818c962e1fa9eefdd643d1d5d197fcd2738f88cb9"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:335fee93188f8cd585552bb8057228ce0111bd227fa81bfd40b7df6b75def8ab"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-win32.whl", hash = "sha256:6729b856166a9e95c278410f73683957ea6100c8a9d0a8dbe434c49663689255"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-win_amd64.whl", hash = "sha256:0e06d99ad1ad97cb2ef7f51ec6b1fedd74a3a700e4949353871cf331d07b382a"}, - {file = "rapidfuzz-3.10.1-cp313-cp313-win_arm64.whl", hash = "sha256:8d1b7082104d596a3eb012e0549b2634ed15015b569f48879701e9d8db959dbb"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:779027d3307e1a2b1dc0c03c34df87a470a368a1a0840a9d2908baf2d4067956"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:440b5608ab12650d0390128d6858bc839ae77ffe5edf0b33a1551f2fa9860651"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82cac41a411e07a6f3dc80dfbd33f6be70ea0abd72e99c59310819d09f07d945"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:958473c9f0bca250590200fd520b75be0dbdbc4a7327dc87a55b6d7dc8d68552"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ef60dfa73749ef91cb6073be1a3e135f4846ec809cc115f3cbfc6fe283a5584"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7fbac18f2c19fc983838a60611e67e3262e36859994c26f2ee85bb268de2355"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a0d519ff39db887cd73f4e297922786d548f5c05d6b51f4e6754f452a7f4296"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:bebb7bc6aeb91cc57e4881b222484c26759ca865794187217c9dcea6c33adae6"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:fe07f8b9c3bb5c5ad1d2c66884253e03800f4189a60eb6acd6119ebaf3eb9894"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:bfa48a4a2d45a41457f0840c48e579db157a927f4e97acf6e20df8fc521c79de"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:2cf44d01bfe8ee605b7eaeecbc2b9ca64fc55765f17b304b40ed8995f69d7716"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1e6bbca9246d9eedaa1c84e04a7f555493ba324d52ae4d9f3d9ddd1b740dcd87"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-win32.whl", hash = "sha256:567f88180f2c1423b4fe3f3ad6e6310fc97b85bdba574801548597287fc07028"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:6b2cd7c29d6ecdf0b780deb587198f13213ac01c430ada6913452fd0c40190fc"}, - {file = "rapidfuzz-3.10.1-cp39-cp39-win_arm64.whl", hash = "sha256:9f912d459e46607ce276128f52bea21ebc3e9a5ccf4cccfef30dd5bddcf47be8"}, - {file = "rapidfuzz-3.10.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ac4452f182243cfab30ba4668ef2de101effaedc30f9faabb06a095a8c90fd16"}, - {file = "rapidfuzz-3.10.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:565c2bd4f7d23c32834652b27b51dd711814ab614b4e12add8476be4e20d1cf5"}, - {file = "rapidfuzz-3.10.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:187d9747149321607be4ccd6f9f366730078bed806178ec3eeb31d05545e9e8f"}, - {file = "rapidfuzz-3.10.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:616290fb9a8fa87e48cb0326d26f98d4e29f17c3b762c2d586f2b35c1fd2034b"}, - {file = "rapidfuzz-3.10.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:073a5b107e17ebd264198b78614c0206fa438cce749692af5bc5f8f484883f50"}, - {file = "rapidfuzz-3.10.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:39c4983e2e2ccb9732f3ac7d81617088822f4a12291d416b09b8a1eadebb3e29"}, - {file = "rapidfuzz-3.10.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ac7adee6bcf0c6fee495d877edad1540a7e0f5fc208da03ccb64734b43522d7a"}, - {file = "rapidfuzz-3.10.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:425f4ac80b22153d391ee3f94bc854668a0c6c129f05cf2eaf5ee74474ddb69e"}, - {file = "rapidfuzz-3.10.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65a2fa13e8a219f9b5dcb9e74abe3ced5838a7327e629f426d333dfc8c5a6e66"}, - {file = "rapidfuzz-3.10.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75561f3df9a906aaa23787e9992b228b1ab69007932dc42070f747103e177ba8"}, - {file = "rapidfuzz-3.10.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:edd062490537e97ca125bc6c7f2b7331c2b73d21dc304615afe61ad1691e15d5"}, - {file = "rapidfuzz-3.10.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cfcc8feccf63245a22dfdd16e222f1a39771a44b870beb748117a0e09cbb4a62"}, - {file = "rapidfuzz-3.10.1.tar.gz", hash = "sha256:5a15546d847a915b3f42dc79ef9b0c78b998b4e2c53b252e7166284066585979"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eb8a54543d16ab1b69e2c5ed96cabbff16db044a50eddfc028000138ca9ddf33"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:231c8b2efbd7f8d2ecd1ae900363ba168b8870644bb8f2b5aa96e4a7573bde19"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54e7f442fb9cca81e9df32333fb075ef729052bcabe05b0afc0441f462299114"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:906f1f2a1b91c06599b3dd1be207449c5d4fc7bd1e1fa2f6aef161ea6223f165"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ed59044aea9eb6c663112170f2399b040d5d7b162828b141f2673e822093fa8"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1cb1965a28b0fa64abdee130c788a0bc0bb3cf9ef7e3a70bf055c086c14a3d7e"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b488b244931d0291412917e6e46ee9f6a14376625e150056fe7c4426ef28225"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f0ba13557fec9d5ffc0a22826754a7457cc77f1b25145be10b7bb1d143ce84c6"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3871fa7dfcef00bad3c7e8ae8d8fd58089bad6fb21f608d2bf42832267ca9663"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:b2669eafee38c5884a6e7cc9769d25c19428549dcdf57de8541cf9e82822e7db"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:ffa1bb0e26297b0f22881b219ffc82a33a3c84ce6174a9d69406239b14575bd5"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:45b15b8a118856ac9caac6877f70f38b8a0d310475d50bc814698659eabc1cdb"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-win32.whl", hash = "sha256:22033677982b9c4c49676f215b794b0404073f8974f98739cb7234e4a9ade9ad"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:be15496e7244361ff0efcd86e52559bacda9cd975eccf19426a0025f9547c792"}, + {file = "rapidfuzz-3.11.0-cp310-cp310-win_arm64.whl", hash = "sha256:714a7ba31ba46b64d30fccfe95f8013ea41a2e6237ba11a805a27cdd3bce2573"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8724a978f8af7059c5323d523870bf272a097478e1471295511cf58b2642ff83"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8b63cb1f2eb371ef20fb155e95efd96e060147bdd4ab9fc400c97325dfee9fe1"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82497f244aac10b20710448645f347d862364cc4f7d8b9ba14bd66b5ce4dec18"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:339607394941801e6e3f6c1ecd413a36e18454e7136ed1161388de674f47f9d9"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84819390a36d6166cec706b9d8f0941f115f700b7faecab5a7e22fc367408bc3"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eea8d9e20632d68f653455265b18c35f90965e26f30d4d92f831899d6682149b"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b659e1e2ea2784a9a397075a7fc395bfa4fe66424042161c4bcaf6e4f637b38"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1315cd2a351144572e31fe3df68340d4b83ddec0af8b2e207cd32930c6acd037"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:a7743cca45b4684c54407e8638f6d07b910d8d811347b9d42ff21262c7c23245"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:5bb636b0150daa6d3331b738f7c0f8b25eadc47f04a40e5c23c4bfb4c4e20ae3"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:42f4dd264ada7a9aa0805ea0da776dc063533917773cf2df5217f14eb4429eae"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:51f24cb39e64256221e6952f22545b8ce21cacd59c0d3e367225da8fc4b868d8"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-win32.whl", hash = "sha256:aaf391fb6715866bc14681c76dc0308f46877f7c06f61d62cc993b79fc3c4a2a"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:ebadd5b8624d8ad503e505a99b8eb26fe3ea9f8e9c2234e805a27b269e585842"}, + {file = "rapidfuzz-3.11.0-cp311-cp311-win_arm64.whl", hash = "sha256:d895998fec712544c13cfe833890e0226585cf0391dd3948412441d5d68a2b8c"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f382fec4a7891d66fb7163c90754454030bb9200a13f82ee7860b6359f3f2fa8"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dfaefe08af2a928e72344c800dcbaf6508e86a4ed481e28355e8d4b6a6a5230e"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92ebb7c12f682b5906ed98429f48a3dd80dd0f9721de30c97a01473d1a346576"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9a1b3ebc62d4bcdfdeba110944a25ab40916d5383c5e57e7c4a8dc0b6c17211a"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9c6d7fea39cb33e71de86397d38bf7ff1a6273e40367f31d05761662ffda49e4"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:99aebef8268f2bc0b445b5640fd3312e080bd17efd3fbae4486b20ac00466308"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4469307f464ae3089acf3210b8fc279110d26d10f79e576f385a98f4429f7d97"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:eb97c53112b593f89a90b4f6218635a9d1eea1d7f9521a3b7d24864228bbc0aa"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:ef8937dae823b889c0273dfa0f0f6c46a3658ac0d851349c464d1b00e7ff4252"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:d95f9e9f3777b96241d8a00d6377cc9c716981d828b5091082d0fe3a2924b43e"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:b1d67d67f89e4e013a5295e7523bc34a7a96f2dba5dd812c7c8cb65d113cbf28"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d994cf27e2f874069884d9bddf0864f9b90ad201fcc9cb2f5b82bacc17c8d5f2"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-win32.whl", hash = "sha256:ba26d87fe7fcb56c4a53b549a9e0e9143f6b0df56d35fe6ad800c902447acd5b"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:b1f7efdd7b7adb32102c2fa481ad6f11923e2deb191f651274be559d56fc913b"}, + {file = "rapidfuzz-3.11.0-cp312-cp312-win_arm64.whl", hash = "sha256:ed78c8e94f57b44292c1a0350f580e18d3a3c5c0800e253f1583580c1b417ad2"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e60814edd0c9b511b5f377d48b9782b88cfe8be07a98f99973669299c8bb318a"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3f28952da055dbfe75828891cd3c9abf0984edc8640573c18b48c14c68ca5e06"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e8f93bc736020351a6f8e71666e1f486bb8bd5ce8112c443a30c77bfde0eb68"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76a4a11ba8f678c9e5876a7d465ab86def047a4fcc043617578368755d63a1bc"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc0e0d41ad8a056a9886bac91ff9d9978e54a244deb61c2972cc76b66752de9c"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e8ea35f2419c7d56b3e75fbde2698766daedb374f20eea28ac9b1f668ef4f74"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd340bbd025302276b5aa221dccfe43040c7babfc32f107c36ad783f2ffd8775"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:494eef2c68305ab75139034ea25328a04a548d297712d9cf887bf27c158c388b"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:5a167344c1d6db06915fb0225592afdc24d8bafaaf02de07d4788ddd37f4bc2f"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:8c7af25bda96ac799378ac8aba54a8ece732835c7b74cfc201b688a87ed11152"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:d2a0f7e17f33e7890257367a1662b05fecaf56625f7dbb6446227aaa2b86448b"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4d0d26c7172bdb64f86ee0765c5b26ea1dc45c52389175888ec073b9b28f4305"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-win32.whl", hash = "sha256:6ad02bab756751c90fa27f3069d7b12146613061341459abf55f8190d899649f"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:b1472986fd9c5d318399a01a0881f4a0bf4950264131bb8e2deba9df6d8c362b"}, + {file = "rapidfuzz-3.11.0-cp313-cp313-win_arm64.whl", hash = "sha256:c408f09649cbff8da76f8d3ad878b64ba7f7abdad1471efb293d2c075e80c822"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1bac4873f6186f5233b0084b266bfb459e997f4c21fc9f029918f44a9eccd304"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f9f12c2d0aa52b86206d2059916153876a9b1cf9dfb3cf2f344913167f1c3d4"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8dd501de6f7a8f83557d20613b58734d1cb5f0be78d794cde64fe43cfc63f5f2"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4416ca69af933d4a8ad30910149d3db6d084781d5c5fdedb713205389f535385"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f0821b9bdf18c5b7d51722b906b233a39b17f602501a966cfbd9b285f8ab83cd"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0edecc3f90c2653298d380f6ea73b536944b767520c2179ec5d40b9145e47aa"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4513dd01cee11e354c31b75f652d4d466c9440b6859f84e600bdebfccb17735a"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d9727b85511b912571a76ce53c7640ba2c44c364e71cef6d7359b5412739c570"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ab9eab33ee3213f7751dc07a1a61b8d9a3d748ca4458fffddd9defa6f0493c16"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6b01c1ddbb054283797967ddc5433d5c108d680e8fa2684cf368be05407b07e4"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:3857e335f97058c4b46fa39ca831290b70de554a5c5af0323d2f163b19c5f2a6"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d98a46cf07c0c875d27e8a7ed50f304d83063e49b9ab63f21c19c154b4c0d08d"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-win32.whl", hash = "sha256:c36539ed2c0173b053dafb221458812e178cfa3224ade0960599bec194637048"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-win_amd64.whl", hash = "sha256:ec8d7d8567e14af34a7911c98f5ac74a3d4a743cd848643341fc92b12b3784ff"}, + {file = "rapidfuzz-3.11.0-cp39-cp39-win_arm64.whl", hash = "sha256:62171b270ecc4071be1c1f99960317db261d4c8c83c169e7f8ad119211fe7397"}, + {file = "rapidfuzz-3.11.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:f06e3c4c0a8badfc4910b9fd15beb1ad8f3b8fafa8ea82c023e5e607b66a78e4"}, + {file = "rapidfuzz-3.11.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:fe7aaf5a54821d340d21412f7f6e6272a9b17a0cbafc1d68f77f2fc11009dcd5"}, + {file = "rapidfuzz-3.11.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25398d9ac7294e99876a3027ffc52c6bebeb2d702b1895af6ae9c541ee676702"}, + {file = "rapidfuzz-3.11.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9a52eea839e4bdc72c5e60a444d26004da00bb5bc6301e99b3dde18212e41465"}, + {file = "rapidfuzz-3.11.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c87319b0ab9d269ab84f6453601fd49b35d9e4a601bbaef43743f26fabf496c"}, + {file = "rapidfuzz-3.11.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:3048c6ed29d693fba7d2a7caf165f5e0bb2b9743a0989012a98a47b975355cca"}, + {file = "rapidfuzz-3.11.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b04f29735bad9f06bb731c214f27253bd8bedb248ef9b8a1b4c5bde65b838454"}, + {file = "rapidfuzz-3.11.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:7864e80a0d4e23eb6194254a81ee1216abdc53f9dc85b7f4d56668eced022eb8"}, + {file = "rapidfuzz-3.11.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3794df87313dfb56fafd679b962e0613c88a293fd9bd5dd5c2793d66bf06a101"}, + {file = "rapidfuzz-3.11.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d71da0012face6f45432a11bc59af19e62fac5a41f8ce489e80c0add8153c3d1"}, + {file = "rapidfuzz-3.11.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff38378346b7018f42cbc1f6d1d3778e36e16d8595f79a312b31e7c25c50bd08"}, + {file = "rapidfuzz-3.11.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:6668321f90aa02a5a789d4e16058f2e4f2692c5230252425c3532a8a62bc3424"}, + {file = "rapidfuzz-3.11.0.tar.gz", hash = "sha256:a53ca4d3f52f00b393fab9b5913c5bafb9afc27d030c8a1db1283da6917a860f"}, ] [package.extras] @@ -4125,13 +4127,13 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] [[package]] name = "streamlit" -version = "1.41.0" +version = "1.41.1" description = "A faster way to build and share data apps" optional = false python-versions = "!=3.9.7,>=3.9" files = [ - {file = "streamlit-1.41.0-py2.py3-none-any.whl", hash = "sha256:66a0e7ccb6c685b52782966f4a2fe63f757627eeaad0ca3f5b335b5bf191e7bd"}, - {file = "streamlit-1.41.0.tar.gz", hash = "sha256:87c20027fcbc8690ad0955b7e89aea2276d14fca2c8c4b54a251bf2b66b40201"}, + {file = "streamlit-1.41.1-py2.py3-none-any.whl", hash = "sha256:0def00822480071d642e6df36cd63c089f991da3a69fd9eb4ab8f65ce27de4e0"}, + {file = "streamlit-1.41.1.tar.gz", hash = "sha256:6626d32b098ba1458b71eebdd634c62af2dd876380e59c4b6a1e828a39d62d69"}, ] [package.dependencies] @@ -4437,13 +4439,13 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0, [[package]] name = "transformers" -version = "4.47.0" +version = "4.47.1" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.9.0" files = [ - {file = "transformers-4.47.0-py3-none-any.whl", hash = "sha256:a8e1bafdaae69abdda3cad638fe392e37c86d2ce0ecfcae11d60abb8f949ff4d"}, - {file = "transformers-4.47.0.tar.gz", hash = "sha256:f8ead7a5a4f6937bb507e66508e5e002dc5930f7b6122a9259c37b099d0f3b19"}, + {file = "transformers-4.47.1-py3-none-any.whl", hash = "sha256:d2f5d19bb6283cd66c893ec7e6d931d6370bbf1cc93633326ff1f41a40046c9c"}, + {file = "transformers-4.47.1.tar.gz", hash = "sha256:6c29c05a5f595e278481166539202bf8641281536df1c42357ee58a45d0a564a"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 63ee32e3..317235d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "surya-ocr" -version = "0.8.0" +version = "0.8.1" description = "OCR, layout, reading order, and table recognition in 90+ languages" authors = ["Vik Paruchuri "] readme = "README.md" From cd795a71c0cf9bf118633d347e99bc4de2cf9618 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Thu, 19 Dec 2024 11:15:03 -0500 Subject: [PATCH 10/12] Add in tests --- .../workflows/{tests.yml => benchmarks.yml} | 6 --- .github/workflows/ci.yml | 26 ++++++++++ poetry.lock | 50 ++++++++++++++++++- pyproject.toml | 1 + pytest.ini | 7 +++ surya/benchmark/tesseract.py | 5 +- tests/conftest.py | 10 ++++ tests/test_ocr_errors.py | 18 +++++++ 8 files changed, 114 insertions(+), 9 deletions(-) rename .github/workflows/{tests.yml => benchmarks.yml} (85%) create mode 100644 .github/workflows/ci.yml create mode 100644 pytest.ini create mode 100644 tests/conftest.py create mode 100644 tests/test_ocr_errors.py diff --git a/.github/workflows/tests.yml b/.github/workflows/benchmarks.yml similarity index 85% rename from .github/workflows/tests.yml rename to .github/workflows/benchmarks.yml index 9edbe6eb..8955d732 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/benchmarks.yml @@ -14,16 +14,10 @@ jobs: uses: actions/setup-python@v4 with: python-version: 3.11 - - name: Install apt dependencies - run: | - sudo apt-get update - sudo apt-get install -y tesseract-ocr tesseract-ocr-eng - name: Install python dependencies run: | pip install poetry poetry install - poetry remove torch - poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu - name: Run detection benchmark test run: | poetry run python benchmark/detection.py --max 2 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..a6057e9d --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,26 @@ +name: Integration test + +on: [push] + +env: + TORCH_DEVICE: "cpu" + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: 3.11 + - name: Install apt dependencies + run: | + sudo apt-get update + sudo apt-get install -y tesseract-ocr tesseract-ocr-eng + - name: Install python dependencies + run: | + pip install poetry + poetry install + - name: Run tests + run: poetry run pytest \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 48dbc9fe..76cb7d78 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1235,6 +1235,17 @@ files = [ [package.extras] all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + [[package]] name = "ipykernel" version = "6.29.5" @@ -2692,6 +2703,21 @@ files = [ greenlet = "3.1.1" pyee = "12.0.0" +[[package]] +name = "pluggy" +version = "1.5.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + [[package]] name = "prometheus-client" version = "0.21.1" @@ -3207,6 +3233,28 @@ files = [ packaging = ">=21.3" Pillow = ">=8.0.0" +[[package]] +name = "pytest" +version = "8.3.4" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"}, + {file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=1.5,<2" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -4925,4 +4973,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "e5dfcdc29e7912fe8cedfb2af75c0edfe5a911de4600890f75922f501b440046" +content-hash = "dd035c4c1f7634ad4fc809b9a11bad6d9c936eb2ce5c992830f68835f69fea12" diff --git a/pyproject.toml b/pyproject.toml index 317235d7..444eea9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ rapidfuzz = "^3.6.1" arabic-reshaper = "^3.0.0" streamlit = "^1.31.0" playwright = "^1.41.2" +pytest = "^8.3.4" [tool.poetry.scripts] surya_detect = "detect_text:main" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..05ef2959 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,7 @@ +[pytest] +testpaths=tests +pythonpath=. +filterwarnings = + ignore::UserWarning + ignore::PendingDeprecationWarning + ignore::DeprecationWarning \ No newline at end of file diff --git a/surya/benchmark/tesseract.py b/surya/benchmark/tesseract.py index a2d025e0..140d46c9 100644 --- a/surya/benchmark/tesseract.py +++ b/surya/benchmark/tesseract.py @@ -1,8 +1,6 @@ from typing import List, Optional import numpy as np -import pytesseract -from pytesseract import Output from tqdm import tqdm from surya.input.processing import slice_bboxes_from_image @@ -24,6 +22,7 @@ def surya_lang_to_tesseract(code: str) -> Optional[str]: def tesseract_ocr(img, bboxes, lang: str): + import pytesseract line_imgs = slice_bboxes_from_image(img, bboxes) config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"' lines = [] @@ -50,6 +49,8 @@ def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None): def tesseract_bboxes(img): + import pytesseract + from pytesseract import Output arr_img = np.asarray(img, dtype=np.uint8) ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..8f37c5bc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +import pytest +from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor + +@pytest.fixture(scope="session") +def ocr_error_model(): + ocr_error_m = load_ocr_error_model() + ocr_error_p = load_ocr_error_processor() + ocr_error_m.processor = ocr_error_p + yield ocr_error_m + del ocr_error_m \ No newline at end of file diff --git a/tests/test_ocr_errors.py b/tests/test_ocr_errors.py new file mode 100644 index 00000000..6471dca7 --- /dev/null +++ b/tests/test_ocr_errors.py @@ -0,0 +1,18 @@ +from surya.ocr_error import batch_ocr_error_detection + + +def test_garbled_text(ocr_error_model): + text = """" + ; dh vksj ls mifLFkr vf/koDrk % Jh vfuy dqekj + 2. vfHk;qDr dh vksj ls mifLFkr vf/koDrk % Jh iznhi d + """.strip() + results = batch_ocr_error_detection([text], ocr_error_model, ocr_error_model.processor) + assert results.labels[0] == "bad" + + +def test_good_text(ocr_error_model): + text = """" + There are professions more harmful than industrial design, but only a very few of them. + """.strip() + results = batch_ocr_error_detection([text], ocr_error_model, ocr_error_model.processor) + assert results.labels[0] == "good" \ No newline at end of file From 4e60cc5e87eac2b0397f62db002d047b959683b2 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Thu, 19 Dec 2024 11:43:06 -0500 Subject: [PATCH 11/12] Add test for topk --- README.md | 1 + detect_layout.py | 2 -- surya/detection.py | 2 -- surya/layout.py | 10 ++++++---- tests/conftest.py | 14 +++++++++++++- tests/test_layout.py | 26 ++++++++++++++++++++++++++ 6 files changed, 46 insertions(+), 9 deletions(-) create mode 100644 tests/test_layout.py diff --git a/README.md b/README.md index c5493e1c..4fb5cef6 100644 --- a/README.md +++ b/README.md @@ -217,6 +217,7 @@ The `results.json` file will contain a json dictionary where the keys are the in - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. - `position` - the reading order of the box. - `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Form`, `Table-of-contents`, `Handwriting`, `Text`, `Text-inline-math`. + - `top_k` - the top-k other potential labels for the box. A dictionary with labels as keys and confidences as values. - `page` - the page number in the file - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. diff --git a/detect_layout.py b/detect_layout.py index 6524af28..1adb5a90 100644 --- a/detect_layout.py +++ b/detect_layout.py @@ -1,6 +1,4 @@ import time - -import pypdfium2 # Causes a warning if not the top import import argparse import copy import json diff --git a/surya/detection.py b/surya/detection.py index 4d0ac9dc..7ae21679 100644 --- a/surya/detection.py +++ b/surya/detection.py @@ -1,5 +1,3 @@ -import contextlib - import torch from typing import List, Tuple, Generator diff --git a/surya/layout.py b/surya/layout.py index 476e54e4..2378d828 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -1,5 +1,4 @@ from typing import List - import numpy as np import torch from PIL import Image @@ -136,7 +135,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_ box_logits = return_dict["bbox_logits"][:current_batch_size, -1, :].detach() class_logits = return_dict["class_logits"][:current_batch_size, -1, :].detach() - probs = torch.nn.functional.softmax(class_logits, dim=-1) + probs = torch.nn.functional.softmax(class_logits, dim=-1).cpu() entropy = torch.special.entr(probs).sum(dim=-1) class_preds = class_logits.argmax(-1) @@ -218,9 +217,12 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_ top_k_probs = [p["top_k_probs"] for p in preds] top_k_indices = [p["top_k_indices"] - model.decoder.config.special_token_count for p in preds] - for z, (poly, label) in enumerate(zip(polygons, labels)): + for z, (poly, label, top_k_prob, top_k_index) in enumerate(zip(polygons, labels, top_k_probs, top_k_indices)): + top_k_dict = { + ID_TO_LABEL.get(int(l)): prob.item() + for (l, prob) in zip(top_k_index, top_k_prob) if l > 0 + } l = ID_TO_LABEL[int(label)] - top_k_dict = {ID_TO_LABEL.get(int(l)): prob.item() for (l, prob) in zip(top_k_indices[z], top_k_probs[z]) if l > 0} lb = LayoutBox( polygon=poly, label=l, diff --git a/tests/conftest.py b/tests/conftest.py index 8f37c5bc..fb9a8647 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ import pytest from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor +from surya.model.layout.model import load_model as load_layout_model +from surya.model.layout.processor import load_processor as load_layout_processor @pytest.fixture(scope="session") def ocr_error_model(): @@ -7,4 +9,14 @@ def ocr_error_model(): ocr_error_p = load_ocr_error_processor() ocr_error_m.processor = ocr_error_p yield ocr_error_m - del ocr_error_m \ No newline at end of file + del ocr_error_m + + +@pytest.fixture(scope="session") +def layout_model(): + layout_m = load_layout_model() + layout_p = load_layout_processor() + layout_m.processor = layout_p + yield layout_m + del layout_m + diff --git a/tests/test_layout.py b/tests/test_layout.py new file mode 100644 index 00000000..bc140e6f --- /dev/null +++ b/tests/test_layout.py @@ -0,0 +1,26 @@ +import os +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +from surya.layout import batch_layout_detection +from PIL import Image, ImageDraw + +def test_layout_topk(layout_model): + image = Image.new("RGB", (1024, 1024), "white") + draw = ImageDraw.Draw(image) + draw.text((10, 10), "Hello World", fill="black", font_size=72) + draw.text((10, 200), "This is a sentence of text.\nNow it is a paragraph.\nA three-line one.", fill="black", + font_size=24) + + layout_results = batch_layout_detection([image], layout_model, layout_model.processor) + + assert len(layout_results) == 1 + assert layout_results[0].image_bbox == [0, 0, 1024, 1024] + + bboxes = layout_results[0].bboxes + assert len(bboxes) == 2 + + assert bboxes[0].label == "SectionHeader" + assert len(bboxes[0].top_k) == 5 + + assert bboxes[1].label == "Text" + assert len(bboxes[1].top_k) == 5 From 6bbdc57b0d55bf57e7619c7a5b2b32667fb5cd12 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Thu, 19 Dec 2024 11:56:40 -0500 Subject: [PATCH 12/12] Refactor encoder --- surya/model/ocr_error/encoder.py | 797 +++++++++++++++++++++++++++++++ surya/model/ocr_error/model.py | 796 +----------------------------- surya/ocr_error.py | 3 +- 3 files changed, 800 insertions(+), 796 deletions(-) create mode 100644 surya/model/ocr_error/encoder.py diff --git a/surya/model/ocr_error/encoder.py b/surya/model/ocr_error/encoder.py new file mode 100644 index 00000000..4e27700d --- /dev/null +++ b/surya/model/ocr_error/encoder.py @@ -0,0 +1,797 @@ +from __future__ import annotations + +import math +from typing import Optional, Set, List, Tuple, Union, Dict + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F, MSELoss, CrossEntropyLoss, BCEWithLogitsLoss +from transformers import apply_chunking_to_forward, PreTrainedModel +from transformers.activations import get_activation +from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput +from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer + +from transformers.utils import ( + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, +) + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +from surya.model.ocr_error.config import DistilBertConfig + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): + position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) + out.requires_grad = False + out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + out.detach_() + + +class Embeddings(nn.Module): + def __init__(self, config: DistilBertConfig): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim) + + self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) + self.dropout = nn.Dropout(config.dropout) + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Parameters: + input_ids (torch.Tensor): + torch.tensor(bs, max_seq_length) The token ids to embed. + input_embeds (*optional*, torch.Tensor): + The pre-computed word embeddings. Can only be passed if the input ids are `None`. + + + Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type + embeddings) + """ + if input_ids is not None: + input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) + + seq_length = input_embeds.size(1) + + # Setting the position-ids to the registered buffer in constructor, it helps + # when tracing the model without passing position-ids, solves + # isues similar to issue #5664 + if hasattr(self, "position_ids"): + position_ids = self.position_ids[:, :seq_length] + else: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) + + position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) + + embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim) + embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) + embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim) + return embeddings + + +class MultiHeadSelfAttention(nn.Module): + def __init__(self, config: DistilBertConfig): + super().__init__() + self.config = config + + self.n_heads = config.n_heads + self.dim = config.dim + self.dropout = nn.Dropout(p=config.attention_dropout) + self.is_causal = False + + # Have an even number of multi heads that divide the dimensions + if self.dim % self.n_heads != 0: + # Raise value errors for even multi-head attention nodes + raise ValueError(f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly") + + self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + + self.pruned_heads: Set[int] = set() + self.attention_head_size = self.dim // self.n_heads + + def prune_heads(self, heads: List[int]): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.attention_head_size, self.pruned_heads + ) + # Prune linear layers + self.q_lin = prune_linear_layer(self.q_lin, index) + self.k_lin = prune_linear_layer(self.k_lin, index) + self.v_lin = prune_linear_layer(self.v_lin, index) + self.out_lin = prune_linear_layer(self.out_lin, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.dim = self.attention_head_size * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + query: torch.tensor(bs, seq_length, dim) + key: torch.tensor(bs, seq_length, dim) + value: torch.tensor(bs, seq_length, dim) + mask: torch.tensor(bs, seq_length) + + Returns: + weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, + seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` + """ + bs, q_length, dim = query.size() + k_length = key.size(1) + # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' + # assert key.size() == value.size() + + dim_per_head = self.dim // self.n_heads + + mask_reshp = (bs, 1, 1, k_length) + + def shape(x: torch.Tensor) -> torch.Tensor: + """separate heads""" + return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) + + def unshape(x: torch.Tensor) -> torch.Tensor: + """group heads""" + return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) + + q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) + k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) + v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) + + q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) + scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) + mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) + scores = scores.masked_fill( + mask, torch.tensor(torch.finfo(scores.dtype).min) + ) # (bs, n_heads, q_length, k_length) + + weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length) + weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) + + # Mask heads if we want to + if head_mask is not None: + weights = weights * head_mask + + context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) + context = unshape(context) # (bs, q_length, dim) + context = self.out_lin(context) # (bs, q_length, dim) + + if output_attentions: + return (context, weights) + else: + return (context,) + + +class DistilBertFlashAttention2(MultiHeadSelfAttention): + """ + DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module + stays untouched. The only required change would be on the forward pass where it needs to correctly call the public + API of flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + query: torch.tensor(bs, seq_length, dim) + key: torch.tensor(bs, seq_length, dim) + value: torch.tensor(bs, seq_length, dim) + mask: torch.tensor(bs, seq_length) + + Returns: + weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, + seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` + """ + batch_size, q_length, dim = query.size() + + dim_per_head = self.dim // self.n_heads + + def reshape(x: torch.Tensor) -> torch.Tensor: + """separate heads""" + return x.view(batch_size, -1, self.n_heads, dim_per_head) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = reshape(self.q_lin(query)) + key_states = reshape(self.k_lin(key)) + value_states = reshape(self.v_lin(value)) + + attn_dropout = self.config.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_lin.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_weights = self._flash_attention_forward( + query_states, key_states, value_states, mask, q_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head) + attn_output = self.out_lin(attn_weights_reshaped) + + if output_attentions: + return (attn_output, attn_weights) + else: + return (attn_output,) + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class FFN(nn.Module): + def __init__(self, config: DistilBertConfig): + super().__init__() + self.dropout = nn.Dropout(p=config.dropout) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim) + self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim) + self.activation = get_activation(config.activation) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input) + + def ff_chunk(self, input: torch.Tensor) -> torch.Tensor: + x = self.lin1(input) + x = self.activation(x) + x = self.lin2(x) + x = self.dropout(x) + return x + + +DISTILBERT_ATTENTION_CLASSES = { + "eager": MultiHeadSelfAttention, + "flash_attention_2": DistilBertFlashAttention2, +} + + +class TransformerBlock(nn.Module): + def __init__(self, config: DistilBertConfig): + super().__init__() + + # Have an even number of Configure multi-heads + if config.dim % config.n_heads != 0: + raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly") + + self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](config) + self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) + + self.ffn = FFN(config) + self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + x: torch.tensor(bs, seq_length, dim) + attn_mask: torch.tensor(bs, seq_length) + + Returns: + sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output: + torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization. + """ + # Self-Attention + sa_output = self.attention( + query=x, + key=x, + value=x, + mask=attn_mask, + head_mask=head_mask, + output_attentions=output_attentions, + ) + if output_attentions: + sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) + else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples + if type(sa_output) != tuple: + raise TypeError(f"sa_output must be a tuple but it is {type(sa_output)} type") + + sa_output = sa_output[0] + sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) + + # Feed Forward Network + ffn_output = self.ffn(sa_output) # (bs, seq_length, dim) + ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim) + + output = (ffn_output,) + if output_attentions: + output = (sa_weights,) + output + return output + + +class Transformer(nn.Module): + def __init__(self, config: DistilBertConfig): + super().__init__() + self.n_layers = config.n_layers + self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: # docstyle-ignore + """ + Parameters: + x: torch.tensor(bs, seq_length, dim) Input sequence embedded. + attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence. + + Returns: + hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top) + layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)] + Tuple of length n_layers with the hidden states from each layer. + Optional: only if output_hidden_states=True + all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] + Tuple of length n_layers with the attention weights from each layer + Optional: only if output_attentions=True + """ + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_state = x + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_state, + attn_mask, + head_mask[i], + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_state, + attn_mask, + head_mask[i], + output_attentions, + ) + + hidden_state = layer_outputs[-1] + + if output_attentions: + if len(layer_outputs) != 2: + raise ValueError(f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}") + + attentions = layer_outputs[0] + all_attentions = all_attentions + (attentions,) + else: + if len(layer_outputs) != 1: + raise ValueError(f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}") + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class DistilBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DistilBertConfig + load_tf_weights = None + base_model_prefix = "distilbert" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds: + create_sinusoidal_embeddings( + self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight + ) + + +class DistilBertModel(DistilBertPreTrainedModel): + def __init__(self, config: DistilBertConfig): + super().__init__(config) + + self.embeddings = Embeddings(config) # Embeddings + self.transformer = Transformer(config) # Encoder + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + # Initialize weights and apply final processing + self.post_init() + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings + """ + return self.embeddings.position_embeddings + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embedding matrix. If position embeddings are learned, increasing the size + will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the + end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the + size will add correct vectors at the end following the position encoding algorithm, whereas reducing + the size will remove vectors from the end. + """ + num_position_embeds_diff = new_num_position_embeddings - self.config.max_position_embeddings + + # no resizing needs to be done if the length stays the same + if num_position_embeds_diff == 0: + return + + self.config.max_position_embeddings = new_num_position_embeddings + + old_position_embeddings_weight = self.embeddings.position_embeddings.weight.clone() + + self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim) + + if self.config.sinusoidal_pos_embds: + create_sinusoidal_embeddings( + n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight + ) + else: + with torch.no_grad(): + if num_position_embeds_diff > 0: + self.embeddings.position_embeddings.weight[:-num_position_embeds_diff] = nn.Parameter( + old_position_embeddings_weight + ) + else: + self.embeddings.position_embeddings.weight = nn.Parameter( + old_position_embeddings_weight[:num_position_embeds_diff] + ) + # move position_embeddings to correct device + self.embeddings.position_embeddings.to(self.device) + + def get_input_embeddings(self) -> nn.Embedding: + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings: nn.Embedding): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.transformer.layer[layer].attention.prune_heads(heads) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim) + + if self._use_flash_attention_2: + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length) + + return self.transformer( + x=embeddings, + attn_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class DistilBertForSequenceClassification(DistilBertPreTrainedModel): + def __init__(self, config: DistilBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.distilbert = DistilBertModel(config) + self.pre_classifier = nn.Linear(config.dim, config.dim) + self.classifier = nn.Linear(config.dim, config.num_labels) + self.dropout = nn.Dropout(config.seq_classif_dropout) + + # Initialize weights and apply final processing + self.post_init() + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings + """ + return self.distilbert.get_position_embeddings() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embedding matrix. If position embeddings are learned, increasing the size + will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the + end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the + size will add correct vectors at the end following the position encoding algorithm, whereas reducing + the size will remove vectors from the end. + """ + self.distilbert.resize_position_embeddings(new_num_position_embeddings) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + distilbert_output = self.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_state = distilbert_output[0] # (bs, seq_len, dim) + pooled_output = hidden_state[:, 0] # (bs, dim) + pooled_output = self.pre_classifier(pooled_output) # (bs, dim) + pooled_output = nn.ReLU()(pooled_output) # (bs, dim) + pooled_output = self.dropout(pooled_output) # (bs, dim) + logits = self.classifier(pooled_output) # (bs, num_labels) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + distilbert_output[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=distilbert_output.hidden_states, + attentions=distilbert_output.attentions, + ) diff --git a/surya/model/ocr_error/model.py b/surya/model/ocr_error/model.py index 542e8c30..72ed5816 100644 --- a/surya/model/ocr_error/model.py +++ b/surya/model/ocr_error/model.py @@ -1,30 +1,7 @@ from __future__ import annotations - -import math -from typing import Dict, List, Optional, Set, Tuple, Union - -import numpy as np import torch -import torch.nn.functional as F -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import get_activation -from transformers.modeling_outputs import ( - BaseModelOutput, - SequenceClassifierOutput, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from transformers.utils import ( - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, -) - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +from surya.model.ocr_error.encoder import DistilBertForSequenceClassification from surya.model.ocr_error.config import DistilBertConfig from surya.model.ocr_error.tokenizer import DistilBertTokenizer from surya.settings import settings @@ -47,775 +24,4 @@ def load_tokenizer(checkpoint=settings.OCR_ERROR_MODEL_CHECKPOINT): tokenizer = DistilBertTokenizer.from_pretrained(checkpoint) return tokenizer -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): - position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) - out.requires_grad = False - out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) - out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - out.detach_() - - -class Embeddings(nn.Module): - def __init__(self, config: DistilBertConfig): - super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim) - - self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) - self.dropout = nn.Dropout(config.dropout) - self.register_buffer( - "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False - ) - - def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - Parameters: - input_ids (torch.Tensor): - torch.tensor(bs, max_seq_length) The token ids to embed. - input_embeds (*optional*, torch.Tensor): - The pre-computed word embeddings. Can only be passed if the input ids are `None`. - - - Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type - embeddings) - """ - if input_ids is not None: - input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) - - seq_length = input_embeds.size(1) - - # Setting the position-ids to the registered buffer in constructor, it helps - # when tracing the model without passing position-ids, solves - # isues similar to issue #5664 - if hasattr(self, "position_ids"): - position_ids = self.position_ids[:, :seq_length] - else: - position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) - - position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) - - embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim) - embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) - embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim) - return embeddings - - -class MultiHeadSelfAttention(nn.Module): - def __init__(self, config: DistilBertConfig): - super().__init__() - self.config = config - - self.n_heads = config.n_heads - self.dim = config.dim - self.dropout = nn.Dropout(p=config.attention_dropout) - self.is_causal = False - - # Have an even number of multi heads that divide the dimensions - if self.dim % self.n_heads != 0: - # Raise value errors for even multi-head attention nodes - raise ValueError(f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly") - - self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim) - self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim) - self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim) - self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim) - - self.pruned_heads: Set[int] = set() - self.attention_head_size = self.dim // self.n_heads - - def prune_heads(self, heads: List[int]): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, self.n_heads, self.attention_head_size, self.pruned_heads - ) - # Prune linear layers - self.q_lin = prune_linear_layer(self.q_lin, index) - self.k_lin = prune_linear_layer(self.k_lin, index) - self.v_lin = prune_linear_layer(self.v_lin, index) - self.out_lin = prune_linear_layer(self.out_lin, index, dim=1) - # Update hyper params - self.n_heads = self.n_heads - len(heads) - self.dim = self.attention_head_size * self.n_heads - self.pruned_heads = self.pruned_heads.union(heads) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: torch.Tensor, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, ...]: - """ - Parameters: - query: torch.tensor(bs, seq_length, dim) - key: torch.tensor(bs, seq_length, dim) - value: torch.tensor(bs, seq_length, dim) - mask: torch.tensor(bs, seq_length) - - Returns: - weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, - seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` - """ - bs, q_length, dim = query.size() - k_length = key.size(1) - # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' - # assert key.size() == value.size() - - dim_per_head = self.dim // self.n_heads - - mask_reshp = (bs, 1, 1, k_length) - - def shape(x: torch.Tensor) -> torch.Tensor: - """separate heads""" - return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) - - def unshape(x: torch.Tensor) -> torch.Tensor: - """group heads""" - return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) - - q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) - k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) - v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) - - q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) - scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) - mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) - scores = scores.masked_fill( - mask, torch.tensor(torch.finfo(scores.dtype).min) - ) # (bs, n_heads, q_length, k_length) - - weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length) - weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) - - # Mask heads if we want to - if head_mask is not None: - weights = weights * head_mask - - context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) - context = unshape(context) # (bs, q_length, dim) - context = self.out_lin(context) # (bs, q_length, dim) - - if output_attentions: - return (context, weights) - else: - return (context,) - - -class DistilBertFlashAttention2(MultiHeadSelfAttention): - """ - DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module - stays untouched. The only required change would be on the forward pass where it needs to correctly call the public - API of flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: torch.Tensor, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, ...]: - """ - Parameters: - query: torch.tensor(bs, seq_length, dim) - key: torch.tensor(bs, seq_length, dim) - value: torch.tensor(bs, seq_length, dim) - mask: torch.tensor(bs, seq_length) - - Returns: - weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, - seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` - """ - batch_size, q_length, dim = query.size() - - dim_per_head = self.dim // self.n_heads - - def reshape(x: torch.Tensor) -> torch.Tensor: - """separate heads""" - return x.view(batch_size, -1, self.n_heads, dim_per_head) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - query_states = reshape(self.q_lin(query)) - key_states = reshape(self.k_lin(key)) - value_states = reshape(self.v_lin(value)) - - attn_dropout = self.config.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - if query_states.dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_lin.weight.dtype - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_weights = self._flash_attention_forward( - query_states, key_states, value_states, mask, q_length, dropout=attn_dropout - ) - - attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head) - attn_output = self.out_lin(attn_weights_reshaped) - - if output_attentions: - return (attn_output, attn_weights) - else: - return (attn_output,) - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False - def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal - ) - - return attn_output - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -class FFN(nn.Module): - def __init__(self, config: DistilBertConfig): - super().__init__() - self.dropout = nn.Dropout(p=config.dropout) - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.seq_len_dim = 1 - self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim) - self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim) - self.activation = get_activation(config.activation) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input) - - def ff_chunk(self, input: torch.Tensor) -> torch.Tensor: - x = self.lin1(input) - x = self.activation(x) - x = self.lin2(x) - x = self.dropout(x) - return x - - -DISTILBERT_ATTENTION_CLASSES = { - "eager": MultiHeadSelfAttention, - "flash_attention_2": DistilBertFlashAttention2, -} - - -class TransformerBlock(nn.Module): - def __init__(self, config: DistilBertConfig): - super().__init__() - - # Have an even number of Configure multi-heads - if config.dim % config.n_heads != 0: - raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly") - - self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](config) - self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) - - self.ffn = FFN(config) - self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) - - def forward( - self, - x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, ...]: - """ - Parameters: - x: torch.tensor(bs, seq_length, dim) - attn_mask: torch.tensor(bs, seq_length) - - Returns: - sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output: - torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization. - """ - # Self-Attention - sa_output = self.attention( - query=x, - key=x, - value=x, - mask=attn_mask, - head_mask=head_mask, - output_attentions=output_attentions, - ) - if output_attentions: - sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) - else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples - if type(sa_output) != tuple: - raise TypeError(f"sa_output must be a tuple but it is {type(sa_output)} type") - - sa_output = sa_output[0] - sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) - - # Feed Forward Network - ffn_output = self.ffn(sa_output) # (bs, seq_length, dim) - ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim) - - output = (ffn_output,) - if output_attentions: - output = (sa_weights,) + output - return output - - -class Transformer(nn.Module): - def __init__(self, config: DistilBertConfig): - super().__init__() - self.n_layers = config.n_layers - self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) - self.gradient_checkpointing = False - - def forward( - self, - x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: Optional[bool] = None, - ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: # docstyle-ignore - """ - Parameters: - x: torch.tensor(bs, seq_length, dim) Input sequence embedded. - attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence. - - Returns: - hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top) - layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)] - Tuple of length n_layers with the hidden states from each layer. - Optional: only if output_hidden_states=True - all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] - Tuple of length n_layers with the attention weights from each layer - Optional: only if output_attentions=True - """ - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_state = x - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_state,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_state, - attn_mask, - head_mask[i], - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_state, - attn_mask, - head_mask[i], - output_attentions, - ) - - hidden_state = layer_outputs[-1] - - if output_attentions: - if len(layer_outputs) != 2: - raise ValueError(f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}") - - attentions = layer_outputs[0] - all_attentions = all_attentions + (attentions,) - else: - if len(layer_outputs) != 1: - raise ValueError(f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}") - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_state,) - - if not return_dict: - return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL # -class DistilBertPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = DistilBertConfig - load_tf_weights = None - base_model_prefix = "distilbert" - supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds: - create_sinusoidal_embeddings( - self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight - ) - - -class DistilBertModel(DistilBertPreTrainedModel): - def __init__(self, config: DistilBertConfig): - super().__init__(config) - - self.embeddings = Embeddings(config) # Embeddings - self.transformer = Transformer(config) # Encoder - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - - # Initialize weights and apply final processing - self.post_init() - - def get_position_embeddings(self) -> nn.Embedding: - """ - Returns the position embeddings - """ - return self.embeddings.position_embeddings - - def resize_position_embeddings(self, new_num_position_embeddings: int): - """ - Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. - - Arguments: - new_num_position_embeddings (`int`): - The number of new position embedding matrix. If position embeddings are learned, increasing the size - will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the - end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the - size will add correct vectors at the end following the position encoding algorithm, whereas reducing - the size will remove vectors from the end. - """ - num_position_embeds_diff = new_num_position_embeddings - self.config.max_position_embeddings - - # no resizing needs to be done if the length stays the same - if num_position_embeds_diff == 0: - return - - self.config.max_position_embeddings = new_num_position_embeddings - - old_position_embeddings_weight = self.embeddings.position_embeddings.weight.clone() - - self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim) - - if self.config.sinusoidal_pos_embds: - create_sinusoidal_embeddings( - n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight - ) - else: - with torch.no_grad(): - if num_position_embeds_diff > 0: - self.embeddings.position_embeddings.weight[:-num_position_embeds_diff] = nn.Parameter( - old_position_embeddings_weight - ) - else: - self.embeddings.position_embeddings.weight = nn.Parameter( - old_position_embeddings_weight[:num_position_embeds_diff] - ) - # move position_embeddings to correct device - self.embeddings.position_embeddings.to(self.device) - - def get_input_embeddings(self) -> nn.Embedding: - return self.embeddings.word_embeddings - - def set_input_embeddings(self, new_embeddings: nn.Embedding): - self.embeddings.word_embeddings = new_embeddings - - def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.transformer.layer[layer].attention.prune_heads(heads) - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - # Prepare head mask if needed - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim) - - if self._use_flash_attention_2: - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - if attention_mask is None: - attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length) - - return self.transformer( - x=embeddings, - attn_mask=attention_mask, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -class DistilBertForSequenceClassification(DistilBertPreTrainedModel): - def __init__(self, config: DistilBertConfig): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - - self.distilbert = DistilBertModel(config) - self.pre_classifier = nn.Linear(config.dim, config.dim) - self.classifier = nn.Linear(config.dim, config.num_labels) - self.dropout = nn.Dropout(config.seq_classif_dropout) - - # Initialize weights and apply final processing - self.post_init() - - def get_position_embeddings(self) -> nn.Embedding: - """ - Returns the position embeddings - """ - return self.distilbert.get_position_embeddings() - - def resize_position_embeddings(self, new_num_position_embeddings: int): - """ - Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. - - Arguments: - new_num_position_embeddings (`int`): - The number of new position embedding matrix. If position embeddings are learned, increasing the size - will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the - end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the - size will add correct vectors at the end following the position encoding algorithm, whereas reducing - the size will remove vectors from the end. - """ - self.distilbert.resize_position_embeddings(new_num_position_embeddings) - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - distilbert_output = self.distilbert( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_state = distilbert_output[0] # (bs, seq_len, dim) - pooled_output = hidden_state[:, 0] # (bs, dim) - pooled_output = self.pre_classifier(pooled_output) # (bs, dim) - pooled_output = nn.ReLU()(pooled_output) # (bs, dim) - pooled_output = self.dropout(pooled_output) # (bs, dim) - logits = self.classifier(pooled_output) # (bs, num_labels) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - - if not return_dict: - output = (logits,) + distilbert_output[1:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=distilbert_output.hidden_states, - attentions=distilbert_output.attentions, - ) \ No newline at end of file diff --git a/surya/ocr_error.py b/surya/ocr_error.py index 4739ad6d..0cf714f4 100644 --- a/surya/ocr_error.py +++ b/surya/ocr_error.py @@ -4,7 +4,8 @@ import torch import numpy as np -from surya.model.ocr_error.model import DistilBertForSequenceClassification, DistilBertTokenizer +from surya.model.ocr_error.model import DistilBertTokenizer +from surya.model.ocr_error.encoder import DistilBertForSequenceClassification from surya.model.ocr_error.config import ID2LABEL from surya.settings import settings from surya.schema import OCRErrorDetectionResult