Skip to content

Commit

Permalink
refactor grobid processors (#29)
Browse files Browse the repository at this point in the history
refactor grobid processors, deprecate legacy methods
  • Loading branch information
lfoppiano authored Mar 4, 2024
1 parent c08e73a commit 104b3a9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 38 deletions.
2 changes: 1 addition & 1 deletion document_qa/document_qa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1,
print("File", pdf_file_path)
filename = Path(pdf_file_path).stem
coordinates = True # if chunk_size == -1 else False
structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates)
structure = self.grobid_processor.process(pdf_file_path, coordinates=coordinates)

biblio = structure['biblio']
biblio['filename'] = filename.replace(" ", "_")
Expand Down
68 changes: 31 additions & 37 deletions document_qa/grobid_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
from html import escape
from pathlib import Path
from typing_extensions import deprecated

import dateparser
import grobid_tei_xml
Expand Down Expand Up @@ -54,6 +55,7 @@ def decorate_text_with_annotations(text, spans, tag="span"):
return annotated_text


@deprecated("Use GrobidQuantitiesProcessor.process() instead")
def extract_quantities(client, x_all, column_text_index):
# relevant_items = ['magnetic field strength', 'magnetic induction', 'maximum energy product',
# "magnetic flux density", "magnetic flux"]
Expand All @@ -63,7 +65,7 @@ def extract_quantities(client, x_all, column_text_index):

for idx, example in tqdm(enumerate(x_all), desc="extract quantities"):
text = example[column_text_index]
spans = GrobidQuantitiesProcessor(client).extract_quantities(text)
spans = GrobidQuantitiesProcessor(client).process(text)

data_record = {
"id": example[0],
Expand All @@ -78,12 +80,13 @@ def extract_quantities(client, x_all, column_text_index):
return output_data


@deprecated("Use GrobidMaterialsProcessor.process() instead")
def extract_materials(client, x_all, column_text_index):
output_data = []

for idx, example in tqdm(enumerate(x_all), desc="extract materials"):
text = example[column_text_index]
spans = GrobidMaterialsProcessor(client).extract_materials(text)
spans = GrobidMaterialsProcessor(client).process(text)
data_record = {
"id": example[0],
"filename": example[1],
Expand Down Expand Up @@ -131,7 +134,7 @@ def __init__(self, grobid_client):
# super().__init__()
self.grobid_client = grobid_client

def process_structure(self, input_path, coordinates=False):
def process(self, input_path, coordinates=False):
pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument",
input_path,
consolidate_header=True,
Expand All @@ -145,19 +148,10 @@ def process_structure(self, input_path, coordinates=False):
if status != 200:
return

output_data = self.parse_grobid_xml(text, coordinates=coordinates)
output_data['filename'] = Path(pdf_file).stem.replace(".tei", "")
document_object = self.parse_grobid_xml(text, coordinates=coordinates)
document_object['filename'] = Path(pdf_file).stem.replace(".tei", "")

return output_data

def process_single(self, input_file):
doc = self.process_structure(input_file)

for paragraph in doc['passages']:
entities = self.process_single_text(paragraph['text'])
paragraph['spans'] = entities

return doc
return document_object

def parse_grobid_xml(self, text, coordinates=False):
output_data = OrderedDict()
Expand Down Expand Up @@ -187,10 +181,10 @@ def parse_grobid_xml(self, text, coordinates=False):
"text": f"authors: {biblio['authors']}",
"type": passage_type,
"section": "<header>",
"subSection": "<title>",
"passage_id": "htitle",
"subSection": "<authors>",
"passage_id": "hauthors",
"coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in
blocks_header['authors']])
blocks_header['authors']])
})

passages.append({
Expand Down Expand Up @@ -293,7 +287,7 @@ class GrobidQuantitiesProcessor(BaseProcessor):
def __init__(self, grobid_quantities_client):
self.grobid_quantities_client = grobid_quantities_client

def extract_quantities(self, text):
def process(self, text):
status, result = self.grobid_quantities_client.process_text(text.strip())

if status != 200:
Expand Down Expand Up @@ -465,7 +459,7 @@ class GrobidMaterialsProcessor(BaseProcessor):
def __init__(self, grobid_superconductors_client):
self.grobid_superconductors_client = grobid_superconductors_client

def extract_materials(self, text):
def process(self, text):
preprocessed_text = text.strip()
status, result = self.grobid_superconductors_client.process_text(preprocessed_text,
"processText_disable_linking")
Expand Down Expand Up @@ -568,17 +562,17 @@ def __init__(self, grobid_client, grobid_quantities_client=None, grobid_supercon
self.gmp = GrobidMaterialsProcessor(grobid_superconductors_client)

def process_single_text(self, text):
extracted_quantities_spans = self.gqp.extract_quantities(text)
extracted_materials_spans = self.gmp.extract_materials(text)
extracted_quantities_spans = self.process_properties(text)
extracted_materials_spans = self.process_materials(text)
all_entities = extracted_quantities_spans + extracted_materials_spans
entities = self.prune_overlapping_annotations(all_entities)
return entities

def extract_quantities(self, text):
return self.gqp.extract_quantities(text)
def process_properties(self, text):
return self.gqp.process(text)

def extract_materials(self, text):
return self.gmp.extract_materials(text)
def process_materials(self, text):
return self.gmp.process(text)

@staticmethod
def box_to_dict(box, color=None, type=None):
Expand Down Expand Up @@ -715,8 +709,8 @@ def prune_overlapping_annotations(entities: list) -> list:


class XmlProcessor(BaseProcessor):
def __init__(self, grobid_superconductors_client, grobid_quantities_client):
super().__init__(grobid_superconductors_client, grobid_quantities_client)
def __init__(self):
super().__init__()

def process_structure(self, input_file):
text = ""
Expand All @@ -728,16 +722,16 @@ def process_structure(self, input_file):

return output_data

def process_single(self, input_file):
doc = self.process_structure(input_file)

for paragraph in doc['passages']:
entities = self.process_single_text(paragraph['text'])
paragraph['spans'] = entities

return doc
# def process_single(self, input_file):
# doc = self.process_structure(input_file)
#
# for paragraph in doc['passages']:
# entities = self.process_single_text(paragraph['text'])
# paragraph['spans'] = entities
#
# return doc

def parse_xml(self, text):
def process(self, text):
output_data = OrderedDict()
soup = BeautifulSoup(text, 'xml')
text_blocks_children = get_children_list_supermat(soup, verbose=False)
Expand Down

0 comments on commit 104b3a9

Please sign in to comment.